import torch
from torch import nn
from torch.nn.modules.loss import _Loss
from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix
from super_gradients.training.utils.ssd_utils import DefaultBoxes
[docs]class SSDLoss(_Loss):
"""
Implements the loss as the sum of the followings:
1. Confidence Loss: All labels, with hard negative mining
2. Localization Loss: Only on positive labels
"""
def __init__(self, dboxes: DefaultBoxes, alpha: float = 1.0):
super(SSDLoss, self).__init__()
self.scale_xy = 1.0 / dboxes.scale_xy
self.scale_wh = 1.0 / dboxes.scale_wh
self.alpha = alpha
self.sl1_loss = nn.SmoothL1Loss(reduce=False)
self.dboxes = nn.Parameter(dboxes(order="xywh").transpose(0, 1).unsqueeze(dim=0), requires_grad=False)
self.con_loss = nn.CrossEntropyLoss(reduce=False)
def _norm_relative_bbox(self, loc):
"""
convert bbox locations into relative locations (relative to the dboxes) and normalized by w,h
:param loc a tensor of shape [batch, 4, num_boxes]
"""
gxy = self.scale_xy * (loc[:, :2, :] - self.dboxes[:, :2, :]) / self.dboxes[:, 2:, ]
gwh = self.scale_wh * (loc[:, 2:, :] / self.dboxes[:, 2:, :]).log()
return torch.cat((gxy, gwh), dim=1).contiguous()
[docs] def match_dboxes(self, targets):
"""
convert ground truth boxes into a tensor with the same size as dboxes. each gt bbox is matched to every
destination box which overlaps it over 0.5 (IoU). so some gt bboxes can be duplicated to a few destination boxes
:param targets: a tensor containing the boxes for a single image. shape [num_boxes, 5] (x,y,w,h,label)
:return: two tensors
boxes - shape of dboxes [4, num_dboxes] (x,y,w,h)
labels - sahpe [num_dboxes]
"""
target_locations = self.dboxes.data.clone().squeeze()
target_labels = torch.zeros((self.dboxes.data.shape[2])).to(self.dboxes.device)
if len(targets) > 0:
boxes = targets[:, 2:]
ious = calculate_bbox_iou_matrix(boxes, self.dboxes.data.squeeze().T, x1y1x2y2=False)
values, indices = torch.max(ious, dim=0)
mask = values > 0.5
target_locations[:, mask] = targets[indices[mask], 2:].T
target_labels[mask] = targets[indices[mask], 1]
return target_locations, target_labels
[docs] def forward(self, predictions, targets):
"""
Compute the loss
:param predictions - predictions tensor coming from the network. shape [N, num_classes+4, num_dboxes]
were the first four items are (x,y,w,h) and the rest are class confidence
:param targets - targets for the batch. [num targets, 6] (index in batch, label, x,y,w,h)
"""
batch_target_locations = []
batch_target_labels = []
(ploc, plabel) = predictions
targets = targets.to(self.dboxes.device)
for i in range(ploc.shape[0]):
target_locations, target_labels = self.match_dboxes(targets[targets[:, 0] == i])
batch_target_locations.append(target_locations)
batch_target_labels.append(target_labels)
batch_target_locations = torch.stack(batch_target_locations)
batch_target_labels = torch.stack(batch_target_labels).type(torch.long)
mask = batch_target_labels > 0
pos_num = mask.sum(dim=1)
vec_gd = self._norm_relative_bbox(batch_target_locations)
# SUM ON FOUR COORDINATES, AND MASK
sl1 = self.sl1_loss(ploc, vec_gd).sum(dim=1)
sl1 = (mask.float() * sl1).sum(dim=1)
# HARD NEGATIVE MINING
con = self.con_loss(plabel, batch_target_labels)
# POSITIVE MASK WILL NEVER SELECTED
con_neg = con.clone()
con_neg[mask] = 0
_, con_idx = con_neg.sort(dim=1, descending=True)
_, con_rank = con_idx.sort(dim=1)
# NUMBER OF NEGATIVE THREE TIMES POSITIVE
neg_num = torch.clamp(3 * pos_num, max=mask.size(1)).unsqueeze(-1)
neg_mask = con_rank < neg_num
closs = (con * (mask.float() + neg_mask.float())).sum(dim=1)
# AVOID NO OBJECT DETECTED
total_loss = (2 - self.alpha) * sl1 + self.alpha * closs
num_mask = (pos_num > 0).float()
pos_num = pos_num.float().clamp(min=1e-6)
ret = (total_loss * num_mask / pos_num).mean(dim=0)
return ret, torch.cat((sl1.mean().unsqueeze(0), closs.mean().unsqueeze(0), ret.unsqueeze(0))).detach()