Source code for super_gradients.training.losses.bce_dice_loss

import torch

from super_gradients.training.losses.bce_loss import BCE
from super_gradients.training.losses.dice_loss import BinaryDiceLoss


[docs]class BCEDiceLoss(torch.nn.Module): """ Binary Cross Entropy + Dice Loss Weighted average of BCE and Dice loss Attributes: loss_weights: list of size 2 s.t loss_weights[0], loss_weights[1] are the weights for BCE, Dice respectively. """ def __init__(self, loss_weights=[0.5, 0.5], logits=True): super(BCEDiceLoss, self).__init__() self.loss_weights = loss_weights self.bce = BCE() self.dice = BinaryDiceLoss(apply_sigmoid=logits)
[docs] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @param input: Network's raw output shaped (N,1,H,W) @param target: Ground truth shaped (N,H,W) """ return self.loss_weights[0] * self.bce(input, target) + self.loss_weights[1] * self.dice(input, target)