--- title: Models keywords: fastai sidebar: home_sidebar summary: "Pytorch segmentation models." description: "Pytorch segmentation models." nb_path: "nbs/01_models.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}
{% endraw %}

Segmenation Models Pytorch Integration

From the website:

  • High level API (just two lines to create a neural network)
  • 9 models architectures for binary and multi class segmentation (including legendary Unet)
  • 104 available encoders
  • All encoders have pre-trained weights for faster and better convergence

See https://github.com/qubvel/segmentation_models.pytorch for API details.

{% raw %}

get_pretrained_options[source]

get_pretrained_options(encoder_name)

Return available options for pretrained weights for a given encoder

{% endraw %} {% raw %}
{% endraw %} {% raw %}

create_smp_model[source]

create_smp_model(arch, **kwargs)

Create segmentation_models_pytorch model

{% endraw %} {% raw %}
{% endraw %} {% raw %}
bs = 2
tile_shapes = [512] #1024
in_channels = [1] #1,3,4
classes = [2] # 2,5
encoders = ENCODERS[1:2]#+ENCODERS[-1:]

for ts in tile_shapes:
    for in_c in in_channels:
        for c in classes:
            inp = torch.randn(bs, in_c, ts, ts)
            out_shape = [bs, c, ts, ts]
            for arch in ARCHITECTURES:
                for encoder_name in encoders:
                    model = create_smp_model(arch=arch, 
                                             encoder_name=encoder_name,
                                             encoder_weights=None,
                                             in_channels=in_c, 
                                             classes=c)
                    out = model(inp)
                    test_eq(out.shape, out_shape)
del model
/media/data/anaconda3/envs/fastai2/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448278899/work/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
{% endraw %} {% raw %}

save_smp_model[source]

save_smp_model(model, arch, file, stats=None, pickle_protocol=2)

Save smp model, optionally including stats

{% endraw %} {% raw %}
{% endraw %} {% raw %}
arch = 'Unet'
file = 'tst.pth'
stats = (1,1)
kwargs = {'encoder_name': 'resnet34'}
tst = create_smp_model(arch, **kwargs)
save_smp_model(tst, arch, file, stats=stats)
{% endraw %} {% raw %}

load_smp_model[source]

load_smp_model(file, device=None, strict=True, **kwargs)

Loads smp model from file

{% endraw %} {% raw %}
{% endraw %} {% raw %}
tst2, stats2 = load_smp_model(file)
for p1, p2 in zip(tst.parameters(), tst2.parameters()):
    test_eq(p1.detach(), p2.detach())
test_eq(stats, stats2)
{% endraw %}

Cellpose integration

for reliable cell and nucleus segmentation. Visit cellpose for more information.

Cellpose integration for deepflash2 is tested on version 0.6.6.dev13+g316927e

{% raw %}

check_cellpose_installation[source]

check_cellpose_installation()

{% endraw %} {% raw %}
{% endraw %} {% raw %}

get_diameters[source]

get_diameters(masks)

Get diameters from deepflash2 prediction

{% endraw %} {% raw %}
{% endraw %} {% raw %}

run_cellpose[source]

run_cellpose(probs, masks, model_type='nuclei', diameter=0, min_size=-1, gpu=True)

Run cellpose on deepflash2 predictions

{% endraw %} {% raw %}
{% endraw %} {% raw %}
probs = [np.random.rand(512,512)]
masks = [x>0. for x in probs]
cp_preds = run_cellpose(probs, masks, diameter=17.)
test_eq(probs[0].shape, cp_preds[0].shape)
Using diameter of 17.0
2021-10-03 21:17:53,554 [INFO] ** TORCH CUDA version installed and working. **
2021-10-03 21:17:53,554 [INFO] >>>> using GPU
2021-10-03 21:17:53,598 [INFO] ~~~ FINDING MASKS ~~~
2021-10-03 21:17:55,255 [INFO] >>>> TOTAL TIME 1.66 sec
{% endraw %}