Source code for distil.active_learning_strategies.core_set

import numpy as np
import pdb
from .strategy import Strategy
from sklearn.neighbors import NearestNeighbors
import pickle
from datetime import datetime
from sklearn.metrics import pairwise_distances

[docs]class CoreSet(Strategy): """ Implementation of CoreSet :footcite:`sener2018active` Strategy. This class extends :class:`active_learning_strategies.strategy.Strategy` to include coreset sampling technique to select data points for active learning. Parameters ---------- X: numpy array Present training/labeled data y: numpy array Labels of present training data unlabeled_x: numpy array Data without labels net: class Pytorch Model class handler: class Data Handler, which can load data even without labels. nclasses: int Number of unique target variables args: dict Specify optional parameters batch_size Batch size to be used inside strategy class (int, optional) """ def __init__(self, X, Y, unlabeled_x, net, handler, nclasses, args={}): """ Constructor method """ if 'tor' in args: self.tor = args['tor'] else: self.tor = 1e-4 super(CoreSet, self).__init__(X, Y, unlabeled_x, net, handler, nclasses, args)
[docs] def furthest_first(self, X, X_set, n): """ Selects points with maximum distance Parameters ---------- X: numpy array Embeddings of unlabeled set X_set: numpy array Embeddings of labeled set n: int Number of points to return Returns ---------- idxs: list List of selected data point indexes with respect to unlabeled_x """ m = np.shape(X)[0] if np.shape(X_set)[0] == 0: min_dist = np.tile(float("inf"), m) else: dist_ctr = pairwise_distances(X, X_set) min_dist = np.amin(dist_ctr, axis=1) idxs = [] for i in range(n): idx = min_dist.argmax() idxs.append(idx) dist_new_ctr = pairwise_distances(X, X[[idx], :]) for j in range(m): min_dist[j] = min(min_dist[j], dist_new_ctr[j, 0]) return idxs
[docs] def select(self, budget): """ Select next set of points Parameters ---------- budget: int Number of indexes to be returned for next set Returns ---------- chosen: list List of selected data point indexes with respect to unlabeled_x """ embedding_unlabeled = self.get_embedding(self.unlabeled_x) embedding_unlabeled = embedding_unlabeled.numpy() embedding_labeled = self.get_embedding(self.X) embedding_labeled = embedding_labeled.numpy() chosen = self.furthest_first(embedding_unlabeled, embedding_labeled, budget) return chosen