Source code for super_gradients.training.losses.ohem_ce_loss

import torch
from torch import nn
from torch.nn.modules.loss import _Loss
from super_gradients.training.exceptions.loss_exceptions import IllegalRangeForLossAttributeException, RequiredLossComponentReductionException


[docs]class OhemLoss(_Loss): """ OhemLoss - Online Hard Example Mining Cross Entropy Loss """ def __init__(self, threshold: float, mining_percent: float = 0.1, ignore_lb: int = -100, num_pixels_exclude_ignored: bool = True, criteria: _Loss = None): """ :param threshold: Sample below probability threshold, is considered hard. :param num_pixels_exclude_ignored: How to calculate total pixels from which extract mining percent of the samples. :param ignore_lb: label index to be ignored in loss calculation. :param criteria: loss to mine the examples from. i.e for num_pixels=100, ignore_pixels=30, mining_percent=0.1: num_pixels_exclude_ignored=False => num_mining = 100 * 0.1 = 10 num_pixels_exclude_ignored=True => num_mining = (100 - 30) * 0.1 = 7 """ super().__init__() if mining_percent < 0 or mining_percent > 1: raise IllegalRangeForLossAttributeException((0, 1), "mining percent") self.thresh = -torch.log(torch.tensor(threshold, dtype=torch.float)) self.mining_percent = mining_percent self.ignore_lb = ignore_lb self.num_pixels_exclude_ignored = num_pixels_exclude_ignored if criteria.reduction != 'none': raise RequiredLossComponentReductionException("criteria", criteria.reduction, 'none') self.criteria = criteria
[docs] def forward(self, logits, labels): loss = self.criteria(logits, labels).view(-1) if self.num_pixels_exclude_ignored: # remove ignore label elements loss = loss[labels.view(-1) != self.ignore_lb] # num pixels in a batch -> num_pixels = batch_size * width * height - ignore_pixels num_pixels = loss.numel() else: num_pixels = labels.numel() # if all pixels are ignore labels, return empty loss tensor if num_pixels == 0: return torch.tensor([0.]).requires_grad_(True) num_mining = int(self.mining_percent * num_pixels) # in case mining_percent=1, prevent out of bound exception num_mining = min(num_mining, num_pixels - 1) self.thresh = self.thresh.to(logits.device) loss, _ = torch.sort(loss, descending=True) if loss[num_mining] > self.thresh: loss = loss[loss > self.thresh] else: loss = loss[:num_mining] return torch.mean(loss)
[docs]class OhemCELoss(OhemLoss): """ OhemLoss - Online Hard Example Mining Cross Entropy Loss """ def __init__(self, threshold: float, mining_percent: float = 0.1, ignore_lb: int = -100, num_pixels_exclude_ignored: bool = True): ignore_lb = -100 if ignore_lb is None or ignore_lb < 0 else ignore_lb criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') super(OhemCELoss, self).__init__(threshold=threshold, mining_percent=mining_percent, ignore_lb=ignore_lb, num_pixels_exclude_ignored=num_pixels_exclude_ignored, criteria=criteria)
[docs]class OhemBCELoss(OhemLoss): """ OhemBCELoss - Online Hard Example Mining Binary Cross Entropy Loss """ def __init__(self, threshold: float, mining_percent: float = 0.1, ignore_lb: int = -100, num_pixels_exclude_ignored: bool = True, ): super(OhemBCELoss, self).__init__(threshold=threshold, mining_percent=mining_percent, ignore_lb=ignore_lb, num_pixels_exclude_ignored=num_pixels_exclude_ignored, criteria=nn.BCEWithLogitsLoss(reduction='none'))
[docs] def forward(self, logits, labels): # REMOVE SINGLE CLASS CHANNEL WHEN DEALING WITH BINARY DATA if logits.shape[1] == 1: logits = logits.squeeze(1) return super(OhemBCELoss, self).forward(logits, labels.float())