Source code for graphwar.utils.modification



import torch
from torch import Tensor
import scipy.sparse as sp
from torch_geometric.utils import sort_edge_index, to_scipy_sparse_matrix, from_scipy_sparse_matrix


[docs]def add_edges(edge_index: Tensor, edges_to_add: Tensor, symmetric: bool = True, sort_edges: bool = True) -> Tensor: """Add edges to the graph denoted as :obj:`edge_index`. Parameters ---------- edge_index : Tensor the graph instance where edges will be removed from. edges_to_add : torch.Tensor shape [2, M], the edges to be added into the graph. symmetric : bool whether the output graph is symmetric, if True, it would add the edges into the graph by: :obj:`edges_to_add = torch.cat([edges_to_add, edges_to_add.flip(0)], dim=1)` Returns ------- Tensor the graph instance :obj:`edge_index` with edges added. """ if symmetric: edges_to_add = torch.cat([edges_to_add, edges_to_add.flip(0)], dim=1) edges_to_add = edges_to_add.to(edge_index) edge_index = torch.cat([edge_index, edges_to_add], dim=1) edge_index = sort_edge_index(edge_index) return edge_index
[docs]def remove_edges(edge_index: Tensor, edges_to_remove: Tensor, symmetric: bool = True) -> Tensor: """Remove edges from the graph denoted as :obj:`edge_index`. Parameters ---------- edge_index : Tensor the graph instance where edges will be removed from. edges_to_remove : torch.Tensor shape [2, M], the edges to be removed in the graph. symmetric : bool whether the output graph is symmetric, if True, it would remove the edges from the graph by: :obj:`edges_to_remove = torch.cat([edges_to_remove, edges_to_remove.flip(0)], dim=1)` Returns ------- Tensor the graph instance :obj:`edge_index` with edges removed. """ device = edge_index.device if symmetric: edges_to_remove = torch.cat( [edges_to_remove, edges_to_remove.flip(0)], dim=1) edges_to_remove = edges_to_remove.to(edge_index) # it's not intuitive to remove edges from a graph represented as `edge_index` edge_weight_remove = torch.zeros(edges_to_remove.size(1)) - 1e5 edge_weight = torch.cat( [torch.ones(edge_index.size(1)), edge_weight_remove], dim=0) edge_index = torch.cat([edge_index, edges_to_remove], dim=1).cpu().numpy() adj_matrix = sp.csr_matrix( (edge_weight.cpu().numpy(), (edge_index[0], edge_index[1]))) adj_matrix.data[adj_matrix.data < 0] = 0. adj_matrix.eliminate_zeros() edge_index, _ = from_scipy_sparse_matrix(adj_matrix) return edge_index.to(device)
[docs]def flip_edges(edge_index: Tensor, edges_to_flip: Tensor, symmetric: bool = True) -> Tensor: """Flip edges from the graph denoted as :obj:`edge_index`. Parameters ---------- edge_index : Tensor the graph instance where edges will be flipped from. edges_to_flip : torch.Tensor shape [2, M], the edges to be flipped in the graph. symmetric : bool whether the output graph is symmetric, if True, it would flip the edges from the graph by: :obj:`edges_to_flip = torch.cat([edges_to_flip, edges_to_flip.flip(0)], dim=1)` Returns ------- Tensor the graph instance :obj:`edge_index` with edges flipped. """ # TODO raise NotImplementedError