import sys
from typing import Any, Callable, List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from graphwar.training.callbacks import (Callback, CallbackList, Optimizer,
Scheduler)
from torch import Tensor
from torch.utils.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as PyGLoader
from graphwar.utils import BunchDict, Progbar
[docs]class Trainer:
"""A simple trainer to train graph neural network models conveniently.
Parameters
----------
model : nn.Module
the model used for training
device : Union[str, torch.device], optional
the device used for training, by default 'cpu'
cfg : other keyword arguments, such as `lr` and `weight_decay`.
Example
-------
>>> from graphwar.training import Trainer
>>> model = ... # your model
>>> trainer = Trainer(model, device='cuda')
>>> data # PyG-like data, e.g., Cora
Data(x=[2485, 1433], edge_index=[2, 10138], y=[2485])
>>> # simple training
>>> trainer.fit({'data': data, 'mask': train_mask})
>>> # train with model picking
>>> from graphwar.training import ModelCheckpoint
>>> cb = ModelCheckpoint('my_ckpt', monitor='val_acc')
>>> trainer.fit({'data': data, 'mask': train_mask},
{'data': data, 'mask': val_mask}, callbacks=[cb])
>>> # get training logs
>>> history = trainer.model.history
>>> trainer.evaluate({'data': data, 'mask': data.test_mask}) # evaluation
>>> predict = trainer.predict({'data': data, 'mask': your_mask}) # prediction
"""
def __init__(self, model: nn.Module, device: Union[str, torch.device] = 'cpu', **cfg):
self.device = torch.device(device)
self.model = model.to(self.device)
cfg.setdefault("lr", 1e-2)
cfg.setdefault("weight_decay", 5e-4)
self.cfg = BunchDict(cfg)
self.optimizer = self.config_optimizer()
self.scheduler = self.config_scheduler(self.optimizer)
[docs] def fit(self, train_inputs, val_inputs: Optional[dict] = None,
callbacks: Optional[Callback] = None,
verbose: Optional[int] = 1, epochs: int = 100) -> "Trainer":
"""Simple training method designed for `:attr:model`
Parameters
----------
train_inputs : dict like or custom inputs
training data. It is used for `train_step`.
val_inputs : Optional[dict]
used for validation.
callbacks : Optional[Callback], optional
callbacks used for training,
see `graphwar.training.callbacks`, by default None
verbose : Optional[int], optional
verbosity during training, can be:
:obj:`None, 1, 2, 3, 4`, by default 1
epochs : int, optional
training epochs, by default 100
Example
-------
>>> # simple training
>>> trainer.fit({'data': data, 'mask': train_mask})
>>> # train with model picking
>>> from graphwar.training import ModelCheckpoint
>>> cb = ModelCheckpoint('my_ckpt', monitor='val_acc')
>>> trainer.fit({'data': data, 'mask': train_mask},
{'data': data, 'mask': val_mask}, callbacks=[cb])
"""
model = self.model.to(self.device)
model.stop_training = False
validation = val_inputs is not None
# Setup callbacks
self.callbacks = callbacks = self.config_callbacks(
verbose, epochs, callbacks=callbacks)
logs = BunchDict()
if verbose:
print("Training...")
callbacks.on_train_begin()
try:
for epoch in range(epochs):
callbacks.on_epoch_begin(epoch)
train_logs = self.train_step(train_inputs)
logs.update(train_logs)
if validation:
val_logs = self.test_step(val_inputs)
val_logs = {f'val_{k}': v for k, v in val_logs.items()}
logs.update(val_logs)
callbacks.on_epoch_end(epoch, logs)
if model.stop_training:
print(f"Early Stopping at Epoch {epoch}", file=sys.stderr)
break
finally:
callbacks.on_train_end()
return self
[docs] def train_step(self, inputs: dict) -> dict:
"""One-step training on the inputs.
Parameters
----------
inputs : dict like or custom inputs
the training data.
Returns
-------
dict
the output logs, including `loss` and `acc`, etc.
"""
model = self.model
self.callbacks.on_train_batch_begin(0)
model.train()
data = inputs['data'].to(self.device)
mask = inputs.get('mask', None)
adj_t = getattr(data, 'adj_t', None)
y = data.y.squeeze()
if adj_t is None:
out = model(data.x, data.edge_index, data.edge_weight)
else:
out = model(data.x, adj_t)
if mask is not None:
out = out[mask]
y = y[mask]
loss = F.cross_entropy(out, y)
loss.backward()
self.callbacks.on_train_batch_end(0)
return dict(loss=loss.item(), acc=out.argmax(1).eq(y).float().mean().item())
[docs] def evaluate(self, inputs: dict, verbose: Optional[int] = 1) -> BunchDict:
"""Simple evaluation step for `:attr:model`
Parameters
----------
inputs : dict like or custom inputs
test data, it is used for `test_step`.
verbose : Optional[int], optional
verbosity during evaluation, by default 1
Returns
-------
BunchDict
the dict-like output logs
Example
-------
>>> trainer.evaluate({'data': data, 'mask': data.test_mask}) # evaluation
"""
if verbose:
print("Evaluating...")
self.model = self.model.to(self.device)
progbar = Progbar(target=1,
verbose=verbose)
logs = BunchDict(**self.test_step(inputs))
progbar.update(1, logs)
return logs
[docs] @torch.no_grad()
def test_step(self, inputs: dict) -> dict:
"""One-step evaluation on the inputs.
Parameters
----------
inputs : dict like or custom inputs
the testing data.
Returns
-------
dict
the output logs, including `loss` and `acc`, etc.
"""
model = self.model
model.eval()
data = inputs['data'].to(self.device)
mask = inputs.get('mask', None)
adj_t = getattr(data, 'adj_t', None)
y = data.y.squeeze()
if adj_t is None:
out = model(data.x, data.edge_index, data.edge_weight)
else:
out = model(data.x, adj_t)
if mask is not None:
out = out[mask]
y = y[mask]
loss = F.cross_entropy(out, y)
return dict(loss=loss.item(), acc=out.argmax(1).eq(y).float().mean().item())
[docs] @torch.no_grad()
def predict_step(self, inputs: dict) -> Tensor:
"""One-step prediction on the inputs.
Parameters
----------
inputs : dict like or custom inputs
the prediction data.
Returns
-------
Tensor
the output prediction.
"""
model = self.model
model.eval()
callbacks = self.callbacks
data = inputs['data'].to(self.device)
mask = inputs.get('mask', None)
adj_t = getattr(data, 'adj_t', None)
if adj_t is None:
out = model(data.x, data.edge_index, data.edge_weight)
else:
out = model(data.x, adj_t)
if mask is not None:
out = out[mask]
return out
[docs] def predict(self, inputs: dict,
transform: Callable = torch.nn.Softmax(dim=-1)) -> Tensor:
"""
Parameters
----------
inputs : dict like or custom inputs
predict data, it is used for `predict_step`
transform : Callable
Callable function applied on output predictions.
Example
-------
>>> predict = trainer.predict({'data': data, 'mask': mask_or_not_given}) # prediction
"""
self.model.to(self.device)
out = self.predict_step(inputs).squeeze()
if transform is not None:
out = transform(out)
return out
[docs] def config_optimizer(self) -> torch.optim.Optimizer:
lr = self.cfg.get('lr', 0.01)
weight_decay = self.cfg.get('weight_decay', 5e-4)
return torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
[docs] def config_scheduler(self, optimizer: torch.optim.Optimizer):
return None
[docs] def config_callbacks(self, verbose, epochs, callbacks=None) -> CallbackList:
callbacks = CallbackList(callbacks=callbacks, add_history=True,
add_progbar=True if verbose else False)
if self.optimizer is not None:
callbacks.append(Optimizer(self.optimizer))
if self.scheduler is not None:
callbacks.append(Scheduler(self.scheduler))
callbacks.set_model(self.model)
callbacks.set_params(dict(verbose=verbose, epochs=epochs))
return callbacks
@property
def model(self):
return self._model
@model.setter
def model(self, m):
assert m is None or isinstance(m, torch.nn.Module)
self._model = m
[docs] def cache_clear(self) -> "Trainer":
"""Clear cached inputs or intermediate results
of the model."""
if hasattr(self.model, 'cache_clear'):
self.model.cache_clear()
return self
def __repr__(self) -> str:
model = self.model
return f"{self.__class__.__name__}(model={model.__class__.__name__}{self.extra_repr()})"
__str__ = __repr__