utils

DataHandler

class distil.utils.DataHandler.DataHandler_CIFAR10(X, Y=None, select=True, use_test_transform=False)[source]

Bases: torch.utils.data.dataset.Dataset

Data Handler to load CIFAR10 dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class distil.utils.DataHandler.DataHandler_CIFAR100(X, Y=None, select=True, use_test_transform=False)[source]

Bases: torch.utils.data.dataset.Dataset

Data Handler to load CIFAR100 dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class distil.utils.DataHandler.DataHandler_FASHION_MNIST(X, Y=None, select=True, use_test_transform=False)[source]

Bases: torch.utils.data.dataset.Dataset

Data Handler to load FASHION_MNIST dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class distil.utils.DataHandler.DataHandler_KMNIST(X, Y=None, select=True, use_test_transform=False)[source]

Bases: torch.utils.data.dataset.Dataset

Data Handler to load KMNIST dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class distil.utils.DataHandler.DataHandler_MNIST(X, Y=None, select=True, use_test_transform=False)[source]

Bases: torch.utils.data.dataset.Dataset

Data Handler to load MNIST dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class distil.utils.DataHandler.DataHandler_Points(X, Y=None, select=True)[source]

Bases: torch.utils.data.dataset.Dataset

Data Handler to load data points. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class distil.utils.DataHandler.DataHandler_STL10(X, Y=None, select=True, use_test_transform=False)[source]

Bases: torch.utils.data.dataset.Dataset

Data Handler to load STL10 dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

class distil.utils.DataHandler.DataHandler_SVHN(X, Y=None, select=True, use_test_transform=False)[source]

Bases: torch.utils.data.dataset.Dataset

Data Handler to load SVHN dataset. This class extends torch.utils.data.Dataset to handle loading data even without labels

Parameters
  • X (numpy array) – Data to be loaded

  • y (numpy array, optional) – Labels to be loaded (default: None)

  • select (bool) – True if loading data without labels, False otherwise

Dataset

distil.utils.dataset.add_label_noise(y_trn, num_cls, noise_ratio=0.8)[source]

Adds noise to the specified list of labels. This functionality is taken from CORDS and applied here.

Parameters
  • y_trn (list) – The list of labels to add noise.

  • num_cls (int) – The number of classes possible in the list.

  • noise_ratio (float, optional) – The percentage of labels to modify. The default is 0.8.

Returns

y_trn – The list of now-noisy labels

Return type

list

distil.utils.dataset.get_CIFAR10(path, tr_load_args=None, te_load_args=None)[source]

Downloads CIFAR10 dataset

Parameters

path (str) – Path to save the downloaded dataset

Returns

  • X_tr (numpy array) – Train set

  • Y_tr (torch tensor) – Training Labels

  • X_te (numpy array) – Test Set

  • Y_te (torch tensor) – Test labels

distil.utils.dataset.get_CIFAR100(path, tr_load_args=None, te_load_args=None)[source]

Downloads CIFAR100 dataset

Parameters

path (str) – Path to save the downloaded dataset

Returns

  • X_tr (numpy array) – Train set

  • Y_tr (torch tensor) – Training Labels

  • X_te (numpy array) – Test Set

  • Y_te (torch tensor) – Test labels

distil.utils.dataset.get_FASHION_MNIST(path, tr_load_args=None, te_load_args=None)[source]

Downloads FASHION_MNIST dataset

Parameters

path (str) – Path to save the downloaded dataset

Returns

  • X_tr (numpy array) – Train set

  • Y_tr (torch tensor) – Training Labels

  • X_te (numpy array) – Test Set

  • Y_te (torch tensor) – Test labels

distil.utils.dataset.get_KMNIST(path, tr_load_args=None, te_load_args=None)[source]

Downloads KMNIST dataset

Parameters

path (str) – Path to save the downloaded dataset

Returns

  • X_tr (numpy array) – Train set

  • Y_tr (torch tensor) – Training Labels

  • X_te (numpy array) – Test Set

  • Y_te (torch tensor) – Test labels

distil.utils.dataset.get_MNIST(path, tr_load_args=None, te_load_args=None)[source]

Downloads MNIST dataset

Parameters

path (str) – Path to save the downloaded dataset

Returns

  • X_tr (numpy array) – Train set

  • Y_tr (torch tensor) – Training Labels

  • X_te (numpy array) – Test Set

  • Y_te (torch tensor) – Test labels

distil.utils.dataset.get_STL10(path, tr_load_args=None, te_load_args=None)[source]

Downloads STL10 dataset

Parameters

path (str) – Path to save the downloaded dataset

Returns

  • X_tr (numpy array) – Train set

  • Y_tr (torch tensor) – Training Labels

  • X_te (numpy array) – Test Set

  • Y_te (torch tensor) – Test labels

distil.utils.dataset.get_SVHN(path, tr_load_args=None, te_load_args=None)[source]

Downloads SVHN dataset

Parameters

path (str) – Path to save the downloaded dataset

Returns

  • X_tr (numpy array) – Train set

  • Y_tr (torch tensor) – Training Labels

  • X_te (numpy array) – Test Set

  • Y_te (torch tensor) – Test labels

distil.utils.dataset.get_dataset(name, path, tr_load_args=None, te_load_args=None)[source]

Loads dataset

Parameters
  • name (str) – Name of the dataset to be loaded. Supports MNIST and CIFAR10

  • path (str) – Path to save the downloaded dataset

  • tr_load_args (dict) – String dictionary for train distribution shift loading

  • te_load_args (dict) – String dictionary for test distribution shift loading

Returns

  • X_tr (numpy array) – Train set

  • Y_tr (torch tensor) – Training Labels

  • X_te (numpy array) – Test Set

  • Y_te (torch tensor) – Test labels

distil.utils.dataset.get_imbalanced_idx(y_trn, num_cls, class_ratio=0.6)[source]

Returns a list of indices of the supplied dataset that constitute a class-imbalanced subset of the supplied dataset. This functionality is taken from CORDS and applied here.

Parameters
  • y_trn (numpy ndarray) – The label set to choose imbalance.

  • num_cls (int) – The number of classes possible in the list.

  • class_ratio (float, optional) – The percentage of classes to affect. The default is 0.6.

Returns

subset_idxs – The list of indices of the supplied dataset that constitute a class-imbalanced subset

Return type

list

distil.utils.dataset.make_data_redundant(X, Y, amtRed=2)[source]

Modifies the input dataset in such a way that only X.shape(0)/amtRed are original points and rest are repeated or redundant.

Parameters
  • X (numpy ndarray) – The feature set to be made redundant.

  • Y (numpy ndarray) – The label set corresponding to the X.

  • amtRed (float, optional) – Factor that determines redundancy. The default is 2.

Returns

X – Modified feature set.

Return type

numpy ndarray