Source code for graphwar.nn.layers.adaptive_conv

import torch
from torch import nn
from torch import Tensor
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.typing import Adj, OptTensor
from torch_sparse import SparseTensor

from graphwar import is_edge_index
from graphwar.functional import spmm
from graphwar.nn.layers.gcn_conv import dense_gcn_norm


[docs]class AdaptiveConv(nn.Module): r"""The AirGNN operator from the `"Graph Neural Networks with Adaptive Residual" <https://openreview.net/forum?id=hfkER_KJiNw>`_ paper (NeurIPS'21) Parameters ---------- K : int, optional the number of propagation steps during message passing, by default 3 lambda_amp : float, optional trade-off for adaptive message passing, by default 0.1 normalize : bool, optional Whether to add self-loops and compute symmetric normalization coefficients on the fly, by default True add_self_loops : bool, optional whether to add self-loops to the input graph, by default True Note ---- Different from that in :class:`torch_geometric`, for the inputs :obj:`x`, :obj:`edge_index`, and :obj:`edge_weight`, our implementation supports: * :obj:`edge_index` is :class:`torch.FloatTensor`: dense adjacency matrix with shape :obj:`[N, N]` * :obj:`edge_index` is :class:`torch.LongTensor`: edge indices with shape :obj:`[2, M]` * :obj:`edge_index` is :class:`torch_sparse.SparseTensor`: sparse matrix with sparse shape :obj:`[N, N]` See also -------- :class:`graphwar.nn.models.AirGNN` """ def __init__(self, K: int = 3, lambda_amp: float = 0.1, normalize: bool = True, add_self_loops: bool = True): super().__init__() self.K = K self.lambda_amp = lambda_amp self.add_self_loops = add_self_loops self.normalize = normalize
[docs] def reset_parameters(self): pass
[docs] def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: is_edge_like = is_edge_index(edge_index) if self.normalize: if is_edge_like: edge_index, edge_weight = gcn_norm(edge_index, edge_weight, x.size(0), improved=False, add_self_loops=self.add_self_loops, dtype=x.dtype) elif isinstance(edge_index, SparseTensor): edge_index = gcn_norm(edge_index, x.size(0), improved=False, add_self_loops=self.add_self_loops, dtype=x.dtype) else: # N by N dense adjacency matrix edge_index = dense_gcn_norm(edge_index, improved=False, add_self_loops=self.add_self_loops) return self.amp_forward(x, edge_index, edge_weight)
[docs] def amp_forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: lambda_amp = self.lambda_amp gamma = 1 / (2 * (1 - lambda_amp)) # or simply gamma = 1 hh: Tensor = x for k in range(self.K): y = x - gamma * 2 * \ (1 - lambda_amp) * self.compute_LX(x, edge_index, edge_weight) # Equation (9) # Equation (11) and (12) x = hh + self.proximal_L21(x=y - hh, lambda_=gamma * lambda_amp) return x
[docs] def proximal_L21(self, x: Tensor, lambda_: float) -> Tensor: row_norm = torch.norm(x, p=2, dim=1) score = torch.clamp(row_norm - lambda_, min=0) # Deal with the case when the row_norm is 0 index = torch.where(row_norm > 0) # score is the adaptive score in Equation (14) score[index] = score[index] / row_norm[index] return score.unsqueeze(1) * x
[docs] def compute_LX(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: is_edge_like = is_edge_index(edge_index) if is_edge_like: out = spmm(x, edge_index, edge_weight) else: out = edge_index @ x return x - out
def __repr__(self) -> str: return (f'{self.__class__.__name__}(lambda_amp={self.lambda_amp}, K={self.K})')