import numpy as np
import torch
from torchvision import datasets
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
[docs]def get_dataset(name, path):
"""
Loads dataset
Parameters
----------
name: str
Name of the dataset to be loaded. Supports MNIST and CIFAR10
path: str
Path to save the downloaded dataset
Returns
----------
X_tr: numpy array
Train set
Y_tr: torch tensor
Training Labels
X_te: numpy array
Test Set
Y_te: torch tensor
Test labels
"""
if name == 'MNIST':
return get_MNIST(path)
elif name == 'CIFAR10':
return get_CIFAR10(path)
[docs]def get_MNIST(path):
"""
Downloads MNIST dataset
Parameters
----------
path: str
Path to save the downloaded dataset
Returns
----------
X_tr: numpy array
Train set
Y_tr: torch tensor
Training Labels
X_te: numpy array
Test Set
Y_te: torch tensor
Test labels
"""
raw_tr = datasets.MNIST(path + '/MNIST', train=True, download=True)
raw_te = datasets.MNIST(path + '/MNIST', train=False, download=True)
X_tr = raw_tr.train_data
Y_tr = raw_tr.train_labels
X_te = raw_te.test_data
Y_te = raw_te.test_labels
return X_tr, Y_tr, X_te, Y_te
[docs]def get_CIFAR10(path):
"""
Downloads CIFAR10 dataset
Parameters
----------
path: str
Path to save the downloaded dataset
Returns
----------
X_tr: numpy array
Train set
Y_tr: torch tensor
Training Labels
X_te: numpy array
Test Set
Y_te: torch tensor
Test labels
"""
# Introduce a training transform that provides generalization in training to the test data.
data_tr = datasets.CIFAR10(path + '/CIFAR10', train=True, download=True)
data_te = datasets.CIFAR10(path + '/CIFAR10', train=False, download=True)
X_tr = data_tr.data
Y_tr = torch.from_numpy(np.array(data_tr.targets))
X_te = data_te.data
Y_te = torch.from_numpy(np.array(data_te.targets))
return X_tr, Y_tr, X_te, Y_te