import copy
import os
from abc import ABC, abstractmethod
from multiprocessing import Value, Lock
import random
import numpy as np
import torch.nn.functional as F
import torchvision
from PIL import Image
import torch
from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataSet
from super_gradients.common.abstractions.abstract_logger import get_logger
from deprecated import deprecated
from matplotlib.patches import Rectangle
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from super_gradients.training.datasets.auto_augment import rand_augment_transform
from torchvision.transforms import transforms, InterpolationMode, RandomResizedCrop
from tqdm import tqdm
from super_gradients.training.utils.utils import AverageMeter
from super_gradients.training.utils.detection_utils import DetectionVisualization
import matplotlib.pyplot as plt
[docs]def get_mean_and_std_torch(data_dir=None, dataloader=None, num_workers=4, RandomResizeSize=224):
"""
A function for getting the mean and std of large datasets using pytorch dataloader and gpu functionality.
:param data_dir: String, path to none-library dataset folder. For example "/data/Imagenette" or "/data/TinyImagenet"
:param dataloader: a torch DataLoader, as it would feed the data into the trainer (including transforms etc).
:param RandomResizeSize: Int, the size of the RandomResizeCrop as it appears in the DataInterface (for example, for Imagenet,
this value should be 224).
:return: 2 lists,mean and std, each one of len 3 (1 for each channel)
"""
assert data_dir is None or dataloader is None, 'Please provide either path to data folder or DataLoader, not both.'
if dataloader is None:
traindir = os.path.join(os.path.abspath(data_dir), 'train')
trainset = ImageFolder(traindir, transforms.Compose([transforms.RandomResizedCrop(RandomResizeSize),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()]))
dataloader = torch.utils.data.DataLoader(trainset, batch_size=1, num_workers=num_workers)
print(f'Calculating on {len(dataloader.dataset.targets)} Training Samples')
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
h, w = 0, 0
for batch_idx, (inputs, targets) in enumerate(dataloader):
inputs = inputs.to(device)
if batch_idx == 0:
h, w = inputs.size(2), inputs.size(3)
print(f'Min: {inputs.min()}, Max: {inputs.max()}')
chsum = inputs.sum(dim=(0, 2, 3), keepdim=True)
else:
chsum += inputs.sum(dim=(0, 2, 3), keepdim=True)
mean = chsum / len(trainset) / h / w
print(f'mean: {mean.view(-1)}')
chsum = None
for batch_idx, (inputs, targets) in enumerate(dataloader):
inputs = inputs.to(device)
if batch_idx == 0:
chsum = (inputs - mean).pow(2).sum(dim=(0, 2, 3), keepdim=True)
else:
chsum += (inputs - mean).pow(2).sum(dim=(0, 2, 3), keepdim=True)
std = torch.sqrt(chsum / (len(trainset) * h * w - 1))
print(f'std: {std.view(-1)}')
return mean.view(-1).cpu().numpy().tolist(), std.view(-1).cpu().numpy().tolist()
[docs]@deprecated(reason='Use get_mean_and_std_torch() instead. It is faster and more accurate')
def get_mean_and_std(dataset):
'''Compute the mean and std value of dataset.'''
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
mean = torch.zeros(3)
std = torch.zeros(3)
print('==> Computing mean and std..')
j = 0
for inputs, targets in dataloader:
if j % 10 == 0:
print(j)
j += 1
for i in range(3):
mean[i] += inputs[:, i, :, :].mean()
std[i] += inputs[:, i, :, :].std()
mean.div_(len(dataset))
std.div_(len(dataset))
return mean, std
[docs]class AbstractCollateFunction(ABC):
"""
A collate function (for torch DataLoader)
"""
@abstractmethod
def __call__(self, batch):
pass
[docs]class ComposedCollateFunction(AbstractCollateFunction):
"""
A function (for torch DataLoader) which executes a sequence of sub collate functions
"""
def __init__(self, functions: list):
self.functions = functions
def __call__(self, batch):
for f in self.functions:
batch = f(batch)
return batch
[docs]class AtomicInteger:
def __init__(self, value: int = 0):
self._value = Value('i', value)
def __set__(self, instance, value):
self._value.value = value
def __get__(self, instance, owner):
return self._value.value
[docs]class MultiScaleCollateFunction(AbstractCollateFunction):
"""
a collate function to implement multi-scale data augmentation
according to https://arxiv.org/pdf/1612.08242.pdf
"""
_counter = AtomicInteger(0)
_current_size = AtomicInteger(0)
_lock = Lock()
def __init__(self, target_size: int = None, min_image_size: int = None, max_image_size: int = None,
image_size_steps: int = 32,
change_frequency: int = 10):
"""
set parameters for the multi-scale collate function
the possible image sizes are in range [min_image_size, max_image_size] in steps of image_size_steps
a new size will be randomly selected every change_frequency calls to the collate_fn()
:param target_size: scales will be [0.66 * target_size, 1.5 * target_size]
:param min_image_size: the minimum size to scale down to (in pixels)
:param max_image_size: the maximum size to scale up to (in pixels)
:param image_size_steps: typically, the stride of the net, which defines the possible image
size multiplications
:param change_frequency:
"""
assert target_size is not None or (max_image_size is not None and min_image_size is not None), \
'either target_size or min_image_size and max_image_size has to be set'
assert target_size is None or max_image_size is None, 'target_size and max_image_size cannot be both defined'
if target_size is not None:
min_image_size = int(0.66 * target_size - ((0.66 * target_size) % image_size_steps) + image_size_steps)
max_image_size = int(1.5 * target_size - ((1.5 * target_size) % image_size_steps))
print('Using multi-scale %g - %g' % (min_image_size, max_image_size))
self.sizes = np.arange(min_image_size, max_image_size + image_size_steps, image_size_steps)
self.image_size_steps = image_size_steps
self.frequency = change_frequency
self._current_size = random.choice(self.sizes)
def __call__(self, batch):
with self._lock:
# Important: this implementation was tailored for a specific input. it assumes the batch is a tuple where
# the images are the first item
assert isinstance(batch, tuple), 'this collate function expects the input to be a tuple (images, labels)'
images = batch[0]
if self._counter % self.frequency == 0:
self._current_size = random.choice(self.sizes)
self._counter += 1
assert images.shape[2] % self.image_size_steps == 0 and images.shape[3] % self.image_size_steps == 0, \
'images sized not divisible by %d. (resize images before calling multi_scale)' % self.image_size_steps
if self._current_size != max(images.shape[2:]):
ratio = float(self._current_size) / max(images.shape[2:])
new_size = (int(round(images.shape[2] * ratio)), int(round(images.shape[3] * ratio)))
images = F.interpolate(images, size=new_size, mode='bilinear', align_corners=False)
return images, batch[1]
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
Image.HAMMING: 'PIL.Image.HAMMING',
Image.BOX: 'PIL.Image.BOX',
}
def _pil_interp(method):
if method == 'bicubic':
return InterpolationMode.BICUBIC
elif method == 'lanczos':
return InterpolationMode.LANCZOS
elif method == 'hamming':
return InterpolationMode.HAMMING
elif method == 'nearest':
return InterpolationMode.NEAREST
elif method == 'bilinear':
return InterpolationMode.BILINEAR
elif method == 'box':
return InterpolationMode.BOX
else:
raise ValueError("interpolation type must be one of ['bilinear', 'bicubic', 'lanczos', 'hamming', "
"'nearest', 'box'] for explicit interpolation type, or 'random' for random")
_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
[docs]class RandomResizedCropAndInterpolation(RandomResizedCrop):
"""
Crop the given PIL Image to random size and aspect ratio with explicitly chosen or random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='default'):
super(RandomResizedCropAndInterpolation, self).__init__(size=size, scale=scale, ratio=ratio, interpolation=interpolation)
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
elif interpolation == 'default':
self.interpolation = InterpolationMode.BILINEAR
else:
self.interpolation = _pil_interp(interpolation)
[docs] def forward(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
return torchvision.transforms.functional.resized_crop(img, i, j, h, w, self.size, interpolation)
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
else:
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string
STAT_LOGGER_FONT_SIZE = 15
[docs]class DatasetStatisticsTensorboardLogger:
logger = get_logger(__name__)
DEFAULT_SUMMARY_PARAMS = {
'sample_images': 32, # by default, 32 images will be sampled from each dataset
'plot_class_distribution': True,
'plot_box_size_distribution': True,
'plot_anchors_coverage': True,
'max_batches': 30
}
def __init__(self, writer: SummaryWriter, summary_params: dict = DEFAULT_SUMMARY_PARAMS):
self.writer = writer
self.summary_params = {**DatasetStatisticsTensorboardLogger.DEFAULT_SUMMARY_PARAMS, **summary_params}
[docs] def analyze(self, data_loader: torch.utils.data.DataLoader, dataset_params: dict, title: str, anchors: list = None):
"""
:param data_loader: the dataset data loader
:param dataset_params: the dataset parameters
:param title: the title for this dataset (i.e. Coco 2017 test set)
:param anchors: the list of anchors used by the model. applicable only for detection datasets
"""
if isinstance(data_loader.dataset, DetectionDataSet):
self._analyze_detection(data_loader=data_loader, dataset_params=dataset_params, title=title, anchors=anchors)
else:
DatasetStatisticsTensorboardLogger.logger.warning('only DetectionDataSet are currently supported')
def _analyze_detection(self, data_loader, dataset_params, title, anchors=None):
"""
Analyze a detection dataset
:param data_loader: the dataset data loader
:param dataset_params: the dataset parameters
:param title: the title for this dataset (i.e. Coco 2017 test set)
:param anchors: the list of anchors used by the model. if not provided, anchors coverage will not be analyzed
"""
try:
color_mean = AverageMeter()
color_std = AverageMeter()
all_labels = []
for i, (images, labels) in enumerate(tqdm(data_loader)):
if i >= self.summary_params['max_batches'] > 0:
break
if i == 0:
if images.shape[0] > self.summary_params['sample_images']:
samples = images[:self.summary_params['sample_images']]
else:
samples = images
pred = [torch.zeros(size=(0, 6)) for _ in range(len(samples))]
class_names = data_loader.dataset.all_classes_list
result_images = DetectionVisualization.visualize_batch(image_tensor=samples, pred_boxes=pred,
target_boxes=copy.deepcopy(labels),
batch_name=title, class_names=class_names,
box_thickness=1,
gt_alpha=1.0)
self.writer.add_images(tag=f'{title} sample images', img_tensor=np.stack(result_images)
.transpose([0, 3, 1, 2])[:, ::-1, :, :])
all_labels.append(labels)
color_mean.update(torch.mean(images, dim=[0, 2, 3]), 1)
color_std.update(torch.std(images, dim=[0, 2, 3]), 1)
all_labels = torch.cat(all_labels, dim=0)[:, 1:].numpy()
if self.summary_params['plot_class_distribution']:
self._analyze_class_distribution(labels=all_labels, num_classes=dataset_params.num_classes, title=title)
if self.summary_params['plot_box_size_distribution']:
self._analyze_object_size_distribution(labels=all_labels, title=title)
summary = ''
summary += f'dataset size: {len(data_loader)} \n'
summary += f'color mean: {color_mean.average} \n'
summary += f'color std: {color_std.average} \n'
if anchors is not None:
coverage = self._analyze_anchors_coverage(anchors=anchors, image_size=dataset_params.train_image_size,
title=title, labels=all_labels)
summary += f'anchors: {anchors} \n'
summary += f'anchors coverage: {coverage} \n'
self.writer.add_text(tag=f'{title} Statistics', text_string=summary)
self.writer.flush()
except Exception as e:
# any exception is caught here. we dont want the DatasetStatisticsLogger to crash any training
DatasetStatisticsTensorboardLogger.logger.error(f'dataset analysis failed: {e}')
def _analyze_class_distribution(self, labels: list, num_classes: int, title: str):
hist, edges = np.histogram(labels[:, 0], num_classes)
f = plt.figure(figsize=[10, 8])
plt.bar(range(num_classes), hist, width=0.5, color='#0504aa', alpha=0.7)
plt.xlim(-1, num_classes)
plt.grid(axis='y', alpha=0.75)
plt.xlabel('Value', fontsize=STAT_LOGGER_FONT_SIZE)
plt.ylabel('Frequency', fontsize=STAT_LOGGER_FONT_SIZE)
plt.xticks(fontsize=STAT_LOGGER_FONT_SIZE)
plt.yticks(fontsize=STAT_LOGGER_FONT_SIZE)
plt.title(f'{title} class distribution', fontsize=STAT_LOGGER_FONT_SIZE)
self.writer.add_figure(f"{title} class distribution", figure=f)
text_dist = ''
for i, val in enumerate(hist):
text_dist += f'[{i}]: {val}, '
self.writer.add_text(tag=f"{title} class distribution", text_string=text_dist)
def _analyze_object_size_distribution(self, labels: list, title: str):
"""
This function will add two plots to the tensorboard.
one is a 2D histogram and the other is a scatter plot. in both cases the X axis is the object width and Y axis
is the object width (both normalized by image size)
:param labels: all the labels of the dataset of the shape [class_label, x_center, y_center, w, h]
:param title: the dataset title
"""
# histogram plot
hist, xedges, yedges = np.histogram2d(labels[:, 4], labels[:, 3], 50) # x and y are deliberately switched
fig = plt.figure(figsize=(10, 6))
fig.suptitle(f'{title} boxes w/h distribution')
ax = fig.add_subplot(121)
ax.set_xlabel('W', fontsize=STAT_LOGGER_FONT_SIZE)
ax.set_ylabel('H', fontsize=STAT_LOGGER_FONT_SIZE)
plt.imshow(np.log(hist + 1), interpolation='nearest', origin='lower',
extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])
# scatter plot
if len(labels) > 10000:
# we randomly sample just 10000 objects so that the scatter plot will not get too dense
labels = labels[np.random.randint(0, len(labels) - 1, 10000)]
ax = fig.add_subplot(122)
ax.set_xlabel('W', fontsize=STAT_LOGGER_FONT_SIZE)
ax.set_ylabel('H', fontsize=STAT_LOGGER_FONT_SIZE)
plt.scatter(labels[:, 3], labels[:, 4], marker='.')
self.writer.add_figure(tag=f'{title} boxes w/h distribution', figure=fig)
@staticmethod
def _get_rect(w, h):
min_w = w / 4.0
min_h = h / 4.0
return Rectangle((min_w, min_h), w * 4 - min_w, h * 4 - min_h, linewidth=1, edgecolor='b', facecolor='none')
@staticmethod
def _get_score(anchors: np.ndarray, points: np.ndarray, image_size: int):
"""
Calculate the ratio (and 1/ratio) between each anchor width and height and each point (representing a possible
object width and height).
i.e. for an anchor with w=10,h=20 the point w=11,h=25 will have the ratios 11/10=1.1 and 25/20=1.25
or 10/11=0.91 and 20/25=0.8 respectively
:param anchors: array of anchors of the shape [2,N]
:param points: array of points of the shape [2,M]
:param image_size the size of the input image
:returns: an array of size [image_size - 1, image_size - 1] where each cell i,j represent the minimum ratio
for that cell (point) from all anchors
"""
ratio = anchors[:, :, None] / points[:, ]
inv_ratio = 1 / ratio
min_ratio = 1 - np.minimum(ratio, inv_ratio)
min_ratio = np.max(min_ratio, axis=1)
to_closest_anchor = np.min(min_ratio, axis=0)
to_closest_anchor[to_closest_anchor > 0.75] = 2
return to_closest_anchor.reshape(image_size - 1, -1)
def _analyze_anchors_coverage(self, anchors: list, image_size: int, labels: list, title: str):
"""
This function will add anchors coverage plots to the tensorboard.
:param anchors: a list of anchors
:param image_size: the input image size for this training
:param labels: all the labels of the dataset of the shape [class_label, x_center, y_center, w, h]
:param title: the dataset title
"""
fig = plt.figure(figsize=(12, 5))
fig.suptitle(f'{title} anchors coverage')
# box style plot
ax = fig.add_subplot(121)
ax.set_xlabel('W', fontsize=STAT_LOGGER_FONT_SIZE)
ax.set_ylabel('H', fontsize=STAT_LOGGER_FONT_SIZE)
ax.set_xlim([0, image_size])
ax.set_ylim([0, image_size])
anchors = np.array(anchors).reshape(-1, 2)
for i in range(len(anchors)):
rect = self._get_rect(anchors[i][0], anchors[i][1])
rect.set_alpha(0.3)
rect.set_facecolor([random.random(), random.random(), random.random(), 0.3])
ax.add_patch(rect)
# distance from anchor plot
ax = fig.add_subplot(122)
ax.set_xlabel('W', fontsize=STAT_LOGGER_FONT_SIZE)
ax.set_ylabel('H', fontsize=STAT_LOGGER_FONT_SIZE)
x = np.arange(1, image_size, 1)
y = np.arange(1, image_size, 1)
xx, yy = np.meshgrid(x, y, sparse=False)
points = np.concatenate([xx.reshape(1, -1), yy.reshape(1, -1)])
color = self._get_score(anchors, points, image_size)
ax.set_xlabel('W', fontsize=STAT_LOGGER_FONT_SIZE)
ax.set_ylabel('H', fontsize=STAT_LOGGER_FONT_SIZE)
plt.imshow(color, interpolation='nearest', origin='lower',
extent=[0, image_size, 0, image_size])
# calculate the coverage for the dataset labels
cover_masks = []
for i in range(len(anchors)):
w_max = (anchors[i][0] / image_size) * 4
w_min = (anchors[i][0] / image_size) * 0.25
h_max = (anchors[i][1] / image_size) * 4
h_min = (anchors[i][1] / image_size) * 0.25
cover_masks.append(np.logical_and(
np.logical_and(np.logical_and(labels[:, 3] < w_max, labels[:, 3] > w_min), labels[:, 4] < h_max),
labels[:, 4] > h_min))
cover_masks = np.stack(cover_masks)
coverage = np.count_nonzero(np.any(cover_masks, axis=0)) / len(labels)
self.writer.add_figure(tag=f'{title} anchors coverage', figure=fig)
return coverage
[docs]def get_color_augmentation(rand_augment_config_string: str, color_jitter: tuple, crop_size=224, img_mean=[0.485, 0.456, 0.406]):
"""
Returns color augmentation class. As these augmentation cannot work on top one another, only one is returned according to rand_augment_config_string
:param rand_augment_config_string: string which defines the auto augment configurations. If none, color jitter will be returned. For possibile values see auto_augment.py
:param color_jitter: tuple for color jitter value.
:param crop_size: relevant only for auto augment
:param img_mean: relevant only for auto augment
:return: RandAugment transform or ColorJitter
"""
if rand_augment_config_string:
auto_augment_params = dict(translate_const=int(crop_size * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in img_mean]))
color_augmentation = rand_augment_transform(rand_augment_config_string, auto_augment_params)
else: # RandAugment includes colorjitter like augmentations, both cannot be applied together.
color_augmentation = transforms.ColorJitter(*color_jitter)
return color_augmentation