Source code for graphwar.attack.targeted.sg_attack

from collections import namedtuple
from typing import Callable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import grad
from tqdm import tqdm

from graphwar.attack.targeted.targeted_attacker import TargetedAttacker
from graphwar.utils import ego_graph
from graphwar.surrogate import Surrogate

SubGraph = namedtuple('SubGraph', ['edge_index', 'sub_edges', 'non_edges',
                                   'edge_weight', 'non_edge_weight', 'selfloop_weight'])


[docs]class SGAttack(TargetedAttacker, Surrogate): r"""Implementation of `SGA` attack from the: `"Adversarial Attack on Large Scale Graph" <https://arxiv.org/abs/2009.03488>`_ paper (TKDE'21) Parameters ---------- data : Data PyG-like data denoting the input graph device : str, optional the device of the attack running on, by default "cpu" seed : Optional[int], optional the random seed for reproducing the attack, by default None name : Optional[str], optional name of the attacker, if None, it would be :obj:`__class__.__name__`, by default None kwargs : additional arguments of :class:`graphwar.attack.Attacker`, Raises ------ TypeError unexpected keyword argument in :obj:`kwargs` Example ------- >>> from graphwar.dataset import GraphWarDataset >>> import torch_geometric.transforms as T >>> dataset = GraphWarDataset(root='~/data/pygdata', name='cora', transform=T.LargestConnectedComponents()) >>> data = dataset[0] >>> surrogate_model = ... # train your surrogate model >>> from graphwar.attack.targeted import SGAttack >>> attacker = SGAttack(data) >>> attacker.setup_surrogate(surrogate_model) >>> attacker.reset() >>> attacker.attack(target=1) # attacking target node `1` with default budget set as node degree >>> attacker.reset() >>> attacker.attack(target=1, num_budgets=1) # attacking target node `1` with budget set as 1 >>> attacker.data() # get attacked graph >>> attacker.edge_flips() # get edge flips after attack >>> attacker.added_edges() # get added edges after attack >>> attacker.removed_edges() # get removed edges after attack Note ---- * `SGAttack` is a scalable attack that can be applied to large scale graphs. * Please remember to call :meth:`reset` before each attack. """ # SGAttack cannot ensure that there is not singleton node after attacks. _allow_singleton = True
[docs] @torch.no_grad() def setup_surrogate(self, surrogate: torch.nn.Module, eps: float = 5.0, freeze: bool = True, K: int = 2): Surrogate.setup_surrogate(self, surrogate=surrogate, eps=eps, freeze=freeze) self.logits = self.surrogate( self.feat, self.edge_index, self.edge_weight) self.K = K return self
[docs] def set_normalize(self, state): for layer in self.surrogate.modules(): if hasattr(layer, 'normalize'): layer.normalize = state if hasattr(layer, 'add_self_loops'): layer.add_self_loops = state
[docs] def strongest_wrong_class(self, target, target_label): logit = self.logits[target].clone() logit[target_label] = -1e4 return logit.argmax()
[docs] def get_subgraph(self, target, target_label, best_wrong_label): sub_nodes, sub_edges = ego_graph( self.adjacency_matrix, int(target), self.K) if sub_edges.size == 0: raise RuntimeError( f"The target node {int(target)} is a singleton node.") sub_nodes = torch.as_tensor( sub_nodes, dtype=torch.long, device=self.device) sub_edges = torch.as_tensor( sub_edges, dtype=torch.long, device=self.device) attacker_nodes = torch.where(self.label == best_wrong_label)[ 0].cpu().numpy() neighbors = self.adjacency_matrix[target].indices influencers = [target] attacker_nodes = np.setdiff1d(attacker_nodes, neighbors) subgraph = self.subgraph_processing( sub_nodes, sub_edges, influencers, attacker_nodes) if self.direct_attack: influencers = [target] num_attackers = self.num_budgets + 1 else: influencers = neighbors num_attackers = 3 attacker_nodes = self.get_top_attackers(subgraph, target, target_label, best_wrong_label, num_attackers=num_attackers) subgraph = self.subgraph_processing( sub_nodes, sub_edges, influencers, attacker_nodes) return subgraph
[docs] def get_top_attackers(self, subgraph, target, target_label, best_wrong_label, num_attackers): non_edge_grad, _ = self.compute_gradients( subgraph, target, target_label, best_wrong_label) _, index = torch.topk(non_edge_grad, k=num_attackers, sorted=False) attacker_nodes = subgraph.non_edges[1][index] return attacker_nodes.tolist()
[docs] def subgraph_processing(self, sub_nodes, sub_edges, influencers, attacker_nodes): row = np.repeat(influencers, len(attacker_nodes)) col = np.tile(attacker_nodes, len(influencers)) non_edges = np.row_stack([row, col]) if not self.direct_attack: # indirect attack mask = self.adjacency_matrix[non_edges[0], non_edges[1]].A1 == 0 non_edges = non_edges[:, mask] non_edges = torch.as_tensor( non_edges, dtype=torch.long, device=self.device) attacker_nodes = torch.as_tensor( attacker_nodes, dtype=torch.long, device=self.device) selfloop = torch.unique(torch.cat([sub_nodes, attacker_nodes])) edge_index = torch.cat([non_edges, sub_edges, non_edges.flip(0), sub_edges.flip(0), selfloop.repeat((2, 1))], dim=1) edge_weight = torch.ones(sub_edges.size( 1), device=self.device).requires_grad_() non_edge_weight = torch.zeros(non_edges.size( 1), device=self.device).requires_grad_() selfloop_weight = torch.ones(selfloop.size(0), device=self.device) subgraph = SubGraph(edge_index=edge_index, sub_edges=sub_edges, non_edges=non_edges, edge_weight=edge_weight, non_edge_weight=non_edge_weight, selfloop_weight=selfloop_weight,) return subgraph
[docs] def attack(self, target, *, target_label=None, num_budgets=None, direct_attack=True, structure_attack=True, feature_attack=False, disable=False): super().attack(target, target_label, num_budgets=num_budgets, direct_attack=direct_attack, structure_attack=structure_attack, feature_attack=feature_attack) self.set_normalize(False) if target_label is None: assert self.target_label is not None, "please specify argument `target_label` as the node label does not exist." target_label = self.target_label.view(-1) else: target_label = torch.as_tensor( target_label, device=self.device, dtype=torch.long).view(-1) best_wrong_label = self.strongest_wrong_class( target, target_label).view(-1) subgraph = self.get_subgraph(target, target_label, best_wrong_label) if not direct_attack: condition1 = subgraph.sub_edges[0] != target condition2 = subgraph.sub_edges[1] != target mask = torch.logical_and(condition1, condition2).float() for it in tqdm(range(self.num_budgets), desc='Peturbing graph...', disable=disable): non_edge_grad, edge_grad = self.compute_gradients(subgraph, target, target_label, best_wrong_label) with torch.no_grad(): edge_grad *= -2 * subgraph.edge_weight + 1 if not direct_attack: edge_grad *= mask non_edge_grad *= -2 * subgraph.non_edge_weight + 1 max_edge_grad, max_edge_idx = torch.max(edge_grad, dim=0) max_non_edge_grad, max_non_edge_idx = torch.max( non_edge_grad, dim=0) if max_edge_grad > max_non_edge_grad: # remove one edge subgraph.edge_weight.data[max_edge_idx].fill_(0.) u, v = subgraph.sub_edges[:, max_edge_idx].tolist() self.remove_edge(u, v, it) else: # add one edge subgraph.non_edge_weight.data[max_non_edge_idx].fill_(1.) u, v = subgraph.non_edges[:, max_non_edge_idx].tolist() self.add_edge(u, v, it) self.set_normalize(True) return self
[docs] def compute_gradients(self, subgraph, target, target_label, best_wrong_label): edge_weight = torch.cat([subgraph.non_edge_weight, subgraph.edge_weight, subgraph.non_edge_weight, subgraph.edge_weight, subgraph.selfloop_weight], dim=0) row, col = subgraph.edge_index norm = (self.degree + 1.).pow(-0.5) edge_weight = norm[row] * edge_weight * norm[col] logit = self.surrogate(self.feat, subgraph.edge_index, edge_weight) logit = logit[target].view(1, -1) / self.eps logit = F.log_softmax(logit, dim=1) loss = F.nll_loss(logit, target_label) - \ F.nll_loss(logit, best_wrong_label) return grad(loss, [subgraph.non_edge_weight, subgraph.edge_weight], create_graph=False)