Source code for graphwar.training.simpgcn_trainer

import torch.nn.functional as F
from graphwar.training import Trainer


[docs]class SimPGCNTrainer(Trainer): """Custom trainer for :class:`graphwar.nn.models.SimPGCN` Parameters ---------- model : nn.Module the model used for training device : Union[str, torch.device], optional the device used for training, by default 'cpu' cfg : other keyword arguments, such as `lr` and `weight_decay`. Note ---- :class:`graphwar.training.SimPGCNTrainer` accepts the following additional arguments: * :obj:`lambda_`: trade-off parameter for regression loss """
[docs] def train_step(self, inputs: dict) -> dict: """One-step training on the input dataloader. Parameters ---------- inputs : dict the training data. Returns ------- dict the output logs, including `loss` and `val_acc`, etc. """ model = self.model self.callbacks.on_train_batch_begin(0) model.train() data = inputs['data'].to(self.device) mask = inputs.get('mask', None) adj_t = getattr(data, 'adj_t', None) y = data.y if adj_t is None: out, embeddings = model(data.x, data.edge_index, data.edge_weight) else: out, embeddings = model(data.x, adj_t) if mask is not None: out = out[mask] y = y[mask] # ================= add regression loss here ==================== lambda_ = self.cfg.get("lambda_", 5.0) loss = F.cross_entropy(out, y) + lambda_ * \ model.regression_loss(embeddings) # =============================================================== loss.backward() self.callbacks.on_train_batch_end(0) return dict(loss=loss.item(), acc=out.argmax(1).eq(y).float().mean().item())