pydgn.training
training.callback
training.event
training.engine
- class pydgn.training.engine.LinkPredictionSingleGraphEngine(engine_callback: Callable[[...], pydgn.training.callback.engine_callback.EngineCallback], model: pydgn.model.interface.ModelInterface, loss: pydgn.training.callback.metric.Metric, optimizer: pydgn.training.callback.optimizer.Optimizer, scorer: pydgn.training.callback.metric.Metric, scheduler: Optional[pydgn.training.callback.scheduler.Scheduler] = None, early_stopper: Optional[pydgn.training.callback.early_stopping.EarlyStopper] = None, gradient_clipper: Optional[pydgn.training.callback.gradient_clipping.GradientClipper] = None, device: str = 'cpu', plotter: Optional[pydgn.training.callback.plotter.Plotter] = None, exp_path: Optional[str] = None, evaluate_every: int = 1, store_last_checkpoint: bool = False)
Bases:
pydgn.training.engine.TrainingEngine
Specific engine for link prediction tasks. Here, we expect target values in the form of tuples:
(_, pos_edges, neg_edges)
, wherepos_edges
andneg_edges
have been generated by the splitter and provided by the data provider.
- class pydgn.training.engine.TrainingEngine(engine_callback: Callable[[...], pydgn.training.callback.engine_callback.EngineCallback], model: pydgn.model.interface.ModelInterface, loss: pydgn.training.callback.metric.Metric, optimizer: pydgn.training.callback.optimizer.Optimizer, scorer: pydgn.training.callback.metric.Metric, scheduler: Optional[pydgn.training.callback.scheduler.Scheduler] = None, early_stopper: Optional[pydgn.training.callback.early_stopping.EarlyStopper] = None, gradient_clipper: Optional[pydgn.training.callback.gradient_clipping.GradientClipper] = None, device: str = 'cpu', plotter: Optional[pydgn.training.callback.plotter.Plotter] = None, exp_path: Optional[str] = None, evaluate_every: int = 1, store_last_checkpoint: bool = False)
Bases:
pydgn.training.event.dispatcher.EventDispatcher
This is the most important class when it comes to training a model. It implements the
EventDispatcher
interface, which means that after registering some callbacks in a given order, it will proceed to trigger specific events that will result in the sharedState
object being updated by the callbacks. Callbacks implement the EventHandler interface, and they receive the shared State object when any event is triggered. Knowing the order in which callbacks are called is important. The order is:loss function
score function
gradient clipper
optimizer
early stopper
scheduler
plotter
- Parameters
engine_callback (Callable[…,
EngineCallback
]) – the engine callback object to be used for data fetching and checkpoints (or even other purposes if necessary)model (
ModelInterface
) – the model to be trainedloss (
Metric
) – the loss to be usedoptimizer (
Optimizer
) – the optimizer to be usedscorer (
Metric
) – the score to be usedscheduler (
Scheduler
) – the scheduler to be used Default isNone
.early_stopper (
EarlyStopper
) – the early stopper to be used. Default isNone
.gradient_clipper (
GradientClipper
) – the gradient clipper to be used. Default isNone
.device (str) – the device on which to train. Default is
cpu
.plotter (
Plotter
) – the plotter to be used. Default isNone
.exp_path (str) – the path of the experiment folder. Default is
None
but it is always instantiated.evaluate_every (int) – the frequency of logging epoch results. Default is
1
.store_last_checkpoint (bool) – whether to store a checkpoint at the end of each epoch. Allows to resume training from last epoch. Default is
False
.
- infer(loader: torch_geometric.loader.dataloader.DataLoader, set: str) Tuple[dict, dict, List[torch_geometric.data.data.Data]]
Performs an evaluation step on the data.
- Parameters
loader (
torch_geometric.loader.DataLoader
) – the loader to be usedset (str) – the type of dataset being used, can be
TRAINING
,VALIDATION
orTEST
(as defined inpydgn.static
)
- Returns
a tuple (loss dict, score dict, list of
torch_geometric.data.Data
objects withx
andy
attributes only). The data list can be used, for instance, in semi-supervised experiments or in incremental architectures
- set_device()
Moves the model and the loss metric to the proper device.
- set_eval_mode()
Sets the model and the internal state in
EVALUATION
mode
- set_training_mode()
Sets the model and the internal state in
TRAINING
mode
- train(train_loader: torch_geometric.loader.dataloader.DataLoader, validation_loader: Optional[torch_geometric.loader.dataloader.DataLoader] = None, test_loader: Optional[torch_geometric.loader.dataloader.DataLoader] = None, max_epochs: int = 100, zero_epoch: bool = False, logger: Optional[pydgn.log.logger.Logger] = None) Tuple[dict, dict, List[torch_geometric.data.data.Data], dict, dict, List[torch_geometric.data.data.Data], dict, dict, List[torch_geometric.data.data.Data]]
Trains the model and regularly evaluates on validation and test data (if given). May perform early stopping and checkpointing.
- Parameters
train_loader (
torch_geometric.loader.DataLoader
) – the DataLoader associated with training datavalidation_loader (
torch_geometric.loader.DataLoader
) – the DataLoader associated with validation data, if anytest_loader (
torch_geometric.loader.DataLoader
) – the DataLoader associated with test data, if anymax_epochs (int) – maximum number of training epochs. Default is
100
zero_epoch – if
True
, starts again from epoch 0 and resets optimizer and scheduler states. Default isFalse
logger – the logger
- Returns
a tuple (train_loss, train_score, train_embeddings, validation_loss, validation_score, validation_embeddings, test_loss, test_score, test_embeddings)
- pydgn.training.engine.log(msg, logger: pydgn.log.logger.Logger)
training.profiler
- class pydgn.training.profiler.Profiler(threshold: float)
Bases:
object
A decorator class that is applied to a
EventHandler
object implementing a set of callback functions. For each callback, the Profiler stores the average and total running time across epochs. When the experiment terminates (either correctly or abruptly) the Profiler can produce a report to be stored in the experiment’s log file.The Profiler is used as a singleton, and it produces wrappers that update its own state.
- Parameters
threshold (float) – used to filter out callback functions that consume a negligible amount of time from the report
- Usage:
Istantiate a profiler, and then register an event_handler with the syntax profiler(event_handler), which returns another object implementing the
EventHandler
interface
- report() str
Builds a report string containing the statistics of the experiment accumulated so far.
- Returns
a string containing the report
training.util
- pydgn.training.util.atomic_save(data: dict, filepath: str)
Atomically stores a dictionary that can be serialized by
torch.save()
, exploiting the atomicos.replace()
.- Parameters
data (dict) – the dictionary to be stored
filepath (str) – the absolute filepath where to store the dictionary