--- title: Data keywords: fastai sidebar: home_sidebar summary: "This module defines tools for image data preprocessing and real-time data augmentation that is used to train a model." description: "This module defines tools for image data preprocessing and real-time data augmentation that is used to train a model." nb_path: "nbs/02_data.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %}

Original Paper: Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70.

The code for data augmentation and mask weight generation was provided by the authors and adapted for this package.

Plot images and masks

{% raw %}

show[source]

show(*obj, file_name=None, overlay=False, pred=False, show_bbox=True, figsize=(10, 10), cmap='binary_r', **kwargs)

Show image, mask, and weight (optional)

{% endraw %} {% raw %}
{% endraw %}

The show methods in fastai all rely on some types being able to show themselves. We create a new type with a show method.

Typedispatch

Custom show_batch and show_results for DataLoader

{% raw %}
{% endraw %} {% raw %}
{% endraw %}

Example image and mask

We will use an example image and mask to guide through the documentation.

Plot example image and mask

{% raw %}
image = imageio.imread(path/'images'/'01.png')
mask = imageio.imread(path/'labels'/'01_mask.png')#//255
show(image)
show(image, mask)
{% endraw %}

Data augmentation

Deformation field class to ensure that all augmentations are performed equally on images, masks, and weights. Implemented augmentations are

  • rotation
  • mirroring
  • random deformation
{% raw %}

class DeformationField[source]

DeformationField(shape=(540, 540), scale=1)

Creates a deformation field for data augmentation

{% endraw %} {% raw %}
{% endraw %}

Original Image

{% raw %}
inst_labels, _ = ndimage.measurements.label(mask)
inp = torch.Tensor(inst_labels)
wt = WeightTransform(channels=inp.size(-1))
weights = wt(inp.view(1, *inp.shape))[0].numpy()
tst = DeformationField(shape=(260, 260))
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)), 
     tst.apply(weights, offset=(270,270)))
{% endraw %}

Add mirroring

{% raw %}
tst = DeformationField()
tst.mirror((1,1))
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)), 
     tst.apply(weights, offset=(270,270)))
{% endraw %}

Add random deformation

{% raw %}
tst.addRandomDeformation()
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)), 
     tst.apply(weights, offset=(270,270)))
{% endraw %}

Add rotation

{% raw %}
tst.rotate(1,1,1)
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)), 
     tst.apply(weights, offset=(270,270)))
{% endraw %}

Datasets

Pytorch map-style datasets for training and validation.

Helper functions

{% raw %}
{% endraw %} {% raw %}
{% endraw %}

Base Class

{% raw %}

class BaseDataset[source]

BaseDataset(*args, **kwds) :: Dataset

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader.

.. note:: :class:~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
path = Path('sample_data')
files = get_image_files(path/'images')
label_fn = label_fn = lambda o: path/'labels'/f'{o.stem}_mask.png'#lambda o: path/'labels'/f'{o.stem}_mask{o.suffix}'
tst = BaseDataset(files, label_fn=label_fn, fbr=0.6)
tst.show_data()
Preprocessing ['01.png']
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 16 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed:    0.0s finished
{% endraw %} {% raw %}
tst.clear_cached_weights()
Deleting all cache at sample_data/labels/.cache
{% endraw %}

RandomTileDataset

For training

{% raw %}

class RandomTileDataset[source]

RandomTileDataset(*args, **kwds) :: BaseDataset

Pytorch Dataset that creates random tiles with augmentations from the input images.

{% endraw %} {% raw %}
{% endraw %}

Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. tile_shape - The tile shape the network expects as input padding - The padding (input shape - output shape) classlabels - A list containing the corresponding class labels. 0 = ignore, 1 = background, 2-n foreground classes If None, the problem will be treated as binary segmentation n_classes - The number of classes including background ignore - A list containing the corresponding ignore regions. weights - A list containing the corresponding weights. element_size_um - The target pixel size in micrometers batch_size - The number of tiles to generate per batch rotation_range_deg - (alpha_min, alpha_max): The range of rotation angles. A random rotation is drawn from a uniform distribution in the given range flip - If true, a coin flip decides whether a mirrored tile will be generated deformation_grid - (dx, dy): The distance of neighboring grid points in pixels for which random deformation vectors are drawn deformation_magnitude - (sx, sy): The standard deviations of the Gaussians, the components of the deformation vector are drawn from value_minimum_range - (v_min, v_max): Input intensity zero will be mapped to a random value in the given range value_maximum_range - (v_min, v_max): Input intensity one will be mapped to a random value within the given range value_slope_range - (s_min, s_max): The slope at control points is drawn from a uniform distribution in the given range

Show data

{% raw %}
tst = RandomTileDataset(files, label_fn=label_fn, n_jobs=2, verbose=2, scale=2)#, albumentations_tfms=get_aug())
tst.show_data()
Preprocessing ['01.png']
[Parallel(n_jobs=2)]: Using backend ThreadingBackend with 2 concurrent workers.
[Parallel(n_jobs=2)]: Done   1 out of   1 | elapsed:    0.0s finished
{% endraw %}

Show random tile (default padding = (184,184))

{% raw %}
tile = tst[0]
show(tile[0], tile[1], tile[2])
{% endraw %}

Compute stats

{% raw %}
tst.compute_stats()
Computing Stats...
([array([0.39269134])], [array([0.33001896])])
{% endraw %}

TileDataset

{% raw %}

class TileDataset[source]

TileDataset(*args, **kwds) :: BaseDataset

Pytorch Dataset that creates random tiles for validation and prediction on new data.

{% endraw %} {% raw %}
{% endraw %}

Show data

{% raw %}
tst = TileDataset(files, label_fn=label_fn, tile_shape=(450,450), padding=(10,10), val_length=6)
tst.show_data()
Using preprocessed masks from sample_data/labels/.cache
Reducing validation from lenght 6 to 4
{% endraw %}

Show tiles

{% raw %}
for i in range(len(tst)): 
    print(f'Tile {i}')
    tile = tst[i]
    show(tile[0], tile[1], tile[2])
Tile 0
Tile 1
Tile 2
Tile 3
{% endraw %}

Reconstruct masks

{% raw %}
tst = TileDataset(files, label_fn=label_fn, tile_shape=(240,240), padding=(10,10), scale=4)
msk_tiles = [x[1] for x in tst]
msk = tst.reconstruct_from_tiles(msk_tiles)
plt.imshow(msk[0], cmap='binary_r');
Using preprocessed masks from sample_data/labels/.cache
{% endraw %}