Source code for graphwar.nn.layers.elastic_conv

from typing import Optional

import torch
from torch import nn
from torch import Tensor
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import degree
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_sparse import SparseTensor, mul

from graphwar.functional import spmm


def get_inc(edge_index: Adj, num_nodes: Optional[int] = None) -> SparseTensor:
    """Compute the incident matrix
    """
    device = edge_index.device
    if torch.is_tensor(edge_index):
        row_index, col_index = edge_index
        num_nodes = maybe_num_nodes(edge_index, num_nodes)
    else:
        row_index = edge_index.storage.row()
        col_index = edge_index.storage.col()
        num_nodes = edge_index.sizes()[1]

    mask = row_index > col_index  # remove duplicate edge and self loop

    row_index = row_index[mask]
    col_index = col_index[mask]
    num_edges = row_index.numel()

    row = torch.cat([torch.arange(num_edges, device=device),
                    torch.arange(num_edges, device=device)])
    col = torch.cat([row_index, col_index])
    value = torch.cat([torch.ones(num_edges, device=device), -
                      torch.ones(num_edges, device=device)])
    inc_mat = SparseTensor(row=row, rowptr=None, col=col, value=value,
                           sparse_sizes=(num_edges, num_nodes))
    return inc_mat


def inc_norm(inc: SparseTensor, edge_index: Adj,
             num_nodes: Optional[int] = None) -> SparseTensor:
    """Normalize the incident matrix
    """

    if torch.is_tensor(edge_index):
        deg = degree(edge_index[0], num_nodes=num_nodes,
                     dtype=torch.float).clamp(min=1)
    else:
        deg = edge_index.sum(1).clamp(min=1)

    deg_inv_sqrt = deg.pow(-0.5)
    inc = mul(inc, deg_inv_sqrt.view(1, -1))  # col-wise
    return inc


[docs]class ElasticConv(nn.Module): r""" The ElasticGNN operator from the `"Elastic Graph Neural Networks" <https://arxiv.org/abs/2107.06996>`_ paper (ICML'21) Parameters ---------- K : int, optional the number of propagation steps, by default 3 lambda_amp : float, optional trade-off of adaptive message passing, by default 0.1 normalize : bool, optional Whether to add self-loops and compute symmetric normalization coefficients on the fly, by default True add_self_loops : bool, optional whether to add self-loops to the input graph, by default True lambda1 : float, optional trade-off hyperparameter, by default 3 lambda2 : float, optional trade-off hyperparameter, by default 3 L21 : bool, optional whether to use row-wise projection on the l2 ball of radius λ1., by default True cached : bool, optional whether to cache the incident matrix, by default True Note ---- The same as :class:`torch_geometric`, for the inputs :obj:`x`, :obj:`edge_index`, and :obj:`edge_weight`, our implementation supports: * :obj:`edge_index` is :class:`torch.LongTensor`: edge indices with shape :obj:`[2, M]` * :obj:`edge_index` is :class:`torch_sparse.SparseTensor`: sparse matrix with sparse shape :obj:`[N, N]` See also -------- :class:`graphwar.nn.models.ElasticGNN` """ _cached_inc: Optional[SparseTensor] = None # incident matrix def __init__(self, K: int = 3, lambda_amp: float = 0.1, normalize: bool = True, add_self_loops: bool = True, lambda1: float = 3., lambda2: float = 3., L21: bool = True, cached: bool = True): super().__init__() self.K = K self.lambda_amp = lambda_amp self.add_self_loops = add_self_loops self.normalize = normalize self.lambda1 = lambda1 self.lambda2 = lambda2 self.L21 = L21 self.cached = cached
[docs] def reset_parameters(self): self.cache_clear()
[docs] def cache_clear(self): """Clear cached inputs or intermediate results.""" self._cached_inc = None return self
[docs] def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: if self.normalize: if torch.is_tensor(edge_index): # NOTE: we do not support Dense adjacency matrix here edge_index, edge_weight = gcn_norm(edge_index, edge_weight, x.size(0), improved=False, add_self_loops=self.add_self_loops, dtype=x.dtype) else: edge_index = gcn_norm(edge_index, x.size(0), improved=False, add_self_loops=self.add_self_loops, dtype=x.dtype) cache = self._cached_inc if cache is None: # compute incident matrix before normalizing edge_index inc_mat = get_inc(edge_index, num_nodes=x.size(0)) # normalize incident matrix inc_mat = inc_norm(inc_mat, edge_index, num_nodes=x.size(0)) if self.cached: self._cached_inc = inc_mat self.init_z = x.new_zeros((inc_mat.sizes()[0], x.size()[-1])) else: inc_mat = self._cached_inc return self.emp_forward(x, inc_mat, edge_index, edge_weight)
[docs] def emp_forward(self, x: Tensor, inc_mat: SparseTensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: lambda1 = self.lambda1 lambda2 = self.lambda2 gamma = 1 / (1 + lambda2) beta = 1 / (2 * gamma) hh = x if lambda1: z = self.init_z for k in range(self.K): if lambda2: if torch.is_tensor(edge_index): out = spmm(x, edge_index, edge_weight) else: out = edge_index @ x y = gamma * hh + (1 - gamma) * out else: y = gamma * hh + (1 - gamma) * x # y = x - gamma * (x - hh) if lambda1: x_bar = y - gamma * (inc_mat.t() @ z) z_bar = z + beta * (inc_mat @ x_bar) if self.L21: z = self.L21_projection(z_bar, lambda_=lambda1) else: z = self.L1_projection(z_bar, lambda_=lambda1) x = y - gamma * (inc_mat.t() @ z) else: x = y # z=0 return x
[docs] def L1_projection(self, x: Tensor, lambda_: float) -> Tensor: """component-wise projection onto the l∞ ball of radius λ1.""" return torch.clamp(x, min=-lambda_, max=lambda_)
[docs] def L21_projection(self, x: Tensor, lambda_: float) -> Tensor: # row-wise projection on the l2 ball of radius λ1. row_norm = torch.norm(x, p=2, dim=1) scale = torch.clamp(row_norm, max=lambda_) index = row_norm > 0 scale[index] = scale[index] / \ row_norm[index] # avoid to be devided by 0 return scale.unsqueeze(1) * x
def __repr__(self) -> str: return (f'{self.__class__.__name__}(lambda_amp={self.lambda_amp}, K={self.K}) ' f'lambda1={self.lambda1}, lambda2={self.lambda2}')