utils

DataHandler

class distil.utils.DataHandler.DataHandler_CIFAR10(X, Y=None, select=True, use_test_transform=False)[source]

Bases: torch.utils.data.dataset.Dataset

Data Handler to load CIFAR10 dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class distil.utils.DataHandler.DataHandler_MNIST(X, Y=None, select=True)[source]

Bases: torch.utils.data.dataset.Dataset

Data Handler to load MNIST dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class distil.utils.DataHandler.DataHandler_Points(X, Y=None, select=True)[source]

Bases: torch.utils.data.dataset.Dataset

Data Handler to load data points. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

Dataset

distil.utils.dataset.get_CIFAR10(path)[source]

Downloads CIFAR10 dataset

Parameters

path (str) – Path to save the downloaded dataset

Returns

  • X_tr (numpy array) – Train set

  • Y_tr (torch tensor) – Training Labels

  • X_te (numpy array) – Test Set

  • Y_te (torch tensor) – Test labels

distil.utils.dataset.get_MNIST(path)[source]

Downloads MNIST dataset

Parameters

path (str) – Path to save the downloaded dataset

Returns

  • X_tr (numpy array) – Train set

  • Y_tr (torch tensor) – Training Labels

  • X_te (numpy array) – Test Set

  • Y_te (torch tensor) – Test labels

distil.utils.dataset.get_dataset(name, path)[source]

Loads dataset

Parameters
  • name (str) – Name of the dataset to be loaded. Supports MNIST and CIFAR10

  • path (str) – Path to save the downloaded dataset

Returns

  • X_tr (numpy array) – Train set

  • Y_tr (torch tensor) – Training Labels

  • X_te (numpy array) – Test Set

  • Y_te (torch tensor) – Test labels