Source code for super_gradients.training.utils.segmentation_utils

import random
from PIL import Image, ImageOps, ImageFilter
import collections
from typing import Optional, Union, Tuple, List
import math

# FIXME: REFACTOR AUGMENTATIONS, CONSIDER USING A MORE EFFICIENT LIBRARIES SUCH AS, IMGAUG, DALI ETC.

image_resample = Image.BILINEAR
mask_resample = Image.NEAREST


[docs]class RandomFlip: """ Randomly flips the image and mask (synchronously) with probability 'prob'. """ def __init__(self, prob: float = 0.5): assert 0. <= prob <= 1., f"Probability value must be between 0 and 1, found {prob}" self.prob = prob def __call__(self, sample: dict): image = sample["image"] mask = sample["mask"] if random.random() < self.prob: image = image.transpose(Image.FLIP_LEFT_RIGHT) mask = mask.transpose(Image.FLIP_LEFT_RIGHT) sample["image"] = image sample["mask"] = mask return sample
[docs]class Rescale: """ Rescales the image and mask (synchronously) while preserving aspect ratio. The rescaling can be done according to scale_factor, short_size or long_size. If more than one argument is given, the rescaling mode is determined by this order: scale_factor, then short_size, then long_size. Args: scale_factor: rescaling is done by multiplying input size by scale_factor: out_size = (scale_factor * w, scale_factor * h) short_size: rescaling is done by determining the scale factor by the ratio short_size / min(h, w). long_size: rescaling is done by determining the scale factor by the ratio long_size / max(h, w). """ def __init__(self, scale_factor: Optional[float] = None, short_size: Optional[int] = None, long_size: Optional[int] = None): self.scale_factor = scale_factor self.short_size = short_size self.long_size = long_size self.check_valid_arguments() def __call__(self, sample: dict): image = sample["image"] mask = sample["mask"] w, h = image.size if self.scale_factor is not None: scale = self.scale_factor elif self.short_size is not None: short_size = min(w, h) scale = self.short_size / short_size else: long_size = max(w, h) scale = self.long_size / long_size out_size = int(scale * w), int(scale * h) image = image.resize(out_size, image_resample) mask = mask.resize(out_size, mask_resample) sample["image"] = image sample["mask"] = mask return sample
[docs] def check_valid_arguments(self): if self.scale_factor is None and self.short_size is None and self.long_size is None: raise ValueError("Must assign one rescale argument: scale_factor, short_size or long_size") if self.scale_factor is not None and self.scale_factor <= 0: raise ValueError(f"Scale factor must be a positive number, found: {self.scale_factor}") if self.short_size is not None and self.short_size <= 0: raise ValueError(f"Short size must be a positive number, found: {self.short_size}") if self.long_size is not None and self.long_size <= 0: raise ValueError(f"Long size must be a positive number, found: {self.long_size}")
[docs]class RandomRescale: """ Random rescale the image and mask (synchronously) while preserving aspect ratio. Scale factor is randomly picked between scales [min, max] Args: scales: scale range tuple (min, max), if scales is a float range will be defined as (1, scales) if scales > 1, otherwise (scales, 1). must be a positive number. """ def __init__(self, scales: Union[float, Tuple, List] = (0.5, 2.0)): self.scales = scales self.check_valid_arguments() def __call__(self, sample: dict): image = sample["image"] mask = sample["mask"] w, h = image.size scale = random.uniform(self.scales[0], self.scales[1]) out_size = int(scale * w), int(scale * h) image = image.resize(out_size, image_resample) mask = mask.resize(out_size, mask_resample) sample["image"] = image sample["mask"] = mask return sample
[docs] def check_valid_arguments(self): """ Check the scale values are valid. if order is wrong, flip the order and return the right scale values. """ if not isinstance(self.scales, collections.abc.Iterable): if self.scales <= 1: self.scales = (self.scales, 1) else: self.scales = (1, self.scales) if self.scales[0] < 0 or self.scales[1] < 0: raise ValueError(f"RandomRescale scale values must be positive numbers, found: {self.scales}") if self.scales[0] > self.scales[1]: self.scales = (self.scales[1], self.scales[0]) return self.scales
[docs]class RandomRotate: """ Randomly rotates image and mask (synchronously) between 'min_deg' and 'max_deg'. """ def __init__(self, min_deg: float = -10, max_deg: float = 10, fill_mask: int = 0, fill_image: Union[int, Tuple, List] = 0): self.min_deg = min_deg self.max_deg = max_deg self.fill_mask = fill_mask # grey color in RGB mode self.fill_image = (fill_image, fill_image, fill_image) self.check_valid_arguments() def __call__(self, sample: dict): image = sample["image"] mask = sample["mask"] deg = random.uniform(self.min_deg, self.max_deg) image = image.rotate(deg, resample=image_resample, fillcolor=self.fill_image) mask = mask.rotate(deg, resample=mask_resample, fillcolor=self.fill_mask) sample["image"] = image sample["mask"] = mask return sample
[docs] def check_valid_arguments(self): self.fill_mask, self.fill_image = _validate_fill_values_arguments(self.fill_mask, self.fill_image)
[docs]class CropImageAndMask: """ Crops image and mask (synchronously). In "center" mode a center crop is performed while, in "random" mode the drop will be positioned around random coordinates. """ def __init__(self, crop_size: Union[float, Tuple, List], mode: str): """ :param crop_size: tuple of (width, height) for the final crop size, if is scalar size is a square (crop_size, crop_size) :param mode: how to choose the center of the crop, 'center' for the center of the input image, 'random' center the point is chosen randomally """ self.crop_size = crop_size self.mode = mode self.check_valid_arguments() def __call__(self, sample: dict): image = sample["image"] mask = sample["mask"] w, h = image.size if self.mode == "random": x1 = random.randint(0, w - self.crop_size[0]) y1 = random.randint(0, h - self.crop_size[1]) else: x1 = int(round((w - self.crop_size[0]) / 2.)) y1 = int(round((h - self.crop_size[1]) / 2.)) image = image.crop((x1, y1, x1 + self.crop_size[0], y1 + self.crop_size[1])) mask = mask.crop((x1, y1, x1 + self.crop_size[0], y1 + self.crop_size[1])) sample["image"] = image sample["mask"] = mask return sample
[docs] def check_valid_arguments(self): if self.mode not in ["center", "random"]: raise ValueError(f"Unsupported mode: found: {self.mode}, expected: 'center' or 'random'") if not isinstance(self.crop_size, collections.abc.Iterable): self.crop_size = (self.crop_size, self.crop_size) if self.crop_size[0] <= 0 or self.crop_size[1] <= 0: raise ValueError(f"Crop size must be positive numbers, found: {self.crop_size}")
[docs]class RandomGaussianBlur: """ Adds random Gaussian Blur to image with probability 'prob'. """ def __init__(self, prob: float = 0.5): assert 0. <= prob <= 1., "Probability value must be between 0 and 1" self.prob = prob def __call__(self, sample: dict): image = sample["image"] mask = sample["mask"] if random.random() < self.prob: image = image.filter(ImageFilter.GaussianBlur( radius=random.random())) sample["image"] = image sample["mask"] = mask return sample
[docs]class PadShortToCropSize: """ Pads image to 'crop_size'. Should be called only after "Rescale" or "RandomRescale" in augmentations pipeline. """ def __init__(self, crop_size: Union[float, Tuple, List], fill_mask: int = 0, fill_image: Union[int, Tuple, List] = 0): """ :param crop_size: tuple of (width, height) for the final crop size, if is scalar size is a square (crop_size, crop_size) :param fill_mask: value to fill mask labels background. :param fill_image: grey value to fill image padded background. """ # CHECK IF CROP SIZE IS A ITERABLE OR SCALAR self.crop_size = crop_size self.fill_mask = fill_mask self.fill_image = fill_image self.check_valid_arguments() def __call__(self, sample: dict): image = sample["image"] mask = sample["mask"] w, h = image.size # pad images from center symmetrically if w < self.crop_size[0] or h < self.crop_size[1]: padh = (self.crop_size[1] - h) / 2 if h < self.crop_size[1] else 0 pad_top, pad_bottom = math.ceil(padh), math.floor(padh) padw = (self.crop_size[0] - w) / 2 if w < self.crop_size[0] else 0 pad_left, pad_right = math.ceil(padw), math.floor(padw) image = ImageOps.expand(image, border=(pad_left, pad_top, pad_right, pad_bottom), fill=self.fill_image) mask = ImageOps.expand(mask, border=(pad_left, pad_top, pad_right, pad_bottom), fill=self.fill_mask) sample["image"] = image sample["mask"] = mask return sample
[docs] def check_valid_arguments(self): if not isinstance(self.crop_size, collections.abc.Iterable): self.crop_size = (self.crop_size, self.crop_size) if self.crop_size[0] <= 0 or self.crop_size[1] <= 0: raise ValueError(f"Crop size must be positive numbers, found: {self.crop_size}") self.fill_mask, self.fill_image = _validate_fill_values_arguments(self.fill_mask, self.fill_image)
def _validate_fill_values_arguments(fill_mask: int, fill_image: Union[int, Tuple, List]): if not isinstance(fill_image, collections.abc.Iterable): # If fill_image is single value, turn to grey color in RGB mode. fill_image = (fill_image, fill_image, fill_image) elif len(fill_image) != 3: raise ValueError(f"fill_image must be an RGB tuple of size equal to 3, found: {fill_image}") # assert values are integers if not isinstance(fill_mask, int) or not all(isinstance(x, int) for x in fill_image): raise ValueError(f"Fill value must be integers," f" found: fill_image = {fill_image}, fill_mask = {fill_mask}") # assert values in range 0-255 if min(fill_image) < 0 or max(fill_image) > 255 or fill_mask < 0 or fill_mask > 255: raise ValueError(f"Fill value must be a value from 0 to 255," f" found: fill_image = {fill_image}, fill_mask = {fill_mask}") return fill_mask, fill_image