import os
import numpy as np
import scipy.io
from PIL import Image
from torch.utils.data import ConcatDataset
from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
from super_gradients.common.abstractions.abstract_logger import get_logger
logger = get_logger(__name__)
PASCAL_VOC_2012_CLASSES = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"potted-plant",
"sheep",
"sofa",
"train",
"tv/monitor",
]
[docs]class PascalVOC2012SegmentationDataSet(SegmentationDataSet):
"""
PascalVOC2012SegmentationDataSet - Segmentation Data Set Class for Pascal VOC 2012 Data Set
"""
IGNORE_LABEL = 21
_ORIGINAL_IGNORE_LABEL = 255
def __init__(self, sample_suffix=None, target_suffix=None, *args, **kwargs):
self.sample_suffix = ".jpg" if sample_suffix is None else sample_suffix
self.target_suffix = ".png" if target_suffix is None else target_suffix
super().__init__(*args, **kwargs)
self.classes = PASCAL_VOC_2012_CLASSES
[docs] def decode_segmentation_mask(self, label_mask: np.ndarray):
"""
decode_segmentation_mask - Decodes the colors for the Segmentation Mask
:param: label_mask: an (M,N) array of integer values denoting
the class label at each spatial location.
:return:
"""
label_colours = self._get_pascal_labels()
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
num_classes_to_plot = len(self.classes)
for ll in range(0, num_classes_to_plot):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
rgb[:, :, 0] = r / 255.0
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = b / 255.0
return rgb
def _generate_samples_and_targets(self):
"""
_generate_samples_and_targets
"""
# GENERATE SAMPLES AND TARGETS HERE SPECIFICALLY FOR PASCAL VOC 2012
with open(self.root + os.path.sep + self.list_file_path, "r", encoding="utf-8") as lines:
for line in lines:
image_path = os.path.join(self.root, self.samples_sub_directory, line.rstrip("\n") + self.sample_suffix)
mask_path = os.path.join(self.root, self.targets_sub_directory, line.rstrip("\n") + self.target_suffix)
if os.path.exists(mask_path) and os.path.exists(image_path):
self.samples_targets_tuples_list.append((image_path, mask_path))
# GENERATE SAMPLES AND TARGETS OF THE SEGMENTATION DATA SET CLASS
super()._generate_samples_and_targets()
def _get_pascal_labels(self):
"""Load the mapping that associates pascal classes with label colors
Returns:
np.ndarray with dimensions (21, 3)
"""
return np.asarray(
[
[0, 0, 0],
[128, 0, 0],
[0, 128, 0],
[128, 128, 0],
[0, 0, 128],
[128, 0, 128],
[0, 128, 128],
[128, 128, 128],
[64, 0, 0],
[192, 0, 0],
[64, 128, 0],
[192, 128, 0],
[64, 0, 128],
[192, 0, 128],
[64, 128, 128],
[192, 128, 128],
[0, 64, 0],
[128, 64, 0],
[0, 192, 0],
[128, 192, 0],
[0, 64, 128],
]
)
[docs]class PascalAUG2012SegmentationDataSet(PascalVOC2012SegmentationDataSet):
"""
PascalAUG2012SegmentationDataSet - Segmentation Data Set Class for Pascal AUG 2012 Data Set
"""
def __init__(self, *args, **kwargs):
self.sample_suffix = ".jpg"
self.target_suffix = ".mat"
super().__init__(sample_suffix=self.sample_suffix, target_suffix=self.target_suffix, *args, **kwargs)
[docs] @staticmethod
def target_loader(target_path: str) -> Image:
"""
target_loader
:param target_path: The path to the target data
:return: The loaded target
"""
mat = scipy.io.loadmat(target_path, mat_dtype=True, squeeze_me=True, struct_as_record=False)
mask = mat["GTcls"].Segmentation
return Image.fromarray(mask)
[docs]class PascalVOCAndAUGUnifiedDataset(ConcatDataset):
"""
Pascal VOC + AUG train dataset, aka `SBD` dataset contributed in "Semantic contours from inverse detectors".
This is class implement the common usage of the SBD and PascalVOC datasets as a unified augmented trainset.
The unified dataset includes a total of 10,582 samples and don't contains duplicate samples from the PascalVOC
validation set.
"""
def __init__(self, **kwargs):
print(kwargs)
if any([kwargs.pop("list_file"), kwargs.pop("samples_sub_directory"), kwargs.pop("targets_sub_directory")]):
logger.warning(
"[list_file, samples_sub_directory, targets_sub_directory] arguments passed will not be used"
" when passed to `PascalVOCAndAUGUnifiedDataset`. Those values are predefined for initiating"
" the Pascal VOC + AUG training set."
)
super().__init__(
datasets=[
PascalVOC2012SegmentationDataSet(
list_file="VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt",
samples_sub_directory="VOCdevkit/VOC2012/JPEGImages",
targets_sub_directory="VOCdevkit/VOC2012/SegmentationClass",
**kwargs,
),
PascalAUG2012SegmentationDataSet(
list_file="VOCaug/dataset/aug.txt", samples_sub_directory="VOCaug/dataset/img", targets_sub_directory="VOCaug/dataset/cls", **kwargs
),
]
)