Source code for super_gradients.training.utils.segmentation_utils

import os
import cv2
import numpy as np
from typing import Union, Callable
import torch
import torch.nn.functional as F
from torchvision.utils import draw_segmentation_masks

# FIXME: REFACTOR AUGMENTATIONS, CONSIDER USING A MORE EFFICIENT LIBRARIES SUCH AS, IMGAUG, DALI ETC.
from super_gradients.training import utils as core_utils


[docs]def coco_sub_classes_inclusion_tuples_list(): return [(0, 'background'), (5, 'airplane'), (2, 'bicycle'), (16, 'bird'), (9, 'boat'), (44, 'bottle'), (6, 'bus'), (3, 'car'), (17, 'cat'), (62, 'chair'), (21, 'cow'), (67, 'dining table'), (18, 'dog'), (19, 'horse'), (4, 'motorcycle'), (1, 'person'), (64, 'potted plant'), (20, 'sheep'), (63, 'couch'), (7, 'train'), (72, 'tv')]
[docs]def to_one_hot(target: torch.Tensor, num_classes: int, ignore_index: int = None): """ Target label to one_hot tensor. labels and ignore_index must be consecutive numbers. :param target: Class labels long tensor, with shape [N, H, W] :param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot result. :return: one hot tensor with shape [N, num_classes, H, W] """ num_classes = num_classes if ignore_index is None else num_classes + 1 one_hot = F.one_hot(target, num_classes).permute((0, 3, 1, 2)) if ignore_index is not None: # remove ignore_index channel one_hot = torch.cat([one_hot[:, :ignore_index], one_hot[:, ignore_index + 1:]], dim=1) return one_hot
[docs]def reverse_imagenet_preprocessing(im_tensor: torch.Tensor) -> np.ndarray: """ :param im_tensor: images in a batch after preprocessing for inference, RGB, (B, C, H, W) :return: images in a batch in cv2 format, BGR, (B, H, W, C) """ im_np = im_tensor.cpu().numpy() im_np = im_np[:, ::-1, :, :].transpose(0, 2, 3, 1) im_np *= np.array([[[.229, .224, .225][::-1]]]) im_np += np.array([[[.485, .456, .406][::-1]]]) im_np *= 255. return np.ascontiguousarray(im_np, dtype=np.uint8)
[docs]class BinarySegmentationVisualization: @staticmethod def _visualize_image(image_np: np.ndarray, pred_mask: torch.Tensor, target_mask: torch.Tensor, image_scale: float, checkpoint_dir: str, image_name: str): pred_mask = pred_mask.copy() image_np = torch.from_numpy(np.moveaxis(image_np, -1, 0).astype(np.uint8)) pred_mask = pred_mask[np.newaxis, :, :] > 0.5 target_mask = target_mask[np.newaxis, :, :].astype(bool) tp_mask = np.logical_and(pred_mask, target_mask) fp_mask = np.logical_and(pred_mask, np.logical_not(target_mask)) fn_mask = np.logical_and(np.logical_not(pred_mask), target_mask) overlay = torch.from_numpy(np.concatenate([tp_mask, fp_mask, fn_mask])) # SWITCH BETWEEN BLUE AND RED IF WE SAVE THE IMAGE ON THE DISC AS OTHERWISE WE CHANGE CHANNEL ORDERING colors = ['green', 'red', 'blue'] res_image = draw_segmentation_masks(image_np, overlay, colors=colors).detach().numpy() res_image = np.concatenate([res_image[ch, :, :, np.newaxis] for ch in range(3)], 2) res_image = cv2.resize(res_image.astype(np.uint8), (0, 0), fx=image_scale, fy=image_scale, interpolation=cv2.INTER_NEAREST) if checkpoint_dir is None: return res_image else: cv2.imwrite(os.path.join(checkpoint_dir, str(image_name) + '.jpg'), res_image)
[docs] @staticmethod def visualize_batch(image_tensor: torch.Tensor, pred_mask: torch.Tensor, target_mask: torch.Tensor, batch_name: Union[int, str], checkpoint_dir: str = None, undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = reverse_imagenet_preprocessing, image_scale: float = 1.): """ A helper function to visualize detections predicted by a network: saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call. Colors are generated on the fly: uniformly sampled from color wheel to support all given classes. :param image_tensor: rgb images, (B, H, W, 3) :param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6), values on dim 1 are: x1, y1, x2, y2, confidence, class :param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h (coordinates scaled to [0, 1]) :param batch_name: id of the current batch to use for image naming :param checkpoint_dir: a path where images with boxes will be saved. if None, the result images will be returns as a list of numpy image arrays :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images :param image_scale: scale factor for output image """ image_np = undo_preprocessing_func(image_tensor.detach()) pred_mask = torch.sigmoid(pred_mask[:, 0, :, :]) # comment out out_images = [] for i in range(image_np.shape[0]): preds = pred_mask[i].detach().cpu().numpy() targets = target_mask[i].detach().cpu().numpy() image_name = '_'.join([str(batch_name), str(i)]) res_image = BinarySegmentationVisualization._visualize_image(image_np[i], preds, targets, image_scale, checkpoint_dir, image_name) if res_image is not None: out_images.append(res_image) return out_images
[docs]def visualize_batches(dataloader, module, visualization_path, num_batches=1, undo_preprocessing_func=None): os.makedirs(visualization_path, exist_ok=True) for batch_i, (imgs, targets) in enumerate(dataloader): if batch_i == num_batches: return imgs = core_utils.tensor_container_to_device(imgs, torch.device('cuda:0')) targets = core_utils.tensor_container_to_device(targets, torch.device('cuda:0')) pred_mask = module(imgs) # Visualize the batch if undo_preprocessing_func: BinarySegmentationVisualization.visualize_batch(imgs, pred_mask, targets, batch_i, visualization_path, undo_preprocessing_func=undo_preprocessing_func) else: BinarySegmentationVisualization.visualize_batch(imgs, pred_mask, targets, batch_i, visualization_path)
[docs]def one_hot_to_binary_edge(x: torch.Tensor, kernel_size: int, flatten_channels: bool = True) -> torch.Tensor: """ Utils function to create edge feature maps. :param x: input tensor, must be one_hot tensor with shape [B, C, H, W] :param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as follows: `edge_width = kernel - 1` :param flatten_channels: Whether to apply logical_or across channels dimension, if at least one pixel class is considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else [B, 1, H, W]. Default is `True`. :return: one_hot edge torch.Tensor. """ if kernel_size < 0 or kernel_size % 2 == 0: raise ValueError(f"kernel size must be an odd positive values, such as [1, 3, 5, ..], found: {kernel_size}") _kernel = torch.ones(x.size(1), 1, kernel_size, kernel_size, dtype=torch.float32, device=x.device) padding = (kernel_size - 1) // 2 # Use replicate padding to prevent class shifting and edge formation at the image boundaries. padded_x = F.pad(x.float(), mode="replicate", pad=[padding] * 4) # The binary edges feature map is created by subtracting dilated features from erosed features. # First the positive one value masks are expanded (dilation) by applying a sliding window filter of one values. # The resulted output is then clamped to binary format to [0, 1], this way the one-hot boundaries are expanded by # (kernel_size - 1) / 2. dilation = torch.clamp( F.conv2d(padded_x, _kernel, groups=x.size(1)), 0, 1 ) # Similar to dilation, erosion (can be seen as inverse of dilation) is applied to contract the one-hot features by # applying a dilation operation on the inverse of the one-hot features. erosion = 1 - torch.clamp( F.conv2d(1 - padded_x, _kernel, groups=x.size(1)), 0, 1 ) # Finally the edge features are the result of subtracting dilation by erosion. # i.e for a simple 1D one-hot input: [0, 0, 0, 1, 1, 1, 0, 0, 0], using sliding kernel with size 3: [1, 1, 1] # Dilated features: [0, 0, 1, 1, 1, 1, 1, 0, 0] # Erosed inverse features: [0, 0, 0, 0, 1, 0, 0, 0, 0] # Edge features: dilation - erosion: [0, 0, 1, 1, 0, 1, 1, 0, 0] edge = dilation - erosion if flatten_channels: # use max operator across channels. Equivalent to logical or for input with binary values [0, 1]. edge = edge.max(dim=1, keepdim=True)[0] return edge
[docs]def target_to_binary_edge(target: torch.Tensor, num_classes: int, kernel_size: int, ignore_index: int = None, flatten_channels: bool = True) -> torch.Tensor: """ Utils function to create edge feature maps from target. :param target: Class labels long tensor, with shape [N, H, W] :param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot result. :param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as follows: `edge_width = kernel - 1` :param flatten_channels: Whether to apply logical or across channels dimension, if at least one pixel class is considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else [B, 1, H, W]. Default is `True`. :return: one_hot edge torch.Tensor. """ one_hot = to_one_hot(target, num_classes=num_classes, ignore_index=ignore_index) return one_hot_to_binary_edge(one_hot, kernel_size=kernel_size, flatten_channels=flatten_channels)