import numpy as np
import torch
import torchmetrics
from torchmetrics import Metric
from typing import Optional, Tuple
from torchmetrics.utilities.distributed import reduce
from abc import ABC, abstractmethod
def batch_pix_accuracy(predict, target):
"""Batch Pixel Accuracy
Args:
predict: input 4D tensor
target: label 3D tensor
"""
_, predict = torch.max(predict, 1)
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
pixel_labeled = np.sum(target > 0)
pixel_correct = np.sum((predict == target) * (target > 0))
assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
def batch_intersection_union(predict, target, nclass):
"""Batch Intersection of Union
Args:
predict: input 4D tensor
target: label 3D tensor
nclass: number of categories (int)
"""
_, predict = torch.max(predict, 1)
mini = 1
maxi = nclass
nbins = nclass
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
predict = predict * (target > 0).astype(predict.dtype)
intersection = predict * (predict == target)
# areas of intersection and union
area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
area_union = area_pred + area_lab - area_inter
assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area"
return area_inter, area_union
# ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py
def pixel_accuracy(im_pred, im_lab):
im_pred = np.asarray(im_pred)
im_lab = np.asarray(im_lab)
# Remove classes from unlabeled pixels in gt image.
# We should not penalize detections in unlabeled portions of the image.
pixel_labeled = np.sum(im_lab > 0)
pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0))
# pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
return pixel_correct, pixel_labeled
def _dice_from_confmat(
confmat: torch.Tensor,
num_classes: int,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
reduction: str = "elementwise_mean",
) -> torch.Tensor:
"""Computes Dice coefficient from confusion matrix.
Args:
confmat: Confusion matrix without normalization
num_classes: Number of classes for a given prediction and target tensor
ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method.
absent_score: score to use for an individual class, if no instances of the class index were present in `pred`
AND no instances of the class index were present in `target`.
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
# Remove the ignored class index from the scores.
if ignore_index is not None and 0 <= ignore_index < num_classes:
confmat[ignore_index] = 0.0
intersection = torch.diag(confmat)
denominator = confmat.sum(0) + confmat.sum(1)
# If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
scores = 2 * intersection.float() / denominator.float()
scores[denominator == 0] = absent_score
if ignore_index is not None and 0 <= ignore_index < num_classes:
scores = torch.cat(
[
scores[:ignore_index],
scores[ignore_index + 1 :],
]
)
return reduce(scores, reduction=reduction)
def intersection_and_union(im_pred, im_lab, num_class):
im_pred = np.asarray(im_pred)
im_lab = np.asarray(im_lab)
# Remove classes from unlabeled pixels in gt image.
im_pred = im_pred * (im_lab > 0)
# Compute area intersection:
intersection = im_pred * (im_pred == im_lab)
area_inter, _ = np.histogram(intersection, bins=num_class - 1, range=(1, num_class - 1))
# Compute area union:
area_pred, _ = np.histogram(im_pred, bins=num_class - 1, range=(1, num_class - 1))
area_lab, _ = np.histogram(im_lab, bins=num_class - 1, range=(1, num_class - 1))
area_union = area_pred + area_lab - area_inter
return area_inter, area_union
class AbstractMetricsArgsPrepFn(ABC):
"""
Abstract preprocess metrics arguments class.
"""
@abstractmethod
def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
All base classes must implement this function and return a tuple of torch tensors (predictions, target).
"""
raise NotImplementedError()
[docs]class PreprocessSegmentationMetricsArgs(AbstractMetricsArgsPrepFn):
"""
Default segmentation inputs preprocess function before updating segmentation metrics, handles multiple inputs and
apply normalizations.
"""
def __init__(self, apply_arg_max: bool = False, apply_sigmoid: bool = False):
"""
:param apply_arg_max: Whether to apply argmax on predictions tensor.
:param apply_sigmoid: Whether to apply sigmoid on predictions tensor.
"""
self.apply_arg_max = apply_arg_max
self.apply_sigmoid = apply_sigmoid
def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# WHEN DEALING WITH MULTIPLE OUTPUTS- OUTPUTS[0] IS THE MAIN SEGMENTATION MAP
if isinstance(preds, (tuple, list)):
preds = preds[0]
if self.apply_arg_max:
_, preds = torch.max(preds, 1)
elif self.apply_sigmoid:
preds = torch.sigmoid(preds)
target = target.long()
return preds, target
[docs]class PixelAccuracy(Metric):
def __init__(self, ignore_label=-100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.ignore_label = ignore_label
self.greater_is_better = True
self.add_state("total_correct", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total_label", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True)
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor):
predict, target = self.metrics_args_prep_fn(preds, target)
labeled_mask = target.ne(self.ignore_label)
pixel_labeled = torch.sum(labeled_mask)
pixel_correct = torch.sum((predict == target) * labeled_mask)
self.total_correct += pixel_correct
self.total_label += pixel_labeled
[docs] def compute(self):
_total_correct = self.total_correct.cpu().detach().numpy().astype("int64")
_total_label = self.total_label.cpu().detach().numpy().astype("int64")
pix_acc = np.float64(1.0) * _total_correct / (np.spacing(1, dtype=np.float64) + _total_label)
return pix_acc
[docs]class IoU(torchmetrics.JaccardIndex):
def __init__(
self,
num_classes: int,
dist_sync_on_step: bool = False,
ignore_index: Optional[int] = None,
reduction: str = "elementwise_mean",
threshold: float = 0.5,
metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None,
):
if num_classes <= 1:
raise ValueError(f"IoU class only for multi-class usage! For binary usage, please call {BinaryIOU.__name__}")
super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold)
self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True)
self.greater_is_better = True
[docs] def update(self, preds, target: torch.Tensor):
preds, target = self.metrics_args_prep_fn(preds, target)
super().update(preds=preds, target=target)
[docs]class Dice(torchmetrics.JaccardIndex):
def __init__(
self,
num_classes: int,
dist_sync_on_step: bool = False,
ignore_index: Optional[int] = None,
reduction: str = "elementwise_mean",
threshold: float = 0.5,
metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None,
):
if num_classes <= 1:
raise ValueError(f"Dice class only for multi-class usage! For binary usage, please call {BinaryDice.__name__}")
super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold)
self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True)
self.greater_is_better = True
[docs] def update(self, preds, target: torch.Tensor):
preds, target = self.metrics_args_prep_fn(preds, target)
super().update(preds=preds, target=target)
[docs] def compute(self) -> torch.Tensor:
"""Computes Dice coefficient"""
return _dice_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction)
[docs]class BinaryIOU(IoU):
def __init__(
self,
dist_sync_on_step=True,
ignore_index: Optional[int] = None,
threshold: float = 0.5,
metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None,
):
metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_sigmoid=True)
super().__init__(
num_classes=2,
dist_sync_on_step=dist_sync_on_step,
ignore_index=ignore_index,
reduction="none",
threshold=threshold,
metrics_args_prep_fn=metrics_args_prep_fn,
)
self.greater_component_is_better = {
"target_IOU": True,
"background_IOU": True,
"mean_IOU": True,
}
self.component_names = list(self.greater_component_is_better.keys())
[docs] def compute(self):
ious = super(BinaryIOU, self).compute()
return {"target_IOU": ious[1], "background_IOU": ious[0], "mean_IOU": ious.mean()}
[docs]class BinaryDice(Dice):
def __init__(
self,
dist_sync_on_step=True,
ignore_index: Optional[int] = None,
threshold: float = 0.5,
metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None,
):
metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_sigmoid=True)
super().__init__(
num_classes=2,
dist_sync_on_step=dist_sync_on_step,
ignore_index=ignore_index,
reduction="none",
threshold=threshold,
metrics_args_prep_fn=metrics_args_prep_fn,
)
self.greater_component_is_better = {
"target_Dice": True,
"background_Dice": True,
"mean_Dice": True,
}
self.component_names = list(self.greater_component_is_better.keys())
[docs] def compute(self):
dices = super().compute()
return {"target_Dice": dices[1], "background_Dice": dices[0], "mean_Dice": dices.mean()}