Source code for super_gradients.training.losses.yolox_loss

"""
Based on https://github.com/Megvii-BaseDetection/YOLOX (Apache-2.0 license)

"""

import logging
from typing import List, Tuple, Union

import torch
from torch import nn
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix

logger = get_logger(__name__)


class IOUloss(nn.Module):
    """
    IoU loss with the following supported loss types:
    Attributes:
        reduction: str: One of ["mean", "sum", "none"] reduction to apply to the computed loss (Default="none")
        loss_type: str: One of ["iou", "giou"] where:
            * 'iou' for
                (1 - iou^2)
            * 'giou' according to "Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression"
                (1 - giou), where giou = iou - (cover_box - union_box)/cover_box
    """

    def __init__(self, reduction: str = "none", loss_type: str = "iou"):
        super(IOUloss, self).__init__()
        self._validate_args(loss_type, reduction)
        self.reduction = reduction
        self.loss_type = loss_type

    @staticmethod
    def _validate_args(loss_type, reduction):
        supported_losses = ["iou", "giou"]
        supported_reductions = ["mean", "sum", "none"]
        if loss_type not in supported_losses:
            raise ValueError("Illegal loss_type value: " + loss_type + ', expected one of: ' + str(supported_losses))
        if reduction not in supported_reductions:
            raise ValueError(
                "Illegal reduction value: " + reduction + ', expected one of: ' + str(supported_reductions))

    def forward(self, pred, target):
        assert pred.shape[0] == target.shape[0]

        pred = pred.view(-1, 4)
        target = target.view(-1, 4)
        tl = torch.max((pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2))
        br = torch.min((pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2))

        area_p = torch.prod(pred[:, 2:], 1)
        area_g = torch.prod(target[:, 2:], 1)

        en = (tl < br).type(tl.type()).prod(dim=1)
        area_i = torch.prod(br - tl, 1) * en
        area_u = area_p + area_g - area_i
        iou = (area_i) / (area_u + 1e-16)

        if self.loss_type == "iou":
            loss = 1 - iou ** 2
        elif self.loss_type == "giou":
            c_tl = torch.min((pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2))
            c_br = torch.max((pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2))
            area_c = torch.prod(c_br - c_tl, 1)
            giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
            loss = 1 - giou.clamp(min=-1.0, max=1.0)

        if self.reduction == "mean":
            loss = loss.mean()
        elif self.reduction == "sum":
            loss = loss.sum()

        return loss


[docs]class YoloXDetectionLoss(_Loss): """ Calculate YOLOX loss: L = L_objectivness + L_iou + L_classification + 1[use_l1]*L_l1 where: * L_iou, L_classification and L_l1 are calculated only between cells and targets that suit them; * L_objectivness is calculated for all cells. L_classification: for cells that have suitable ground truths in their grid locations add BCEs to force a prediction of IoU with a GT in a multi-label way Coef: 1. L_iou: for cells that have suitable ground truths in their grid locations add (1 - IoU^2), IoU between a predicted box and each GT box, force maximum IoU Coef: 5. L_l1: for cells that have suitable ground truths in their grid locations l1 distance between the logits and GTs in “logits” format (the inverse of “logits to predictions” ops) Coef: 1[use_l1] L_objectness: for each cell add BCE with a label of 1 if there is GT assigned to the cell Coef: 1 Attributes: strides: list: List of Yolo levels output grid sizes (i.e [8, 16, 32]). num_classes: int: Number of classes. use_l1: bool: Controls the L_l1 Coef as discussed above (default=False). center_sampling_radius: float: Sampling radius used for center sampling when creating the fg mask (default=2.5). iou_type: str: Iou loss type, one of ["iou","giou"] (deafult="iou"). """ def __init__(self, strides: list, num_classes: int, use_l1: bool = False, center_sampling_radius: float = 2.5, iou_type='iou'): super().__init__() self.grids = [torch.zeros(1)] * len(strides) self.strides = strides self.num_classes = num_classes self.center_sampling_radius = center_sampling_radius self.use_l1 = use_l1 self.l1_loss = nn.L1Loss(reduction="none") self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") self.iou_loss = IOUloss(reduction="none", loss_type=iou_type)
[docs] def forward(self, model_output: Union[list, Tuple[torch.Tensor, List]], targets: torch.Tensor): """ :param model_output: Union[list, Tuple[torch.Tensor, List]]: When list- output from all Yolo levels, each of shape [Batch x 1 x GridSizeY x GridSizeX x (4 + 1 + Num_classes)] And when tuple- the second item is the described list (first item is discarded) :param targets: torch.Tensor: Num_targets x (4 + 2)], values on dim 1 are: image id in a batch, class, box x y w h :return: loss, all losses separately in a detached tensor """ if isinstance(model_output, tuple) and len(model_output) == 2: # in test/eval mode the Yolo model outputs a tuple where the second item is the raw predictions _, predictions = model_output else: predictions = model_output return self._compute_loss(predictions, targets)
@staticmethod def _make_grid(nx=20, ny=20): """ Creates a tensor of xy coordinates of size (1,1,nx,ny,2) :param nx: int: cells along x axis (default=20) :param ny: int: cells along the y axis (default=20) :return: torch.tensor of xy coordinates of size (1,1,nx,ny,2) """ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() def _compute_loss(self, predictions: List[torch.Tensor], targets: torch.Tensor) \ -> Tuple[torch.Tensor, torch.Tensor]: """ :param predictions: output from all Yolo levels, each of shape [Batch x 1 x GridSizeY x GridSizeX x (4 + 1 + Num_classes)] :param targets: [Num_targets x (4 + 2)], values on dim 1 are: image id in a batch, class, box x y w h :return: loss, all losses separately in a detached tensor """ x_shifts, y_shifts, expanded_strides, transformed_outputs, raw_outputs = self.prepare_predictions(predictions) bbox_preds = transformed_outputs[:, :, :4] # [batch, n_anchors_all, 4] obj_preds = transformed_outputs[:, :, 4:5] # [batch, n_anchors_all, 1] cls_preds = transformed_outputs[:, :, 5:] # [batch, n_anchors_all, n_cls] # calculate targets total_num_anchors = transformed_outputs.shape[1] cls_targets = [] reg_targets = [] l1_targets = [] obj_targets = [] fg_masks = [] num_fg, num_gts = 0., 0. for image_idx in range(transformed_outputs.shape[0]): labels_im = targets[targets[:, 0] == image_idx] num_gt = labels_im.shape[0] num_gts += num_gt if num_gt == 0: cls_target = transformed_outputs.new_zeros((0, self.num_classes)) reg_target = transformed_outputs.new_zeros((0, 4)) l1_target = transformed_outputs.new_zeros((0, 4)) obj_target = transformed_outputs.new_zeros((total_num_anchors, 1)) fg_mask = transformed_outputs.new_zeros(total_num_anchors).bool() else: # GT boxes to image coordinates gt_bboxes_per_image = labels_im[:, 2:6].clone() gt_classes = labels_im[:, 1] bboxes_preds_per_image = bbox_preds[image_idx] try: # assign cells to ground truths, at most one GT per cell gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = \ self.get_assignments(image_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, obj_preds) # TODO: CHECK IF ERROR IS CUDA OUT OF MEMORY except RuntimeError: logging.error("OOM RuntimeError is raised due to the huge memory cost during label assignment. \ CPU mode is applied in this batch. If you want to avoid this issue, \ try to reduce the batch size or image size.") torch.cuda.empty_cache() gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = \ self.get_assignments(image_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, obj_preds, 'cpu') torch.cuda.empty_cache() num_fg += num_fg_img cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes) * pred_ious_this_matching.unsqueeze(-1) obj_target = fg_mask.unsqueeze(-1) reg_target = gt_bboxes_per_image[matched_gt_inds] if self.use_l1: l1_target = self.get_l1_target(transformed_outputs.new_zeros((num_fg_img, 4)), gt_bboxes_per_image[matched_gt_inds], expanded_strides[0][fg_mask], x_shifts=x_shifts[0][fg_mask], y_shifts=y_shifts[0][fg_mask]) # collect targets for all loss terms over the whole batch cls_targets.append(cls_target) reg_targets.append(reg_target) obj_targets.append(obj_target.to(transformed_outputs.dtype)) fg_masks.append(fg_mask) if self.use_l1: l1_targets.append(l1_target) # concat all targets over the batch (get rid of batch dim) cls_targets = torch.cat(cls_targets, 0) reg_targets = torch.cat(reg_targets, 0) obj_targets = torch.cat(obj_targets, 0) fg_masks = torch.cat(fg_masks, 0) if self.use_l1: l1_targets = torch.cat(l1_targets, 0) num_fg = max(num_fg, 1) # loss terms divided by the total number of foregrounds loss_iou = self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets).sum() / num_fg loss_obj = self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets).sum() / num_fg loss_cls = self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets).sum() / num_fg if self.use_l1: loss_l1 = self.l1_loss(raw_outputs.view(-1, 4)[fg_masks], l1_targets).sum() / num_fg else: loss_l1 = 0.0 reg_weight = 5.0 loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1 return loss, torch.cat((loss_iou.unsqueeze(0), loss_obj.unsqueeze(0), loss_cls.unsqueeze(0), torch.tensor(loss_l1).unsqueeze(0).to(loss.device), torch.tensor(num_fg / max(num_gts, 1)).unsqueeze(0).to(loss.device), loss.unsqueeze(0))).detach()
[docs] def prepare_predictions(self, predictions: List[torch.Tensor]) -> \ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Convert raw outputs of the network into a format that merges outputs from all levels :param predictions: output from all Yolo levels, each of shape [Batch x 1 x GridSizeY x GridSizeX x (4 + 1 + Num_classes)] :return: 5 tensors representing predictions: * x_shifts: shape [1 x * num_cells x 1], where num_cells = grid1X * grid1Y + grid2X * grid2Y + grid3X * grid3Y, x coordinate on the grid cell the prediction is coming from * y_shifts: shape [1 x num_cells x 1], y coordinate on the grid cell the prediction is coming from * expanded_strides: shape [1 x num_cells x 1], stride of the output grid the prediction is coming from * transformed_outputs: shape [batch_size x num_cells x (num_classes + 5)], predictions with boxes in real coordinates and logprobabilities * raw_outputs: shape [batch_size x num_cells x (num_classes + 5)], raw predictions with boxes and confidences as logits """ raw_outputs = [] transformed_outputs = [] x_shifts = [] y_shifts = [] expanded_strides = [] for k, output in enumerate(predictions): batch_size, num_anchors, h, w, num_outputs = output.shape # IN FIRST PASS CREATE GRIDS ACCORDING TO OUTPUT SHAPE (BATCH,1,IMAGE_H/STRIDE,IMAGE_2/STRIDE,NUM_CLASSES+5) if self.grids[k].shape[2:4] != output.shape[2:4]: self.grids[k] = self._make_grid(w, h).type_as(output) # e.g. [batch_size, 1, 28, 28, 85] -> [batch_size, 784, 85] output_raveled = output.reshape(batch_size, num_anchors * h * w, num_outputs) # e.g [1, 784, 2] grid_raveled = self.grids[k].view(1, num_anchors * h * w, 2) if self.use_l1: # e.g [1, 784, 4] raw_outputs.append(output_raveled[:, :, :4].clone()) # box logits to coordinates centers = (output_raveled[..., :2] + grid_raveled) * self.strides[k] wh = torch.exp(output_raveled[..., 2:4]) * self.strides[k] classes = output_raveled[..., 4:] output_raveled = torch.cat([centers, wh, classes], -1) # outputs with boxes in real coordinates, probs as logits transformed_outputs.append(output_raveled) # x cell coordinates of all 784 predictions, 0, 0, 0, ..., 1, 1, 1, ... x_shifts.append(grid_raveled[:, :, 0]) # y cell coordinates of all 784 predictions, 0, 1, 2, ..., 0, 1, 2, ... y_shifts.append(grid_raveled[:, :, 1]) # e.g. [1, 784, stride of this level (one of [8, 16, 32])] expanded_strides.append(torch.zeros(1, grid_raveled.shape[1]).fill_(self.strides[k]).type_as(output)) # all 4 below have shapes of [batch_size , num_cells, num_values_pre_cell] # where num_anchors * num_cells is e.g. 1 * (28 * 28 + 14 * 14 + 17 * 17) transformed_outputs = torch.cat(transformed_outputs, 1) x_shifts = torch.cat(x_shifts, 1) y_shifts = torch.cat(y_shifts, 1) expanded_strides = torch.cat(expanded_strides, 1) if self.use_l1: raw_outputs = torch.cat(raw_outputs, 1) return x_shifts, y_shifts, expanded_strides, transformed_outputs, raw_outputs
[docs] def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8): """ :param l1_target: tensor of zeros of shape [Num_cell_gt_pairs x 4] :param gt: targets in coordinates [Num_cell_gt_pairs x (4 + 1 + num_classes)] :return: targets in the format corresponding to logits """ l1_target[:, 0] = gt[:, 0] / stride - x_shifts l1_target[:, 1] = gt[:, 1] / stride - y_shifts l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps) l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps) return l1_target
[docs] @torch.no_grad() def get_assignments(self, image_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, obj_preds, mode="gpu", ious_loss_cost_coeff=3.0, outside_boxes_and_center_cost_coeff=100000.0): """ Match cells to ground truth: * at most 1 GT per cell * dynamic number of cells per GT :param outside_boxes_and_center_cost_coeff: float: Cost coefficiant of cells the radius and bbox of gts in dynamic matching (default=100000). :param ious_loss_cost_coeff: float: Cost coefficiant for iou loss in dynamic matching (default=3). :param image_idx: int: Image index in batch. :param num_gt: int: Number of ground trunth targets in the image. :param total_num_anchors: int: Total number of possible bboxes = sum of all grid cells. :param gt_bboxes_per_image: torch.Tensor: Tensor of gt bboxes for the image, shape: (num_gt, 4). :param gt_classes: torch.Tesnor: Tensor of the classes in the image, shape: (num_preds,4). :param bboxes_preds_per_image: Tensor of the classes in the image, shape: (num_preds). :param expanded_strides: torch.Tensor: Stride of the output grid the prediction is coming from, shape (1 x num_cells x 1). :param x_shifts: torch.Tensor: X's in cell coordinates, shape (1,num_cells,1). :param y_shifts: torch.Tensor: Y's in cell coordinates, shape (1,num_cells,1). :param cls_preds: torch.Tensor: Class predictions in all cells, shape (batch_size, num_cells). :param obj_preds: torch.Tensor: Objectness predictions in all cells, shape (batch_size, num_cells). :param mode: str: One of ["gpu","cpu"], Controls the device the assignment operation should be taken place on (deafult="gpu") """ if mode == "cpu": print("------------CPU Mode for This Batch-------------") gt_bboxes_per_image = gt_bboxes_per_image.cpu().float() bboxes_preds_per_image = bboxes_preds_per_image.cpu().float() gt_classes = gt_classes.cpu().float() expanded_strides = expanded_strides.cpu().float() x_shifts = x_shifts.cpu() y_shifts = y_shifts.cpu() # create a mask for foreground cells fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt) bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] cls_preds_ = cls_preds[image_idx][fg_mask] obj_preds_ = obj_preds[image_idx][fg_mask] num_in_boxes_anchor = bboxes_preds_per_image.shape[0] if mode == "cpu": gt_bboxes_per_image = gt_bboxes_per_image.cpu() bboxes_preds_per_image = bboxes_preds_per_image.cpu() # calculate cost between all foregrounds and all ground truths (used only for matching) pair_wise_ious = calculate_bbox_iou_matrix(gt_bboxes_per_image, bboxes_preds_per_image, x1y1x2y2=False) gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes) gt_cls_per_image = gt_cls_per_image.float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1) pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) if mode == "cpu": cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu() with torch.cuda.amp.autocast(enabled=False): cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1) del cls_preds_ cost = pair_wise_cls_loss + ious_loss_cost_coeff * pair_wise_ious_loss + outside_boxes_and_center_cost_coeff * ( ~is_in_boxes_and_center) # further filter foregrounds: create pairs between cells and ground truth, based on cost and IoUs num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = \ self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask) # discard tensors related to cost del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss if mode == "cpu": gt_matched_classes = gt_matched_classes.cuda() fg_mask = fg_mask.cuda() pred_ious_this_matching = pred_ious_this_matching.cuda() matched_gt_inds = matched_gt_inds.cuda() return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg
[docs] def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt): """ Create a mask for all cells, mask in only foreground: cells that have a center located: * withing a GT box; OR * within a fixed radius around a GT box (center sampling); :param num_gt: int: Number of ground trunth targets in the image. :param total_num_anchors: int: Sum of all grid cells. :param gt_bboxes_per_image: torch.Tensor: Tensor of gt bboxes for the image, shape: (num_gt, 4). :param expanded_strides: torch.Tensor: Stride of the output grid the prediction is coming from, shape (1 x num_cells x 1). :param x_shifts: torch.Tensor: X's in cell coordinates, shape (1,num_cells,1). :param y_shifts: torch.Tensor: Y's in cell coordinates, shape (1,num_cells,1). :return is_in_boxes_anchor, is_in_boxes_and_center where: - is_in_boxes_anchor masks the cells that their cell center is inside a gt bbox and within self.center_sampling_radius cells away, without reduction (i.e shape=(num_gts, num_fgs)) - is_in_boxes_and_center masks the cells that their center is either inside a gt bbox or within self.center_sampling_radius cells away, shape (num_fgs) """ expanded_strides_per_image = expanded_strides[0] # cell coordinates, shape [n_predictions] -> repeated to [n_gts, n_predictions] x_shifts_per_image = x_shifts[0] * expanded_strides_per_image y_shifts_per_image = y_shifts[0] * expanded_strides_per_image x_centers_per_image = (x_shifts_per_image + 0.5 * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1) y_centers_per_image = (y_shifts_per_image + 0.5 * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1) # FIND CELL CENTERS THAT ARE WITHIN GROUND TRUTH BOXES # ground truth boxes, shape [n_gts] -> repeated to [n_gts, n_predictions] # from (c1, c2, w, h) to left, right, top, bottom gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors) gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors) gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors) gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors) # check which cell centers lay within the ground truth boxes b_l = x_centers_per_image - gt_bboxes_per_image_l # x - l > 0 when l is on the lest from x b_r = gt_bboxes_per_image_r - x_centers_per_image b_t = y_centers_per_image - gt_bboxes_per_image_t b_b = gt_bboxes_per_image_b - y_centers_per_image bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2) # shape [n_gts, n_predictions] # to claim that a cell center is inside a gt box all 4 differences calculated above should be positive is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0 # shape [n_gts, n_predictions] is_in_boxes_all = is_in_boxes.sum(dim=0) > 0 # shape [n_predictions], whether a cell is inside at least one gt # FIND CELL CENTERS THAT ARE WITHIN +- self.center_sampling_radius CELLS FROM GROUND TRUTH BOXES CENTERS # define fake boxes: instead of ground truth boxes step +- self.center_sampling_radius from their centers gt_bboxes_per_image_l = ((gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0)) gt_bboxes_per_image_r = ((gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0)) gt_bboxes_per_image_t = ((gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0)) gt_bboxes_per_image_b = ((gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + self.center_sampling_radius * expanded_strides_per_image.unsqueeze(0)) c_l = x_centers_per_image - gt_bboxes_per_image_l c_r = gt_bboxes_per_image_r - x_centers_per_image c_t = y_centers_per_image - gt_bboxes_per_image_t c_b = gt_bboxes_per_image_b - y_centers_per_image center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) is_in_centers = center_deltas.min(dim=-1).values > 0.0 is_in_centers_all = is_in_centers.sum(dim=0) > 0 # in boxes OR in centers is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all # in boxes AND in centers, preserving a shape [num_GTs x num_FGs] is_in_boxes_and_center = (is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]) return is_in_boxes_anchor, is_in_boxes_and_center
[docs] def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): """ :param cost: pairwise cost, [num_FGs x num_GTs] :param pair_wise_ious: pairwise IoUs, [num_FGs x num_GTs] :param gt_classes: class of each GT :param num_gt: number of GTs :return num_fg, (number of foregrounds) gt_matched_classes, (the classes that have been matched with fgs) pred_ious_this_matching matched_gt_inds """ # create a matrix with shape [num_GTs x num_FGs] matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) # for each GT get a dynamic k of foregrounds with a minimum cost: k = int(sum[top 10 IoUs]) ious_in_boxes_matrix = pair_wise_ious n_candidate_k = min(10, ious_in_boxes_matrix.size(1)) topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1) dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) dynamic_ks = dynamic_ks.tolist() for gt_idx in range(num_gt): try: _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx], largest=False) except Exception: logger.warning("cost[gt_idx]: " + str(cost[gt_idx]) + " dynamic_ks[gt_idx]L " + str(dynamic_ks[gt_idx])) matching_matrix[gt_idx][pos_idx] = 1 del topk_ious, dynamic_ks, pos_idx # leave at most one GT per foreground, chose the one with the smallest cost anchor_matching_gt = matching_matrix.sum(0) if (anchor_matching_gt > 1).sum() > 0: _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) matching_matrix[:, anchor_matching_gt > 1] *= 0 matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1 fg_mask_inboxes = matching_matrix.sum(0) > 0 num_fg = fg_mask_inboxes.sum().item() fg_mask[fg_mask.clone()] = fg_mask_inboxes matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) gt_matched_classes = gt_classes[matched_gt_inds] pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes] return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds