from typing import Union
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, OptPairTensor
from torch_sparse import SparseTensor
from graphwar.functional import spmm
from graphwar import is_edge_index
from graphwar.nn.layers.gcn_conv import dense_gcn_norm
[docs]class RobustConv(nn.Module):
r"""The robust graph convolutional operator
from the `"Robust Graph Convolutional Networks
Against Adversarial Attacks"
<http://pengcui.thumedialab.com/papers/RGCN.pdf>`_ paper (KDD'19)
Parameters
----------
in_channels : int
dimensions of int samples
out_channels : int
dimensions of output samples
gamma : float, optional
the scale of attention on the variances, by default 1.0
add_self_loops : bool, optional
whether to add self-loops to the input graph, by default True
bias : bool, optional
whether to use bias in the layers, by default True
Note
----
Different from that in :class:`torch_geometric`,
For the inputs :obj:`x`, :obj:`edge_index`, and :obj:`edge_weight`,
our implementation supports:
* :obj:`edge_index` is :class:`torch.FloatTensor`: dense adjacency matrix with shape :obj:`[N, N]`
* :obj:`edge_index` is :class:`torch.LongTensor`: edge indices with shape :obj:`[2, M]`
* :obj:`edge_index` is :class:`torch_sparse.SparseTensor`: sparse matrix with sparse shape :obj:`[N, N]`
See also
--------
:class:`graphwar.nn.models.RobustGCN`
"""
def __init__(self, in_channels: int, out_channels: int,
gamma: float = 1.0,
add_self_loops: bool = True,
bias: bool = True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.add_self_loops = add_self_loops
self.gamma = gamma
self.lin_mean = Linear(in_channels, out_channels, bias=False,
weight_initializer='glorot')
self.lin_var = Linear(in_channels, out_channels, bias=False,
weight_initializer='glorot')
if bias:
self.bias_mean = nn.Parameter(torch.Tensor(out_channels))
self.bias_var = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias_mean', None)
self.register_parameter('bias_var', None)
self.reset_parameters()
[docs] def reset_parameters(self):
self.lin_mean.reset_parameters()
self.lin_var.reset_parameters()
zeros(self.bias_mean)
zeros(self.bias_var)
[docs] def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
if isinstance(x, Tensor):
x_mean = x_var = x
else:
x_mean, x_var = x
mean = self.lin_mean(x_mean)
var = self.lin_var(x_var)
if self.bias_mean is not None:
mean = mean + self.bias_mean
var = var + self.bias_var
mean = F.relu(mean)
var = F.relu(var)
is_edge_like = is_edge_index(edge_index)
if is_edge_like:
edge_index, edge_weight = gcn_norm(edge_index, edge_weight, mean.size(0),
improved=False,
add_self_loops=self.add_self_loops,
dtype=mean.dtype)
elif isinstance(edge_index, SparseTensor):
adj = gcn_norm(edge_index, mean.size(0),
improved=False,
add_self_loops=self.add_self_loops, dtype=mean.dtype)
else:
# N by N dense adjacency matrix
adj = dense_gcn_norm(edge_index, improved=False,
add_self_loops=self.add_self_loops)
attention = torch.exp(-self.gamma * var)
mean = mean * attention
var = var * attention * attention
# TODO: actually, using .square() is not always right,
# particularly weighted graph
if is_edge_like:
mean = spmm(mean, edge_index, edge_weight)
var = spmm(var, edge_index, edge_weight.square())
else:
mean = adj @ mean
var = (adj * adj) @ var
return mean, var
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels})')