import torch
from torch import nn
import torch.nn.functional as F
[docs]def onehot(indexes, N=None, ignore_index=None):
"""
Creates a one-hot representation of indexes with N possible entries
if N is not specified, it will suit the maximum index appearing.
indexes is a long-tensor of indexes
ignore_index will be zero in onehot representation
"""
if N is None:
N = indexes.max() + 1
sz = list(indexes.size())
output = indexes.new().byte().resize_(*sz, N).zero_()
output.scatter_(-1, indexes.unsqueeze(-1), 1)
if ignore_index is not None and ignore_index >= 0:
output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0)
return output
def _is_long(x):
if hasattr(x, 'data'):
x = x.data
return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor)
[docs]def cross_entropy(inputs, target, weight=None, ignore_index=-100, reduction='mean', # noqa: C901
smooth_eps=None, smooth_dist=None, from_logits=True):
"""cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567"""
smooth_eps = smooth_eps or 0
# ordinary log-liklihood - use cross_entropy from nn
if _is_long(target) and smooth_eps == 0:
if from_logits:
return F.cross_entropy(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)
else:
return F.nll_loss(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)
if from_logits:
# log-softmax of inputs
lsm = F.log_softmax(inputs, dim=-1)
else:
lsm = inputs
masked_indices = None
num_classes = inputs.size(-1)
if _is_long(target) and ignore_index >= 0:
masked_indices = target.eq(ignore_index)
if smooth_eps > 0 and smooth_dist is not None:
if _is_long(target):
target = onehot(target, num_classes).type_as(inputs)
if smooth_dist.dim() < target.dim():
smooth_dist = smooth_dist.unsqueeze(0)
target.lerp_(smooth_dist, smooth_eps)
if weight is not None:
lsm = lsm * weight.unsqueeze(0)
if _is_long(target):
eps_nll = 1. - smooth_eps
likelihood = lsm.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)
loss = -(eps_nll * likelihood + smooth_eps * lsm.mean(-1))
else:
loss = -(target * lsm).sum(-1)
if masked_indices is not None:
loss.masked_fill_(masked_indices, 0)
if reduction == 'sum':
loss = loss.sum()
elif reduction == 'mean':
if masked_indices is None:
loss = loss.mean()
else:
loss = loss.sum() / float(loss.size(0) - masked_indices.sum())
return loss
[docs]class LabelSmoothingCrossEntropyLoss(nn.CrossEntropyLoss):
"""CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing"""
def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None,
from_logits=True):
super(LabelSmoothingCrossEntropyLoss, self).__init__(weight=weight,
ignore_index=ignore_index, reduction=reduction)
self.smooth_eps = smooth_eps
self.smooth_dist = smooth_dist
self.from_logits = from_logits
[docs] def forward(self, input, target, smooth_dist=None):
if smooth_dist is None:
smooth_dist = self.smooth_dist
loss = cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index,
reduction=self.reduction, smooth_eps=self.smooth_eps,
smooth_dist=smooth_dist, from_logits=self.from_logits)
# CHANGED TO THE CURRENT FORMAT- OUR CRITERION FUNCTIONS SHOULD ALL NPW RETURN A TUPLE OF (LOSS_FOR_BACKPROP, ADDITIONAL_ITEMS)
# WHERE ADDITIONAL ITEMS ARE TORCH TENSORS OF SIZE (N_ITEMS,...) DETACHED FROM THEIR GRADIENTS FOR LOGGING
return loss, loss.unsqueeze(0).detach()