import numpy as np
import torch
import torch.nn.functional as F
from .strategy import Strategy
from torch.autograd import Variable
[docs]class AdversarialDeepFool(Strategy):
"""
Implementation of Adversial Deep Fool Strategy.
This class extends :class:`active_learning_strategies.strategy.Strategy`
to include entropy sampling technique to select data points for active learning.
Parameters
----------
X: numpy array
Present training/labeled data
y: numpy array
Labels of present training data
unlabeled_x: numpy array
Data without labels
net: class
Pytorch Model class
handler: class
Data Handler, which can load data even without labels.
nclasses: int
Number of unique target variables
args: dict
Specify optional parameters
batch_size
Batch size to be used inside strategy class (int, optional)
max_iter
Maximum Number of Iterations (int, optional)
"""
def __init__(self, X, Y, unlabeled_x, net, handler, nclasses, args={}):
"""
Constructor method
"""
if 'max_iter' in args:
self.max_iter = args['max_iter']
else:
self.max_iter = 50
super(AdversarialDeepFool, self).__init__(X, Y, unlabeled_x, net, handler, nclasses, args={})
[docs] def cal_dis(self, x):
nx = Variable(torch.unsqueeze(x, 0), requires_grad=True)
eta = Variable(torch.zeros(nx.shape))
out = self.model(nx + eta)
n_class = out.shape[1]
py = int(out.max(1)[1])
ny = int(out.max(1)[1])
i_iter = 0
while py == ny and i_iter < self.max_iter:
out[0, py].backward(retain_graph=True)
grad_np = nx.grad.data.clone()
value_l = np.inf
ri = None
for i in range(n_class):
if i == py:
continue
nx.grad.data.zero_()
out[0, i].backward(retain_graph=True)
grad_i = nx.grad.data.clone()
wi = grad_i - grad_np
fi = out[0, i] - out[0, py]
value_i = np.abs(float(fi)) / np.linalg.norm(wi.numpy().flatten())
if value_i < value_l:
ri = value_i/np.linalg.norm(wi.numpy().flatten()) * wi
eta += Variable(ri.clone())
nx.grad.data.zero_()
out = self.model(nx + eta)
py = int(out.max(1)[1])
i_iter += 1
return (eta*eta).sum()
[docs] def select(self, budget):
"""
Select next set of points
Parameters
----------
budget: int
Number of indexes to be returned for next set
Returns
----------
idxs: list
List of selected data point indexes with respect to unlabeled_x
"""
self.model.cpu()
self.model.eval()
dis = np.zeros(self.unlabeled_x.shape[0])
data_pool = self.handler(self.unlabeled_x)
for i in range(self.unlabeled_x.shape[0]):
# if i % 20 == 0:
# print('adv {}/{}'.format(i, self.unlabeled_x.shape[0]), flush=True)
x, idx = data_pool[i]
#x = torch.from_numpy(x)
dis[i] = self.cal_dis(x)
self.model.to(self.device)
idxs = dis.argsort()[:budget]
return idxs