gantrainer Module

GanTrainer class makes use of hooks. Hooks are a collection of methods which provide quick access to exact entry in loop. In this way, we can override these methods with custom functionality in either training, evaluation or test loops.

Lifecycle hooks

Training loop

digraph Trainloop {

   graph [fontsize=8];
   edge [fontsize=8];
   node [fontsize=8, shape=box, colorscheme=set36, style=rounded];

   train -> on_train_start;
   on_train_start -> start_logger;
   start_logger -> train_loop;
   train_loop -> on_train_end;
   train_loop -> train_epoch;
   train_epoch -> on_train_epoch_start;
   on_epoch_end -> on_train_epoch_start [label = "_num_epochs"];
   on_train_epoch_start -> train_batch;
   train_batch -> on_train_epoch_end;
   on_train_epoch_end -> evaluate_epoch;
   evaluate_epoch -> on_epoch_end;
   train_batch -> on_start_training_batch;
   on_start_training_batch -> discriminator_zero_grad;
   discriminator_zero_grad -> discriminator_loss;
   discriminator_loss -> discriminator_backward;
   discriminator_backward -> discriminator_optim_step;
   discriminator_optim_step -> generator_zero_grad;
   generator_zero_grad -> generator_loss;
   generator_loss -> generator_backward;
   generator_backward -> generator_optim_step;
   generator_optim_step -> on_end_training_batch;
   on_end_training_batch -> on_start_training_batch [label = "train_epoch_iter"];

   train  [fillcolor=1, style="rounded,filled"]
   on_train_start  [fillcolor=2, style="rounded,filled"]
   start_logger  [fillcolor=2, style="rounded,filled"]
   train_loop  [fillcolor=2, style="rounded,filled"]
   on_train_end  [fillcolor=2, style="rounded,filled"]
   train_epoch  [fillcolor=3, style="rounded,filled"]
   on_train_epoch_start  [fillcolor=4, style="rounded,filled"]
   train_batch  [fillcolor=4, style="rounded,filled"]
   on_train_epoch_end  [fillcolor=4, style="rounded,filled"]
   evaluate_epoch  [fillcolor=4, style="rounded,filled,dashed"]
   on_epoch_end  [fillcolor=4, style="rounded,filled"]
   on_start_training_batch  [fillcolor=5, style="rounded,filled"]
   discriminator_zero_grad  [fillcolor=5, style="rounded,filled"]
   discriminator_loss  [fillcolor=5, style="rounded,filled"]
   discriminator_backward  [fillcolor=5, style="rounded,filled"]
   discriminator_optim_step  [fillcolor=5, style="rounded,filled"]
   generator_zero_grad  [fillcolor=5, style="rounded,filled"]
   generator_loss  [fillcolor=5, style="rounded,filled"]
   generator_backward  [fillcolor=5, style="rounded,filled"]
   generator_optim_step  [fillcolor=5, style="rounded,filled"]
   on_end_training_batch  [fillcolor=5, style="rounded,filled"]
   }

Evaluation Loop

digraph Evalloop {

   graph [fontsize=8];
   edge [fontsize=8];
   node [fontsize=8, shape=box, colorscheme=set36, style=rounded];

   evaluate_epoch -> on_evaluate_epoch_start
   on_evaluate_epoch_start -> on_evaluate_batch_start;
   on_evaluate_batch_start -> evaluate_batch;
   evaluate_batch -> on_evaluate_batch_end;
   on_evaluate_batch_end -> on_evaluate_epoch_end;
   on_evaluate_batch_end -> on_evaluate_batch_start [label = "valid_epoch_iter"];

   evaluate_epoch  [fillcolor=4, style="rounded,filled,dashed"]
   on_evaluate_epoch_start  [fillcolor=5, style="rounded,filled"]
   on_evaluate_epoch_end  [fillcolor=5, style="rounded,filled"]
   on_evaluate_batch_start  [fillcolor=6, style="rounded,filled"]
   evaluate_batch  [fillcolor=6, style="rounded,filled"]
   on_evaluate_batch_end  [fillcolor=6, style="rounded,filled"]

   }

Test loop

digraph Testloop {

   graph [fontsize=8];
   edge [fontsize=8];
   node [fontsize=8, shape=box, colorscheme=set33 , style=rounded];

   test -> load_model;
   test -> load_parallel_model;
   load_model -> on_test_start;
   load_parallel_model -> on_test_start;
   on_test_start -> test_loop;
   test_loop -> on_test_end;
   test_loop -> on_start_test_batch;
   on_start_test_batch -> test_step;
   test_step -> on_end_test_batch;
   on_end_test_batch -> on_start_test_batch [label = "test_loop_iter"];

   test  [fillcolor=1, style="rounded,filled"]
   load_model  [fillcolor=2, style="rounded,filled"]
   load_parallel_model  [fillcolor=2, style="rounded,filled"]
   on_test_start  [fillcolor=2, style="rounded,filled"]
   test_loop  [fillcolor=2, style="rounded,filled"]
   on_test_end  [fillcolor=2, style="rounded,filled"]
   on_start_test_batch  [fillcolor=3, style="rounded,filled"]
   test_step  [fillcolor=3, style="rounded,filled"]
   on_end_test_batch  [fillcolor=3, style="rounded,filled"]

   }

Docs

class farabio.core.gantrainer.GanTrainer(config)[source]
__init__(config)[source]

Initializes trainer object

default_attr(*args)[source]
init_attr(*args)[source]

Abstract method that initializes object attributes

define_data_attr(*args)[source]

Define data related attributes here

define_model_attr(*args)[source]

Define model related attributes here

define_train_attr(*args)[source]

Define training related attributes here

define_test_attr(*args)[source]

Define training related attributes here

define_log_attr(*args)[source]

Define log related attributes here

define_compute_attr(*args)[source]

Define compute related attributes here

define_misc_attr(*args)[source]

Define miscellaneous attributes here

build_model(*args)[source]

Abstract method that builds model

get_trainloader(*args)[source]

Hook: Retreives training set of torch.utils.data.DataLoader class

get_testloader(*args)[source]

Hook: Retreives test set of torch.utils.data.DataLoader class

train()[source]

Training loop with hooks

train_loop()[source]

Hook: training loop

train_epoch()[source]

Hook: epoch of training loop

train_batch(args)[source]

Hook: batch of training loop

on_train_start()[source]

Hook: On start of training loop

start_logger(*args)[source]

Hook: Starts logger

on_train_epoch_start()[source]

Hook: On epoch start

on_start_training_batch(*args)[source]

Hook: On training batch start

on_end_training_batch(*args)[source]

Hook: On end of training batch

on_train_epoch_end(*args)[source]

Hook: On end of training epoch

on_train_end()[source]

Hook: On end of training

stop_train(*args)[source]

On end of training

evaluate_epoch()[source]

Hook: epoch of evaluation loop

Parameters
epochint

Current epoch

evaluate_batch(*args)[source]

Hook: batch of evaluation loop

on_evaluate_start(*args)[source]

Hook: on evaluation end

on_evaluate_epoch_start()[source]

Hook: on evaluation start

on_evaluate_batch_start()[source]
on_evaluate_batch_end()[source]

Hook: On evaluate batch end

on_evaluate_epoch_end(*args)[source]
on_evaluate_end(*args)[source]

Hook: on evaluation end

on_epoch_end(*args)[source]

Hook: on epoch end

test()[source]

Hook: Test lifecycle

test_loop()[source]

Hook: test loop

get_dataloader()[source]

Hook: Retreives torch.utils.data.DataLoader object

on_test_start(*args)[source]

Hook: on test start

on_start_test_batch(*args)[source]

Hook: on test batch start

test_step(*args)[source]

Test action (Put test here)

on_end_test_batch(*args)[source]

Hook: on end of batch test

on_test_end(*args)[source]

Hook: on end test

load_model(*args)[source]

Hook: load model

save_model(*args)[source]

Hook: saves model

discriminator_zero_grad()[source]

Hook: Zero gradients of discriminator

discriminator_loss(*args)[source]

Hook: Training action (Put training here)

discriminator_backward()[source]

Hook: Discriminator back-propagation

discriminator_optim_step()[source]

Discriminator optimizer step

generator_zero_grad()[source]

Hook: Zero gradients of generator

generator_loss(*args)[source]

Hook: Training action (Put training here)

generator_backward()[source]

Hook: sends backward

generator_optim_step()[source]

Discriminator optimizer step