Source code for farabio.utils.losses

import torch.nn.functional as F
import torch
import torch.nn as nn
from torchvision.models.vgg import vgg16
from tqdm import tqdm
import numpy as np
from torch.autograd import Variable
try:
    from itertools import  ifilterfalse
except ImportError: # py3k
    from itertools import  filterfalse as ifilterfalse


[docs]class DiceLoss(nn.Module): r"""The Dice coefficient, or Dice-Sørensen coefficient. It is a common metric for pixel segmentation that can also be modified to act as a loss function. .. math:: D S C=\frac{2|X \cap Y|}{|X|+|Y|} Examples ---------- >>> dice_loss = DiceLoss() >>> dice_loss(outputs, targets) """
[docs] def __init__(self, weight=None, size_average=True): super(DiceLoss, self).__init__()
[docs] def forward(self, inputs, targets, smooth=1): #comment out if your model contains a sigmoid or equivalent activation layer inputs = F.sigmoid(inputs) #flatten label and prediction tensors inputs = inputs.view(-1) targets = targets.view(-1) intersection = (inputs * targets).sum() dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) return 1 - dice
[docs]class DiceBCELoss(nn.Module): r"""Dice combined with BCE This loss combines Dice loss with the standard binary cross-entropy (BCE) loss that is generally the default for segmentation models. Combining the two methods allows for some diversity in the loss, while benefitting from the stability of BCE. The equation for multi-class BCE by itself will be familiar to anyone who has studied logistic regression. .. math:: J(\mathbf{w})=\frac{1}{N} \sum_{n=1}^{N} H\left(p_{ n}, q_{n}\right)=-\frac{1}{N} \sum_{n=1}^{N}\left[y_{n} \log \hat{y}_{n}+\left(1-y_{n}\right) \log \left(1-\hat{y}_{n}\right)\right] Examples ---------- >>> dice_bce_loss = DiceBCELoss() >>> dice_bce_loss(outputs, targets) """
[docs] def __init__(self, weight=None, size_average=True): super(DiceBCELoss, self).__init__()
[docs] def forward(self, inputs, targets, smooth=1): #comment out if your model contains a sigmoid or equivalent activation layer inputs = F.sigmoid(inputs) bce_weight = 0.5 #flatten label and prediction tensors inputs = inputs.view(-1) targets = targets.view(-1) intersection = (inputs * targets).sum() dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') loss_final = BCE * bce_weight + dice_loss * (1 - bce_weight) return loss_final
[docs]class IoULoss(nn.Module): r"""The IoU metric, or Jaccard Index. It is similar to the Dice metric and is calculated as the ratio between the overlap of the positive instances between two sets, and their mutual combined values. .. math:: J(A, B)=\frac{|A \cap B|}{|A \cup B|}=\frac{|A \cap B|}{|A|+|B|-|A \cap B|} Examples ---------- >>> iou_loss = IoULoss() >>> iou_loss(outputs, targets) """
[docs] def __init__(self, weight=None, size_average=True): super(IoULoss, self).__init__()
[docs] def forward(self, inputs, targets, smooth=1): #comment out if your model contains a sigmoid or equivalent activation layer inputs = F.sigmoid(inputs) #flatten label and prediction tensors inputs = inputs.view(-1) targets = targets.view(-1) #intersection is equivalent to True Positive count #union is the mutually inclusive area of all labels & predictions intersection = (inputs * targets).sum() total = (inputs + targets).sum() union = total - intersection IoU = (intersection + smooth)/(union + smooth) return 1 - IoU
[docs]class FocalLoss(nn.Module): r"""Focal Loss from [RetinaNet]_ Notes ------- .. math:: \mathrm{FL}\left(p_{t}\right)=-\alpha_{t}\left(1-p_{t}\right)^{\gamma} \log \left(p_{t}\right) where: :math:`p_t` is the model's estimated probability for each class. It was introduced by Facebook AI Research in 2017 to combat extremely imbalanced datasets where positive cases were relatively rare. Figure excerpt from [amaarora]_: .. image:: ../imgs/focal_loss.png :width: 400 With the help of hyperparameters, :math:`\alpha` and :math:`\gamma`. The focusing parameter :math:`\gamma` smoothly adjusts the rate at which easy examples are down-weighted. When :math:`\gamma = 0`, focal loss is equivalent to categorical cross-entropy, and as :math:`\gamma` is increased, the effect of the modulating factor is likewise increased (:math:`\gamma = 2` works best in experiments). And, :math:`\alpha` is a weighting factor. If the :math:`\alpha = 1`, then class 1 and class 0 (in binary case) have same weights, so :math:`\alpha` balances the importance of positive/negative examples in this way. Examples ---------- >>> focal_loss = FocalLoss() >>> focal_loss(outputs, targets) References --------------- .. [RetinaNet] https://arxiv.org/abs/1708.02002 .. [amaarora] https://amaarora.github.io/2020/06/29/FocalLoss.html """
[docs] def __init__(self, weight=None, size_average=True): super(FocalLoss, self).__init__()
[docs] def forward(self, inputs, targets, alpha=0.8, gamma=0.2, smooth=1): #comment out if your model contains a sigmoid or equivalent activation layer inputs = F.sigmoid(inputs) #flatten label and prediction tensors inputs = inputs.view(-1) targets = targets.view(-1) #first compute binary cross-entropy BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') BCE_EXP = torch.exp(-BCE) focal_loss = alpha * (1-BCE_EXP)**gamma * BCE return focal_loss
[docs]class TverskyLoss(nn.Module): r"""Tversky loss from [1]_ Notes ------- .. math:: \mathrm{S}(P, G, \alpha ; \beta)=\frac{|P G|}{|P G|+\alpha|P G|+\beta|G P|} where: - :math:`P` and :math:`G` are the predicted and ground truth binary labels. - :math:`\alpha` and :math:`\beta` control the magnitude of the penalties for FPs and FNs, respectively. Notes: - :math:`\alpha = \beta = 0.5` => dice coeff - :math:`\alpha = \beta = 1` => tanimoto coeff - :math:`\alpha + \beta = 1` => F beta coeff Examples ---------- >>> tversky_loss = TverskyLoss() >>> tversky_loss(outputs, targets) References --------------- .. [1] https://arxiv.org/abs/1706.05721 """
[docs] def __init__(self, weight=None, size_average=True): super(TverskyLoss, self).__init__()
[docs] def forward(self, inputs, targets, smooth=1, alpha=0.5, beta=0.5): #comment out if your model contains a sigmoid or equivalent activation layer inputs = F.sigmoid(inputs) #flatten label and prediction tensors inputs = inputs.view(-1) targets = targets.view(-1) #True Positives, False Positives & False Negatives TP = (inputs * targets).sum() FP = ((1-targets) * inputs).sum() FN = (targets * (1-inputs)).sum() Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth) return 1 - Tversky
[docs]class FocalTverskyLoss(nn.Module): r"""A variant on the Tversky loss. It also includes the gamma modifier from Focal Loss from [1]_. Examples ---------- >>> focal_tversky_loss = FocalTverskyLoss() >>> focal_tversky_loss(outputs, targets) References --------------- .. [1] https://arxiv.org/abs/1810.07842 """
[docs] def __init__(self, weight=None, size_average=True): super(FocalTverskyLoss, self).__init__()
[docs] def forward(self, inputs, targets, smooth=1, alpha=0.5, beta=0.5, gamma=1): #comment out if your model contains a sigmoid or equivalent activation layer inputs = F.sigmoid(inputs) #flatten label and prediction tensors inputs = inputs.view(-1) targets = targets.view(-1) #True Positives, False Positives & False Negatives TP = (inputs * targets).sum() FP = ((1-targets) * inputs).sum() FN = (targets * (1-inputs)).sum() Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth) FocalTversky = (1 - Tversky)**gamma return FocalTversky
# --------------------------- HELPER FUNCTIONS --------------------------- def isnan(x): return x != x def mean(l, ignore_nan=False, empty=0): """ nanmean compatible with generators. """ l = iter(l) if ignore_nan: l = ifilterfalse(isnan, l) try: n = 1 acc = next(l) except StopIteration: if empty == 'raise': raise ValueError('Empty mean') return empty for n, v in enumerate(l, 2): acc += v if n == 1: return acc #PyTorch def flatten_binary_scores(scores, labels, ignore=None): """ Flattens predictions in the batch (binary case) Remove labels equal to 'ignore' """ scores = scores.view(-1) labels = labels.view(-1) if ignore is None: return scores, labels valid = (labels != ignore) vscores = scores[valid] vlabels = labels[valid] return vscores, vlabels def lovasz_grad(gt_sorted): """ Computes gradient of the Lovasz extension w.r.t sorted errors See Alg. 1 in paper """ p = len(gt_sorted) gts = gt_sorted.sum() intersection = gts - gt_sorted.float().cumsum(0) union = gts + (1 - gt_sorted).float().cumsum(0) jaccard = 1. - intersection / union if p > 1: # cover 1-pixel case jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] return jaccard def lovasz_hinge(logits, labels, per_image=True, ignore=None): """ Binary Lovasz hinge loss logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) per_image: compute the loss per image instead of per batch ignore: void class id """ if per_image: loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) for log, lab in zip(logits, labels)) else: loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) return loss def lovasz_hinge_flat(logits, labels): """ Binary Lovasz hinge loss logits: [P] Variable, logits at each prediction (between -\infty and +\infty) labels: [P] Tensor, binary ground truth labels (0 or 1) ignore: label to ignore """ if len(labels) == 0: # only void pixels, the gradients should be 0 return logits.sum() * 0. signs = 2. * labels.float() - 1. errors = (1. - logits * Variable(signs)) errors_sorted, perm = torch.sort(errors, dim=0, descending=True) perm = perm.data gt_sorted = labels[perm] grad = lovasz_grad(gt_sorted) loss = torch.dot(F.relu(errors_sorted), Variable(grad)) return loss #===== #Multi-class Lovasz loss #===== def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): """ Multi-class Lovasz-Softmax loss probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. per_image: compute the loss per image instead of per batch ignore: void class labels """ if per_image: loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) for prob, lab in zip(probas, labels)) else: loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) return loss def lovasz_softmax_flat(probas, labels, classes='present'): """ Multi-class Lovasz-Softmax loss probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) labels: [P] Tensor, ground truth labels (between 0 and C - 1) classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. """ if probas.numel() == 0: # only void pixels, the gradients should be 0 return probas * 0. C = probas.size(1) losses = [] class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes for c in class_to_sum: fg = (labels == c).float() # foreground for class c if (classes is 'present' and fg.sum() == 0): continue if C == 1: if len(classes) > 1: raise ValueError('Sigmoid output possible only with 1 class') class_pred = probas[:, 0] else: class_pred = probas[:, c] errors = (Variable(fg) - class_pred).abs() errors_sorted, perm = torch.sort(errors, 0, descending=True) perm = perm.data fg_sorted = fg[perm] losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) return mean(losses)
[docs]class LovaszHingeLoss(nn.Module): r"""The Lovász-Softmax loss. A tractable surrogate for the optimization of the intersection-over-union measure in neural networks from [1]_. Examples --------------- >>> lovasz_hinge_loss = LovaszHingeLoss() >>> lovasz_hinge_loss(outputs, targets) References --------------- .. [1] https://arxiv.org/abs/1705.08790 """
[docs] def __init__(self, weight=None, size_average=True): super(LovaszHingeLoss, self).__init__()
[docs] def forward(self, inputs, targets): inputs = F.sigmoid(inputs) Lovasz = lovasz_hinge(inputs, targets, per_image=False) return Lovasz
class Losses: """Losses class combines BCE and Dice loss Attributes ---------- imgpath : str image path of interest fsize : str file size in human readable format img : numpy.array image h : int height w : int width ch : int channels pmax : int | float max pixel value pmin : int | float min pixel value img_r : numpy.array red channel image img_g : numpy.array green channel image img_b : numpy.array blue channel image Methods ------- bce_loss(self, pred, target) Returns binary cross-entropy loss dice_loss(self, pred, target, smooth=1.) Returns dice loss slice_img(img, slices, orien) Slices image into pieces pad_img(self, droi, simg=None) Pads image into predefined ROI size approx_bcg(self, channel='blue') Approximates background image blend_img(self, ref_, overlap=0.2, ratio=0.5) Blends two images taking the overage of overlap area mask_img(self, bw) Masks image profile_img(self, pt1, pt2) Plots image x/y profile between two coordinates """ def bce_loss(self, pred, target): """Returns binary cross-entropy loss Parameters ---------- pred : torch.Tensor predictions tensor array target : torch.Tensor target tensor array Returns ------- torch.Tensor binary cross-entropy loss """ bce = F.binary_cross_entropy_with_logits(pred, target) return bce def dice_loss(self, pred, target, smooth=1.): """Returns dice loss Parameters ---------- pred : torch.Tensor predictions tensor array target : torch.Tensor target tensor array smooth : float, optional smoothening coefficient, by default 1. Returns ------- torch.Tensor dice loss """ pred = pred.contiguous() target = target.contiguous() intersection = (pred * target).sum(dim=2).sum(dim=2) loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))) return loss.mean() def calc_loss(self, pred, target, bce_weight=0.5): """Calculates combined BCE and Dice losses Parameters ---------- pred : torch.Tensor predictions tensor array target : torch.Tensor target tensor array bce_weight : float, optional weight to give for bce loss, by default 0.5 Returns ------- torch.Tensor bce and dice loss combined """ bce = self.bce_loss(pred, target) dice = self.dice_loss(torch.sigmoid(pred), target) loss = bce * bce_weight + dice * (1 - bce_weight) return loss def cat_loss(self, pred, target): """[summary] Parameters ---------- pred : [type] [description] target : [type] [description] Returns ------- [type] [description] """ loss = torch.nn.CrossEntropyLoss() categ_loss = loss(pred, target) return categ_loss def extract_loss(self, pred, target, device=True, cat_weight=0.5): """Calculates combined categorical and Dice losses Parameters ---------- pred : torch.Tensor predictions tensor array target : torch.Tensor target tensor array bce_weight : float, optional weight to give for bce loss, by default 0.5 Returns ------- torch.Tensor bce and dice loss combined """ categ = self.cat_loss(pred, target) # coefficient of max label pred = torch.sigmoid(pred) coef_tensor = torch.argmax(pred.sum(2).sum(2)[:, 1:], dim=1) # empty tensor ttensor = torch.zeros(target.unsqueeze(1).size(), dtype=torch.float32) # [4, 1, 512, 512] # this was previously torch.device to set on same gpu #ttensor = ttensor.to(device=device) if device: ttensor = ttensor.cuda() # get the mask of max label for ij in range(len(coef_tensor)): ttensor[ij, :, :, :] = pred[ij, coef_tensor[ij]+1, :, :] # dice dice = self.dice_loss(ttensor, target.unsqueeze(1)) # combined loss loss = categ * cat_weight + dice * (1 - cat_weight) return loss def _smooth_l1_loss(self, x, t, in_weight, sigma): sigma2 = sigma ** 2 diff = in_weight * (x - t) abs_diff = diff.abs() flag = (abs_diff.data < (1. / sigma2)).float() y = (flag * (sigma2 / 2.) * (diff ** 2) + (1 - flag) * (abs_diff - 0.5 / sigma2)) return y.sum() def _fast_rcnn_loc_loss(self, pred_loc, gt_loc, gt_label, sigma): in_weight = torch.zeros(gt_loc.shape).cuda() # Localization loss is calculated only for positive rois. # NOTE: unlike origin implementation, # we don't need inside_weight and outside_weight, they can calculate by gt_label in_weight[(gt_label > 0).view(-1, 1).expand_as(in_weight).cuda()] = 1 loc_loss = self._smooth_l1_loss(pred_loc, gt_loc, in_weight.detach(), sigma) # Normalize by total number of negtive and positive rois. # ignore gt_label==-1 for rpn_loss loc_loss /= ((gt_label >= 0).sum().float()) return loc_loss class GeneratorLoss(nn.Module): """Generator loss from VGG16 """ def __init__(self): super(GeneratorLoss, self).__init__() vgg = vgg16(pretrained=True) loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() for param in loss_network.parameters(): param.requires_grad = False self.loss_network = loss_network self.mse_loss = nn.MSELoss() self.tv_loss = TVLoss() def forward(self, out_labels, out_images, target_images): # Adversarial Loss adversarial_loss = torch.mean(1 - out_labels) # Perception Loss perception_loss = self.mse_loss(self.loss_network( out_images), self.loss_network(target_images)) # Image Loss image_loss = self.mse_loss(out_images, target_images) # TV Loss tv_loss = self.tv_loss(out_images) return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss class TVLoss(nn.Module): """Tversky loss """ def __init__(self, tv_loss_weight=1): """[summary] Parameters ---------- tv_loss_weight : int, optional [description], by default 1 """ super(TVLoss, self).__init__() self.tv_loss_weight = tv_loss_weight def forward(self, x): """[summary] Parameters ---------- x : [type] [description] Returns ------- [type] [description] """ batch_size = x.size()[0] h_x = x.size()[2] w_x = x.size()[3] count_h = self.tensor_size(x[:, :, 1:, :]) count_w = self.tensor_size(x[:, :, :, 1:]) h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size @staticmethod def tensor_size(t): """[summary] Parameters ---------- t : [type] [description] Returns ------- [type] [description] """ return t.size()[1] * t.size()[2] * t.size()[3] def ap_per_class(tp, conf, pred_cls, target_cls): """ Compute the average precision, given the recall and precision curves. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. # Arguments tp: True positives (list). conf: Objectness value from 0-1 (list). pred_cls: Predicted object classes (list). target_cls: True object classes (list). # Returns The average precision as computed in py-faster-rcnn. """ # Sort by objectness i = np.argsort(-conf) tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] # Find unique classes unique_classes = np.unique(target_cls) # Create Precision-Recall curve and compute AP for each class ap, p, r = [], [], [] for c in tqdm(unique_classes, desc="Computing AP"): i = pred_cls == c n_gt = (target_cls == c).sum() # Number of ground truth objects n_p = i.sum() # Number of predicted objects if n_p == 0 and n_gt == 0: continue elif n_p == 0 or n_gt == 0: ap.append(0) r.append(0) p.append(0) else: # Accumulate FPs and TPs fpc = (1 - tp[i]).cumsum() tpc = (tp[i]).cumsum() # Recall recall_curve = tpc / (n_gt + 1e-16) r.append(recall_curve[-1]) # Precision precision_curve = tpc / (tpc + fpc) p.append(precision_curve[-1]) # AP from recall-precision curve ap.append(compute_ap(recall_curve, precision_curve)) # Compute F1 score (harmonic mean of precision and recall) p, r, ap = np.array(p), np.array(r), np.array(ap) f1 = 2 * p * r / (p + r + 1e-16) return p, r, ap, f1, unique_classes.astype("int32") def compute_ap(recall, precision): """ Compute the average precision, given the recall and precision curves. Code originally from https://github.com/rbgirshick/py-faster-rcnn. # Arguments recall: The recall curve (list). precision: The precision curve (list). # Returns The average precision as computed in py-faster-rcnn. """ # correct AP calculation # first append sentinel values at the end mrec = np.concatenate(([0.0], recall, [1.0])) mpre = np.concatenate(([0.0], precision, [0.0])) # compute the precision envelope for i in range(mpre.size - 1, 0, -1): mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) # to calculate area under PR curve, look for points # where X axis (recall) changes value i = np.where(mrec[1:] != mrec[:-1])[0] # and sum (\Delta recall) * prec ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) return ap def get_batch_statistics(outputs, targets, iou_threshold): """ Compute true positives, predicted scores and predicted labels per sample """ batch_metrics = [] for sample_i in range(len(outputs)): if outputs[sample_i] is None: continue output = outputs[sample_i] pred_boxes = output[:, :4] pred_scores = output[:, 4] pred_labels = output[:, -1] true_positives = np.zeros(pred_boxes.shape[0]) annotations = targets[targets[:, 0] == sample_i][:, 1:] target_labels = annotations[:, 0] if len(annotations) else [] if len(annotations): detected_boxes = [] target_boxes = annotations[:, 1:] for pred_i, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)): # If targets are found break if len(detected_boxes) == len(annotations): break # Ignore if label is not one of the target labels if pred_label not in target_labels: continue iou, box_index = bbox_iou( pred_box.unsqueeze(0), target_boxes).max(0) if iou >= iou_threshold and box_index not in detected_boxes: true_positives[pred_i] = 1 detected_boxes += [box_index] batch_metrics.append([true_positives, pred_scores, pred_labels]) return batch_metrics def bbox_wh_iou(wh1, wh2): wh2 = wh2.t() w1, h1 = wh1[0], wh1[1] w2, h2 = wh2[0], wh2[1] inter_area = torch.min(w1, w2) * torch.min(h1, h2) union_area = (w1 * h1 + 1e-16) + w2 * h2 - inter_area return inter_area / union_area def bbox_iou(box1, box2, x1y1x2y2=True): """ Returns the IoU of two bounding boxes """ if not x1y1x2y2: # Transform from center and width to exact coordinates b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 else: # Get the coordinates of bounding boxes b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] # get the corrdinates of the intersection rectangle inter_rect_x1 = torch.max(b1_x1, b2_x1) inter_rect_y1 = torch.max(b1_y1, b2_y1) inter_rect_x2 = torch.min(b1_x2, b2_x2) inter_rect_y2 = torch.min(b1_y2, b2_y2) # Intersection area inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp( inter_rect_y2 - inter_rect_y1 + 1, min=0 ) # Union Area b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) iou = inter_area / (b1_area + b2_area - inter_area + 1e-16) return iou def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): """ Removes detections with lower object confidence score than 'conf_thres' and performs Non-Maximum Suppression to further filter detections. Returns detections with shape: (x1, y1, x2, y2, object_conf, class_score, class_pred) """ # From (center x, center y, width, height) to (x1, y1, x2, y2) prediction[..., :4] = xywh2xyxy(prediction[..., :4]) output = [None for _ in range(len(prediction))] for image_i, image_pred in enumerate(prediction): # Filter out confidence scores below threshold image_pred = image_pred[image_pred[:, 4] >= conf_thres] # If none are remaining => process next image if not image_pred.size(0): continue # Object confidence times class confidence score = image_pred[:, 4] * image_pred[:, 5:].max(1)[0] # Sort by it image_pred = image_pred[(-score).argsort()] class_confs, class_preds = image_pred[:, 5:].max(1, keepdim=True) detections = torch.cat( (image_pred[:, :5], class_confs.float(), class_preds.float()), 1) # Perform non-maximum suppression keep_boxes = [] while detections.size(0): large_overlap = bbox_iou(detections[0, :4].unsqueeze( 0), detections[:, :4]) > nms_thres label_match = detections[0, -1] == detections[:, -1] # Indices of boxes with lower confidence scores, large IOUs and matching labels invalid = large_overlap & label_match weights = detections[invalid, 4:5] # Merge overlapping bboxes by order of confidence detections[0, :4] = ( weights * detections[invalid, :4]).sum(0) / weights.sum() keep_boxes += [detections[0]] detections = detections[~invalid] if keep_boxes: output[image_i] = torch.stack(keep_boxes) return output def xywh2xyxy(x): """[summary] Parameters ---------- x : [type] [description] Returns ------- [type] [description] """ y = x.new(x.shape) y[..., 0] = x[..., 0] - x[..., 2] / 2 y[..., 1] = x[..., 1] - x[..., 3] / 2 y[..., 2] = x[..., 0] + x[..., 2] / 2 y[..., 3] = x[..., 1] + x[..., 3] / 2 return y