utils¶
DataHandler¶
-
class
distil.utils.DataHandler.
DataHandler_CIFAR10
(X, Y=None, select=True, use_test_transform=False)[source]¶ Bases:
torch.utils.data.dataset.Dataset
Data Handler to load CIFAR10 dataset. This class extends
torch.utils.data.Dataset
to handle loading data even without labels- Parameters
X (numpy array) – Data to be loaded
y (numpy array, optional) – Labels to be loaded (default: None)
select (bool) – True if loading data without labels, False otherwise
-
class
distil.utils.DataHandler.
DataHandler_MNIST
(X, Y=None, select=True)[source]¶ Bases:
torch.utils.data.dataset.Dataset
Data Handler to load MNIST dataset. This class extends
torch.utils.data.Dataset
to handle loading data even without labels- Parameters
X (numpy array) – Data to be loaded
y (numpy array, optional) – Labels to be loaded (default: None)
select (bool) – True if loading data without labels, False otherwise
-
class
distil.utils.DataHandler.
DataHandler_Points
(X, Y=None, select=True)[source]¶ Bases:
torch.utils.data.dataset.Dataset
Data Handler to load data points. This class extends
torch.utils.data.Dataset
to handle loading data even without labels- Parameters
X (numpy array) – Data to be loaded
y (numpy array, optional) – Labels to be loaded (default: None)
select (bool) – True if loading data without labels, False otherwise
Dataset¶
-
distil.utils.dataset.
get_CIFAR10
(path)[source]¶ Downloads CIFAR10 dataset
- Parameters
path (str) – Path to save the downloaded dataset
- Returns
X_tr (numpy array) – Train set
Y_tr (torch tensor) – Training Labels
X_te (numpy array) – Test Set
Y_te (torch tensor) – Test labels
-
distil.utils.dataset.
get_MNIST
(path)[source]¶ Downloads MNIST dataset
- Parameters
path (str) – Path to save the downloaded dataset
- Returns
X_tr (numpy array) – Train set
Y_tr (torch tensor) – Training Labels
X_te (numpy array) – Test Set
Y_te (torch tensor) – Test labels
-
distil.utils.dataset.
get_dataset
(name, path)[source]¶ Loads dataset
- Parameters
name (str) – Name of the dataset to be loaded. Supports MNIST and CIFAR10
path (str) – Path to save the downloaded dataset
- Returns
X_tr (numpy array) – Train set
Y_tr (torch tensor) – Training Labels
X_te (numpy array) – Test Set
Y_te (torch tensor) – Test labels