--- title: Data keywords: fastai sidebar: home_sidebar summary: "This module defines tools for image data preprocessing and real-time data augmentation that is used to train a model." description: "This module defines tools for image data preprocessing and real-time data augmentation that is used to train a model." nb_path: "nbs/02_data.ipynb" ---
Original Paper: Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70.
The code for data augmentation and mask weight generation was provided by the authors and adapted for this package.
The show methods in fastai all rely on some types being able to show themselves. We create a new type with a show method.
Example image and mask
We will use an example image and mask to guide through the documentation.
Plot example image and mask
image = imageio.imread(path/'images'/'01.png')
mask = imageio.imread(path/'labels'/'01_mask.png')#//255
show(image)
show(image, mask)
Deformation field class to ensure that all augmentations are performed equally on images, masks, and weights. Implemented augmentations are
Original Image
inst_labels, _ = ndimage.measurements.label(mask)
inp = torch.Tensor(inst_labels)
wt = WeightTransform(channels=inp.size(-1))
weights = wt(inp.view(1, *inp.shape))[0].numpy()
tst = DeformationField(shape=(260, 260))
show(tst.apply(image, offset=(270,270)),
tst.apply(mask, offset=(270,270)),
tst.apply(weights, offset=(270,270)))
Add mirroring
tst = DeformationField()
tst.mirror((1,1))
show(tst.apply(image, offset=(270,270)),
tst.apply(mask, offset=(270,270)),
tst.apply(weights, offset=(270,270)))
Add random deformation
tst.addRandomDeformation()
show(tst.apply(image, offset=(270,270)),
tst.apply(mask, offset=(270,270)),
tst.apply(weights, offset=(270,270)))
Add rotation
tst.rotate(1,1,1)
show(tst.apply(image, offset=(270,270)),
tst.apply(mask, offset=(270,270)),
tst.apply(weights, offset=(270,270)))
Pytorch map-style datasets for training and validation.
path = Path('sample_data')
files = get_image_files(path/'images')
label_fn = label_fn = lambda o: path/'labels'/f'{o.stem}_mask.png'#lambda o: path/'labels'/f'{o.stem}_mask{o.suffix}'
tst = BaseDataset(files, label_fn=label_fn, fbr=0.6)
tst.show_data()
tst.clear_cached_weights()
Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. tile_shape - The tile shape the network expects as input padding - The padding (input shape - output shape) classlabels - A list containing the corresponding class labels. 0 = ignore, 1 = background, 2-n foreground classes If None, the problem will be treated as binary segmentation n_classes - The number of classes including background ignore - A list containing the corresponding ignore regions. weights - A list containing the corresponding weights. element_size_um - The target pixel size in micrometers batch_size - The number of tiles to generate per batch rotation_range_deg - (alpha_min, alpha_max): The range of rotation angles. A random rotation is drawn from a uniform distribution in the given range flip - If true, a coin flip decides whether a mirrored tile will be generated deformation_grid - (dx, dy): The distance of neighboring grid points in pixels for which random deformation vectors are drawn deformation_magnitude - (sx, sy): The standard deviations of the Gaussians, the components of the deformation vector are drawn from value_minimum_range - (v_min, v_max): Input intensity zero will be mapped to a random value in the given range value_maximum_range - (v_min, v_max): Input intensity one will be mapped to a random value within the given range value_slope_range - (s_min, s_max): The slope at control points is drawn from a uniform distribution in the given range
Show data
tst = RandomTileDataset(files, label_fn=label_fn, n_jobs=2, verbose=2, scale=2)#, albumentations_tfms=get_aug())
tst.show_data()
Show random tile (default padding = (184,184))
tile = tst[0]
show(tile[0], tile[1], tile[2])
Compute stats
tst.compute_stats()
Show data
tst = TileDataset(files, label_fn=label_fn, tile_shape=(450,450), padding=(10,10), val_length=6)
tst.show_data()
Show tiles
for i in range(len(tst)):
print(f'Tile {i}')
tile = tst[i]
show(tile[0], tile[1], tile[2])
Reconstruct masks
tst = TileDataset(files, label_fn=label_fn, tile_shape=(240,240), padding=(10,10), scale=4)
msk_tiles = [x[1] for x in tst]
msk = tst.reconstruct_from_tiles(msk_tiles)
plt.imshow(msk[0], cmap='binary_r');