Source code for graphwar.attack.untargeted.pgd_attack

import math
from copy import deepcopy
from typing import Callable, Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.autograd import grad
from torch.distributions.bernoulli import Bernoulli
from tqdm import tqdm

from graphwar.attack.untargeted.untargeted_attacker import UntargetedAttacker
from graphwar.surrogate import Surrogate
from graphwar.functional import to_dense_adj


[docs]class PGDAttack(UntargetedAttacker, Surrogate): r"""Implementation of `PGD` attack from the: `"Topology Attack and Defense for Graph Neural Networks: An Optimization Perspective" <https://arxiv.org/abs/1906.04214>`_ paper (IJCAI'19) Parameters ---------- data : Data PyG-like data denoting the input graph device : str, optional the device of the attack running on, by default "cpu" seed : Optional[int], optional the random seed for reproducing the attack, by default None name : Optional[str], optional name of the attacker, if None, it would be :obj:`__class__.__name__`, by default None kwargs : additional arguments of :class:`graphwar.attack.Attacker`, Raises ------ TypeError unexpected keyword argument in :obj:`kwargs` Example ------- >>> from graphwar.dataset import GraphWarDataset >>> import torch_geometric.transforms as T >>> dataset = GraphWarDataset(root='~/data/pygdata', name='cora', transform=T.LargestConnectedComponents()) >>> data = dataset[0] >>> surrogate_model = ... # train your surrogate model >>> from graphwar.attack.untargeted import PGDAttack >>> attacker = PGDAttack(data) >>> attacker.setup_surrogate(surrogate_model) >>> attacker.reset() >>> attacker.attack(0.05) # attack with 0.05% of edge perturbations >>> attacker.data() # get attacked graph >>> attacker.edge_flips() # get edge flips after attack >>> attacker.added_edges() # get added edges after attack >>> attacker.removed_edges() # get removed edges after attack Note ---- * MinMax attack is a variant of :class:`graphwar.attack.untargeted.PGDAttack` attack. * Please remember to call :meth:`reset` before each attack. """ # PGDAttack cannot ensure that there is not singleton node after attacks. _allow_singleton: bool = True
[docs] def setup_surrogate(self, surrogate: torch.nn.Module, labeled_nodes: Tensor, unlabeled_nodes: Optional[Tensor] = None, *, eps: float = 1.0, freeze: bool = True): Surrogate.setup_surrogate(self, surrogate=surrogate, eps=eps, freeze=freeze) labeled_nodes = torch.LongTensor(labeled_nodes).to(self.device) # poisoning attack in DeepRobust if unlabeled_nodes is None: victim_nodes = labeled_nodes victim_labels = self.label[labeled_nodes] else: # Evasion attack in original paper unlabeled_nodes = torch.LongTensor(unlabeled_nodes).to(self.device) self_training_labels = self.estimate_self_training_labels( unlabeled_nodes) victim_nodes = torch.cat([labeled_nodes, unlabeled_nodes], dim=0) victim_labels = torch.cat([self.label[labeled_nodes], self_training_labels], dim=0) adj = to_dense_adj(self.edge_index, self.edge_weight, num_nodes=self.num_nodes).to(self.device) I = torch.eye(self.num_nodes, device=self.device) self.complementary = torch.ones_like(adj) - I - 2. * adj self.adj = adj self.victim_nodes = victim_nodes self.victim_labels = victim_labels return self
[docs] def reset(self): super().reset() self.perturbations = torch.zeros_like(self.adj).requires_grad_() return self
[docs] def attack(self, num_budgets=0.05, *, C=None, CW_loss=False, epochs=200, sample_epochs=20, structure_attack=True, feature_attack=False, disable=False): super().attack(num_budgets=num_budgets, structure_attack=structure_attack, feature_attack=feature_attack) self.CW_loss = CW_loss C = self.config_C(C) perturbations = self.perturbations for epoch in tqdm(range(epochs), desc='PGD training...', disable=disable): gradients = self.compute_gradients(perturbations, self.victim_nodes, self.victim_labels) lr = C / math.sqrt(epoch + 1) perturbations.data.add_(lr * gradients) perturbations = self.projection(perturbations) best_s = self.bernoulli_sample( perturbations, sample_epochs, disable=disable) row, col = torch.where(best_s > 0.) for it, (u, v) in enumerate(zip(row.tolist(), col.tolist())): if self.adj[u, v] > 0: self.remove_edge(u, v, it) else: self.add_edge(u, v, it) return self
[docs] def config_C(self, C=None): if C is not None: return C if self.CW_loss: C = 0.1 else: C = 200 return C
[docs] def bisection(self, perturbations, a, b, epsilon): def func(x): clipped_matrix = self.clip(perturbations - x) return clipped_matrix.sum() - self.num_budgets miu = a while (b - a) > epsilon: miu = (a + b) / 2 # Check if middle point is root if func(miu) == 0: break # Decide the side to repeat the steps if func(miu) * func(a) < 0: b = miu else: a = miu return miu
[docs] def get_perturbed_adj(self, perturbations=None): perturbations = self.perturbations if perturbations is None else perturbations adj_triu = torch.triu(perturbations, diagonal=1) perturbations = adj_triu + adj_triu.t() adj = self.complementary * perturbations + self.adj return adj
[docs] def projection(self, perturbations): clipped_matrix = self.clip(perturbations) num_modified = clipped_matrix.sum() if num_modified > self.num_budgets: left = (perturbations - 1.).min() right = perturbations.max() miu = self.bisection(perturbations, left, right, epsilon=1e-5) clipped_matrix = self.clip(perturbations - miu) else: pass perturbations.data.copy_(clipped_matrix) return perturbations
[docs] def clip(self, matrix): clipped_matrix = torch.clamp(matrix, 0., 1.) return clipped_matrix
[docs] @torch.no_grad() def bernoulli_sample(self, perturbations, sample_epochs=20, disable=False): best_loss = -1e4 best_s = None probs = torch.triu(perturbations, diagonal=1) sampler = Bernoulli(probs) for it in tqdm(range(sample_epochs), desc='Bernoulli sampling...', disable=disable): sampled = sampler.sample() if sampled.sum() > self.num_budgets: continue perturbations.data.copy_(sampled) loss = self.compute_loss( perturbations, self.victim_nodes, self.victim_labels) if best_loss < loss: best_loss = loss best_s = sampled assert best_s is not None, "Something went wrong" return best_s.cpu()
[docs] def compute_loss(self, perturbations, victim_nodes, victim_labels): adj = self.get_perturbed_adj(perturbations) logit = self.surrogate(self.feat, adj)[victim_nodes] / self.eps if self.CW_loss: # logit = F.softmax(logit, dim=1) one_hot = torch.eye( logit.size(-1), device=self.device)[victim_labels] range_idx = torch.arange(victim_nodes.size(0), device=self.device) best_wrong_class = (logit - 1000 * one_hot).argmax(1) margin = logit[range_idx, victim_labels] - \ logit[range_idx, best_wrong_class] + 50 loss = -torch.clamp(margin, min=0.) return loss.mean() else: loss = F.cross_entropy(logit, self.victim_labels) return loss
[docs] def compute_gradients(self, perturbations, victim_nodes, victim_labels): loss = self.compute_loss(perturbations, victim_nodes, victim_labels) return grad(loss, perturbations, create_graph=False)[0]
[docs]class MinmaxAttack(PGDAttack): r"""Implementation of `MinMax` attack from the: `"Topology Attack and Defense for Graph Neural Networks: An Optimization Perspective" <https://arxiv.org/abs/1906.04214>`_ paper (IJCAI'19) Parameters ---------- data : Data PyG-like data denoting the input graph device : str, optional the device of the attack running on, by default "cpu" seed : Optional[int], optional the random seed for reproducing the attack, by default None name : Optional[str], optional name of the attacker, if None, it would be :obj:`__class__.__name__`, by default None kwargs : additional arguments of :class:`graphwar.attack.Attacker`, Raises ------ TypeError unexpected keyword argument in :obj:`kwargs` Example ------- >>> from graphwar.dataset import GraphWarDataset >>> import torch_geometric.transforms as T >>> dataset = GraphWarDataset(root='~/data/pygdata', name='cora', transform=T.LargestConnectedComponents()) >>> data = dataset[0] >>> surrogate_model = ... # train your surrogate model >>> from graphwar.attack.untargeted import MinmaxAttack >>> attacker = MinmaxAttack(data) >>> attacker.setup_surrogate(surrogate_model) >>> attacker.reset() >>> attacker.attack(0.05) # attack with 0.05% of edge perturbations >>> attacker.data() # get attacked graph >>> attacker.edge_flips() # get edge flips after attack >>> attacker.added_edges() # get added edges after attack >>> attacker.removed_edges() # get removed edges after attack Note ---- * MinMax attack is a variant of :class:`graphwar.attack.untargeted.PGDAttack` attack. * Please remember to call :meth:`reset` before each attack. """
[docs] def setup_surrogate(self, surrogate: torch.nn.Module, labeled_nodes: Tensor, unlabeled_nodes: Optional[Tensor] = None, *, eps: float = 1.0): super().setup_surrogate(surrogate=surrogate, labeled_nodes=labeled_nodes, unlabeled_nodes=unlabeled_nodes, eps=eps, freeze=False) self.cached = deepcopy(self.surrogate.state_dict()) return self
[docs] def reset(self): super().reset() self.surrogate.load_state_dict(self.cached) return self
[docs] def attack(self, num_budgets=0.05, *, C=None, lr=0.001, CW_loss=False, epochs=100, sample_epochs=20, structure_attack=True, feature_attack=False, disable=False): super(PGDAttack, self).attack(num_budgets=num_budgets, structure_attack=structure_attack, feature_attack=feature_attack) self.CW_loss = CW_loss C = self.config_C(C) perturbations = self.perturbations optimizer = torch.optim.Adam(self.surrogate.parameters(), lr=lr) for epoch in tqdm(range(epochs), desc='Min-MAX training...', disable=disable): # =========== Min-step =================== loss = self.compute_loss(perturbations, self.victim_nodes, self.victim_labels) optimizer.zero_grad() loss.backward() optimizer.step() # ======================================== # =========== Max-step =================== gradients = self.compute_gradients(perturbations, self.victim_nodes, self.victim_labels) lr = C / math.sqrt(epoch + 1) perturbations.data.add_(lr * gradients) perturbations = self.projection(perturbations) # ======================================== best_s = self.bernoulli_sample( perturbations, sample_epochs, disable=disable) row, col = torch.where(best_s > 0.) for it, (u, v) in enumerate(zip(row.tolist(), col.tolist())): if self.adj[u, v] > 0: self.remove_edge(u, v, it) else: self.add_edge(u, v, it) return self