--- title: Models keywords: fastai sidebar: home_sidebar summary: "Pytorch segmentation models." description: "Pytorch segmentation models." nb_path: "nbs/01_models.ipynb" ---
From the website:
See https://github.com/qubvel/segmentation_models.pytorch for API details.
bs = 2
tile_shapes = [512] #1024
in_channels = [1] #1,3,4
classes = [2] # 2,5
encoders = ENCODERS[1:2]#+ENCODERS[-1:]
for ts in tile_shapes:
for in_c in in_channels:
for c in classes:
inp = torch.randn(bs, in_c, ts, ts)
out_shape = [bs, c, ts, ts]
for arch in ARCHITECTURES:
for encoder_name in encoders:
model = create_smp_model(arch=arch,
encoder_name=encoder_name,
encoder_weights=None,
in_channels=in_c,
classes=c)
out = model(inp)
test_eq(out.shape, out_shape)
del model
arch = 'Unet'
path = 'tst.pth'
stats = (1,1)
kwargs = {'encoder_name': 'resnet34'}
tst = create_smp_model(arch, **kwargs)
path = save_smp_model(tst, arch, path, stats=stats)
tst2, stats2 = load_smp_model(path)
for p1, p2 in zip(tst.parameters(), tst2.parameters()):
test_eq(p1.detach(), p2.detach())
test_eq(stats, stats2)
path.unlink()
for reliable cell and nucleus segmentation. Visit cellpose for more information.
Cellpose integration for deepflash2 is tested on version 0.6.6.dev13+g316927e
check_cellpose_installation()
probs = [np.random.rand(512,512)]
masks = [x>0. for x in probs]
cp_preds = run_cellpose(probs, masks, diameter=17.)
test_eq(probs[0].shape, cp_preds[0].shape)