from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from torchvision import transforms
[docs]class DataHandler_Points(Dataset):
"""
Data Handler to load data points.
This class extends :class:`torch.utils.data.Dataset` to handle
loading data even without labels
Parameters
----------
X: numpy array
Data to be loaded
y: numpy array, optional
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
"""
def __init__(self, X, Y=None, select=True):
"""
Constructor
"""
self.select = select
if not self.select:
self.X = X.astype(np.float32)
self.Y = Y
else:
self.X = X.astype(np.float32) #For unlabeled Data
def __getitem__(self, index):
if not self.select:
x, y = self.X[index], self.Y[index]
return x, y, index
else:
x = self.X[index] #For unlabeled Data
return x, index
def __len__(self):
return len(self.X)
[docs]class DataHandler_SVHN(Dataset):
"""
Data Handler to load SVHN dataset.
This class extends :class:`torch.utils.data.Dataset` to handle
loading data even without labels
Parameters
----------
X: numpy array
Data to be loaded
y: numpy array, optional
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
"""
def __init__(self, X, Y=None, select=True, use_test_transform=False):
"""
Constructor
"""
self.select = select
self.use_test_transform=use_test_transform
self.training_gen_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
self.test_gen_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # ImageNet mean/std
if not self.select:
self.X = X
self.Y = Y
else:
self.X = X
def __getitem__(self, index):
if not self.select:
x, y = self.X[index], self.Y[index]
x = Image.fromarray(np.transpose(x, (1, 2, 0)))
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, y, index
else:
x = self.X[index]
x = Image.fromarray(np.transpose(x, (1, 2, 0)))
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, index
def __len__(self):
return len(self.X)
[docs]class DataHandler_MNIST(Dataset):
"""
Data Handler to load MNIST dataset.
This class extends :class:`torch.utils.data.Dataset` to handle
loading data even without labels
Parameters
----------
X: numpy array
Data to be loaded
y: numpy array, optional
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
"""
def __init__(self, X, Y=None, select=True, use_test_transform=False):
"""
Constructor
"""
self.select = select
self.use_test_transform=use_test_transform
self.training_gen_transform = transforms.Compose([transforms.RandomCrop(28, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
self.test_gen_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
if not self.select:
self.X = X
self.Y = Y
else:
self.X = X
def __getitem__(self, index):
if not self.select:
x, y = self.X[index], self.Y[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, y, index
else:
x = self.X[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, index
def __len__(self):
return len(self.X)
[docs]class DataHandler_KMNIST(Dataset):
"""
Data Handler to load KMNIST dataset.
This class extends :class:`torch.utils.data.Dataset` to handle
loading data even without labels
Parameters
----------
X: numpy array
Data to be loaded
y: numpy array, optional
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
"""
def __init__(self, X, Y=None, select=True, use_test_transform=False):
"""
Constructor
"""
self.select = select
self.use_test_transform=use_test_transform
self.training_gen_transform = transforms.Compose([transforms.RandomCrop(28, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
self.test_gen_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # Use mean/std of MNIST
if not self.select:
self.X = X
self.Y = Y
else:
self.X = X
def __getitem__(self, index):
if not self.select:
x, y = self.X[index], self.Y[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, y, index
else:
x = self.X[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, index
def __len__(self):
return len(self.X)
[docs]class DataHandler_FASHION_MNIST(Dataset):
"""
Data Handler to load FASHION_MNIST dataset.
This class extends :class:`torch.utils.data.Dataset` to handle
loading data even without labels
Parameters
----------
X: numpy array
Data to be loaded
y: numpy array, optional
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
"""
def __init__(self, X, Y=None, select=True, use_test_transform=False):
"""
Constructor
"""
self.select = select
self.use_test_transform=use_test_transform
self.training_gen_transform = transforms.Compose([transforms.RandomCrop(28, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
self.test_gen_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # Use mean/std of MNIST
if not self.select:
self.X = X
self.Y = Y
else:
self.X = X
def __getitem__(self, index):
if not self.select:
x, y = self.X[index], self.Y[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, y, index
else:
x = self.X[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, index
def __len__(self):
return len(self.X)
[docs]class DataHandler_CIFAR10(Dataset):
"""
Data Handler to load CIFAR10 dataset.
This class extends :class:`torch.utils.data.Dataset` to handle
loading data even without labels
Parameters
----------
X: numpy array
Data to be loaded
y: numpy array, optional
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
"""
def __init__(self, X, Y=None, select=True, use_test_transform=False):
"""
Constructor
"""
self.select = select
self.use_test_transform=use_test_transform
self.training_gen_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
self.test_gen_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
if not self.select:
self.X = X
self.Y = Y
else:
self.X = X
def __getitem__(self, index):
if not self.select:
x, y = self.X[index], self.Y[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, y, index
else:
x = self.X[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, index
def __len__(self):
return len(self.X)
[docs]class DataHandler_CIFAR100(Dataset):
"""
Data Handler to load CIFAR100 dataset.
This class extends :class:`torch.utils.data.Dataset` to handle
loading data even without labels
Parameters
----------
X: numpy array
Data to be loaded
y: numpy array, optional
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
"""
def __init__(self, X, Y=None, select=True, use_test_transform=False):
"""
Constructor
"""
self.select = select
self.use_test_transform=use_test_transform
self.training_gen_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
self.test_gen_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
if not self.select:
self.X = X
self.Y = Y
else:
self.X = X
def __getitem__(self, index):
if not self.select:
x, y = self.X[index], self.Y[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, y, index
else:
x = self.X[index]
x = Image.fromarray(x)
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, index
def __len__(self):
return len(self.X)
[docs]class DataHandler_STL10(Dataset):
"""
Data Handler to load STL10 dataset.
This class extends :class:`torch.utils.data.Dataset` to handle
loading data even without labels
Parameters
----------
X: numpy array
Data to be loaded
y: numpy array, optional
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
"""
def __init__(self, X, Y=None, select=True, use_test_transform=False):
"""
Constructor
"""
self.select = select
self.use_test_transform=use_test_transform
self.training_gen_transform = transforms.Compose([transforms.RandomCrop(96, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
self.test_gen_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # ImageNet mean/std
if not self.select:
self.X = X
self.Y = Y
else:
self.X = X
def __getitem__(self, index):
if not self.select:
x, y = self.X[index], self.Y[index]
x = Image.fromarray(np.transpose(x, (1, 2, 0)))
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, y, index
else:
x = self.X[index]
x = Image.fromarray(np.transpose(x, (1, 2, 0)))
if self.use_test_transform:
x = self.test_gen_transform(x)
else:
x = self.training_gen_transform(x)
return x, index
def __len__(self):
return len(self.X)