from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from torchvision import transforms
[docs]class DataHandler_Points(Dataset):
"""
Data Handler to load data points.
This class extends :class:`torch.utils.data.Dataset` to handle
loading data even without labels
Parameters
----------
X: numpy array
Data to be loaded
y: numpy array, optional
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
"""
def __init__(self, X, Y=None, select=True):
"""
Constructor
"""
self.select = select
if not self.select:
self.X = X.astype(np.float32)
self.Y = Y
else:
self.X = X.astype(np.float32) #For unlabeled Data
def __getitem__(self, index):
if not self.select:
x, y = self.X[index], self.Y[index]
return x, y, index
else:
x = self.X[index] #For unlabeled Data
return x, index
def __len__(self):
return len(self.X)
[docs]class DataHandler_MNIST(Dataset):
"""
Data Handler to load MNIST dataset.
This class extends :class:`torch.utils.data.Dataset` to handle
loading data even without labels
Parameters
----------
X: numpy array
Data to be loaded
y: numpy array, optional
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
"""
def __init__(self, X, Y=None, select=True):
"""
Constructor
"""
self.select = select
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
if not self.select:
self.X = X
self.Y = Y
self.transform = transform
else:
self.X = X
self.transform = transform
def __getitem__(self, index):
if not self.select:
x, y = self.X[index], self.Y[index]
x = Image.fromarray(x)
x = self.transform(x)
return x, y, index
else:
x = self.X[index]
x = Image.fromarray(x)
x = self.transform(x)
return x, index
def __len__(self):
return len(self.X)
[docs]class DataHandler_CIFAR10(Dataset):
"""
Data Handler to load CIFAR10 dataset.
This class extends :class:`torch.utils.data.Dataset` to handle
loading data even without labels
Parameters
----------
X: numpy array
Data to be loaded
y: numpy array, optional
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
"""
def __init__(self, X, Y=None, select=True, use_test_transform=False):
"""
Constructor
"""
self.select = select
self.use_test_transform=use_test_transform
self.training_gen_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
self.test_gen_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
if not self.select:
self.X = X
self.Y = Y
else:
self.X = X
def __getitem__(self, index):
if not self.select:
x, y = self.X[index], self.Y[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, y, index
else:
x = self.X[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, index
def __len__(self):
return len(self.X)