--- title: Data keywords: fastai sidebar: home_sidebar summary: "This module defines tools for image data preprocessing and real-time data augmentation that is used to train a model." description: "This module defines tools for image data preprocessing and real-time data augmentation that is used to train a model." nb_path: "nbs/02_data.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %}

Original Paper: Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70.

The code for data augmentation and mask weight generation was provided by the authors and adapted for this package.

Plot images and masks

{% raw %}

show[source]

show(*obj, file_name=None, overlay=False, pred=False, show_bbox=True, figsize=(10, 10), cmap='binary_r', **kwargs)

Show image, mask, and weight (optional)

{% endraw %} {% raw %}
{% endraw %}

The show methods in fastai all rely on some types being able to show themselves. We create a new type with a show method.

Typedispatch

Custom show_batch and show_results for DataLoader

{% raw %}
{% endraw %} {% raw %}
{% endraw %}

Example image and mask

We will use an example image and mask to guide through the documentation.

Plot example image and mask

{% raw %}
image = imageio.imread(path/'images'/'01.png')
mask = imageio.imread(path/'labels'/'01_mask.png')
show(image)
show(image, mask)
{% endraw %}

Weight Calculation

We calculate the weight for the weighted softmax cross entropy loss from the given mask (classlabels).

{% raw %}

calculate_weights[source]

calculate_weights(clabels=None, instlabels=None, ignore=None, n_dims=2, bws=10, fds=10, bwf=10, fbr=0.1)

Calculates the weights from the given mask (classlabels clabels or instlabels).

{% endraw %} {% raw %}
{% endraw %}

Arguments in calculate_weights:

  • clabels: class labels (segmentation mask),
  • instlabels: instance labels (segmentation mask),
  • ignore: ignored reagions,
  • n_dims (int) = number of classes for clabels
  • bws (float): border_weight_sigma in pixel
  • fds (float): foreground_dist_sigma in pixel
  • bwf (float): border_weight_factor
  • fbr (float): foreground_background_ratio
{% raw %}
labels, weights, _ =  calculate_weights(clabels=mask)
titles = ['Labels (Mask)', 'Weights', 'PDF', ]
show(image, labels, weights)
{% endraw %}

Plot different weight parameters (foreground_dist_sigma_px, border_weight_factor)

{% raw %}
{% endraw %}

Data augmentation

Deformation field class to ensure that all augmentations are performed equally on images, masks, and weights. Implemented augmentations are

  • rotation
  • mirroring
  • random deformation
{% raw %}

class DeformationField[source]

DeformationField(shape=(540, 540))

Creates a deformation field for data augmentation

{% endraw %} {% raw %}
{% endraw %}

Original Image

{% raw %}
show(image, labels, weights)
{% endraw %}

Add mirroring

{% raw %}
tst = DeformationField()
tst.mirror((1,1))
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)), 
     tst.apply(weights, offset=(270,270)))
{% endraw %}

Add random deformation

{% raw %}
tst.addRandomDeformation()
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)), 
     tst.apply(weights, offset=(270,270)))
{% endraw %}

Add rotation

{% raw %}
tst.rotate(1,1,1)
show(tst.apply(image, offset=(270,270)), 
     tst.apply(mask, offset=(270,270)), 
     tst.apply(weights, offset=(270,270)))
{% endraw %}

Datasets

Pytorch map-style datasets for training and validation.

Helper functions

{% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}
{% endraw %}

Base Class

{% raw %}

class BaseDataset[source]

BaseDataset(*args, **kwds) :: Dataset

An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader.

.. note:: :class:~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
path = Path('sample_data')
files = get_image_files(path/'images')
label_fn = lambda o: path/'labels'/f'{o.stem}_mask{o.suffix}'
tst = BaseDataset(files, label_fn=label_fn, create_weights=True)
tst.show_data()
Creating weights for 01.png
{% endraw %} {% raw %}
tst.clear_cached_weights()
Deleting all cache at sample_data/labels/.cache
{% endraw %}

RandomTileDataset

For training

{% raw %}

class RandomTileDataset[source]

RandomTileDataset(*args, **kwds) :: BaseDataset

Pytorch Dataset that creates random tiles with augmentations from the input images.

{% endraw %} {% raw %}
{% endraw %}

Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. tile_shape - The tile shape the network expects as input padding - The padding (input shape - output shape) classlabels - A list containing the corresponding class labels. 0 = ignore, 1 = background, 2-n foreground classes If None, the problem will be treated as binary segmentation n_classes - The number of classes including background ignore - A list containing the corresponding ignore regions. weights - A list containing the corresponding weights. element_size_um - The target pixel size in micrometers batch_size - The number of tiles to generate per batch rotation_range_deg - (alpha_min, alpha_max): The range of rotation angles. A random rotation is drawn from a uniform distribution in the given range flip - If true, a coin flip decides whether a mirrored tile will be generated deformation_grid - (dx, dy): The distance of neighboring grid points in pixels for which random deformation vectors are drawn deformation_magnitude - (sx, sy): The standard deviations of the Gaussians, the components of the deformation vector are drawn from value_minimum_range - (v_min, v_max): Input intensity zero will be mapped to a random value in the given range value_maximum_range - (v_min, v_max): Input intensity one will be mapped to a random value within the given range value_slope_range - (s_min, s_max): The slope at control points is drawn from a uniform distribution in the given range

Show data

{% raw %}
tst = RandomTileDataset(files, label_fn=label_fn, n_jobs=2, verbose=2)
tst.show_data()
Creating weights for ['01.png']
[Parallel(n_jobs=2)]: Using backend ThreadingBackend with 2 concurrent workers.
[Parallel(n_jobs=2)]: Done   1 out of   1 | elapsed:    0.1s finished
{% endraw %}

Show random tile (default padding = (184,184))

{% raw %}
tile = tst[0]
show(tile[0], tile[1], tile[2])
{% endraw %}

Compute stats

{% raw %}
tst.compute_stats()
Computing Stats...
([array([0.34229552])], [array([0.30897086])])
{% endraw %}

TileDataset

{% raw %}

class TileDataset[source]

TileDataset(*args, **kwds) :: BaseDataset

Pytorch Dataset that creates random tiles for validation and prediction on new data.

{% endraw %} {% raw %}
{% endraw %}

Show data

{% raw %}
tst = TileDataset(files, label_fn=label_fn, tile_shape=(450,450), padding=(10,10))
tst.show_data()
Using cached mask weights from sample_data/labels/.cache
{% endraw %}

Show tiles

{% raw %}
for i in range(len(tst)): 
    print(f'Tile {i}')
    tile = tst[i]
    show(tile[0], tile[1], tile[2])
Tile 0
Tile 1
Tile 2
Tile 3
{% endraw %}

Reconstruct masks

{% raw %}
msk_tiles = [x[1] for x in tst]
msk = tst.reconstruct_from_tiles(msk_tiles)
plt.imshow(msk[0], cmap='binary_r');
{% endraw %}