Source code for farabio.core.gantrainer

import torch
import sys
from farabio.core.basetrainer import BaseTrainer


[docs]class GanTrainer(BaseTrainer):
[docs] def __init__(self, config): """Initializes trainer object """ super().__init__() self.config = config self.default_attr() self.init_attr() self.get_trainloader() self.get_testloader() self.build_model()
########################## # Definition of attributes ##########################
[docs] def default_attr(self, *args): self._mode = "train" self._num_epochs = 10 self.train_loader = None self.valid_loader = None self.test_loader = None self.model = None self.model_load_dir = None self.optimizerD = None self.optimizerG = None self._model_path = None self.train_epoch_iter = None self.valid_epoch_iter = None self.test_loop_iter = None self._save_epoch = 1 self._start_epoch = 1 self._has_eval = True
[docs] def init_attr(self, *args): """Abstract method that initializes object attributes """ self.define_data_attr() self.define_model_attr() self.define_train_attr() self.define_test_attr() self.define_log_attr() self.define_compute_attr() self.define_misc_attr()
[docs] def define_data_attr(self, *args): """Define data related attributes here """ pass
[docs] def define_model_attr(self, *args): """Define model related attributes here """ pass
[docs] def define_train_attr(self, *args): """Define training related attributes here """ pass
[docs] def define_test_attr(self, *args): """Define training related attributes here """ pass
[docs] def define_log_attr(self, *args): """Define log related attributes here """ pass
[docs] def define_compute_attr(self, *args): """Define compute related attributes here """ pass
[docs] def define_misc_attr(self, *args): """Define miscellaneous attributes here """ pass
########################## # Building model ##########################
[docs] def build_model(self, *args): """Abstract method that builds model """ pass
########################## # Dataloaders ##########################
[docs] def get_trainloader(self, *args): """Hook: Retreives training set of torch.utils.data.DataLoader class """ pass
[docs] def get_testloader(self, *args): """Hook: Retreives test set of torch.utils.data.DataLoader class """ pass
########################## # Training related ##########################
[docs] def train(self): """Training loop with hooks """ # hook to do on tran instart self.on_train_start() self.start_logger() self.train_loop() self.on_train_end()
[docs] def train_loop(self): """Hook: training loop """ for self.epoch in range(self._start_epoch, self._num_epochs+1): self.train_epoch()
[docs] def train_epoch(self): """Hook: epoch of training loop """ self.on_train_epoch_start() for train_epoch_var in self.train_epoch_iter: self.train_batch(train_epoch_var) self.on_train_epoch_end() if self._has_eval: self.evaluate_epoch() self.on_epoch_end()
[docs] def train_batch(self, args): """Hook: batch of training loop """ self.on_start_training_batch(args) # ##### Discriminator ###### self.discriminator_zero_grad() self.discriminator_loss() self.discriminator_backward() self.discriminator_optim_step() # ##### Generator ###### self.generator_zero_grad() self.generator_loss() self.generator_backward() self.generator_optim_step() self.on_end_training_batch()
[docs] def on_train_start(self): """Hook: On start of training loop """ self.batch_training_loss = 0
[docs] def start_logger(self, *args): """Hook: Starts logger """ pass
[docs] def on_train_epoch_start(self): """Hook: On epoch start """ self.batch_training_loss = 0 self.model.train()
[docs] def on_start_training_batch(self, *args): """Hook: On training batch start """ pass
[docs] def on_end_training_batch(self, *args): """Hook: On end of training batch """ pass
[docs] def on_train_epoch_end(self, *args): """Hook: On end of training epoch """ pass
[docs] def on_train_end(self): """Hook: On end of training """ pass
[docs] def stop_train(self, *args): """On end of training """ sys.exit()
########################## # Evaluation loop related ##########################
[docs] def evaluate_epoch(self): """Hook: epoch of evaluation loop Parameters ---------- epoch : int Current epoch """ # what happens if i just put inside nograd with torch.no_grad(): self.on_evaluate_epoch_start() for valid_epoch_var in self.valid_epoch_iter: self.on_evaluate_batch_start() self.evaluate_batch(valid_epoch_var) self.on_evaluate_batch_end() self.on_evaluate_epoch_end()
[docs] def evaluate_batch(self, *args): """Hook: batch of evaluation loop """ pass
[docs] def on_evaluate_start(self, *args): """Hook: on evaluation end """ pass
[docs] def on_evaluate_epoch_start(self): """Hook: on evaluation start """ raise NotImplementedError
[docs] def on_evaluate_batch_start(self): pass
[docs] def on_evaluate_batch_end(self): """Hook: On evaluate batch end """ pass
[docs] def on_evaluate_epoch_end(self, *args): pass
[docs] def on_evaluate_end(self, *args): """Hook: on evaluation end """ pass
[docs] def on_epoch_end(self, *args): """Hook: on epoch end """ pass
########################## # Test loop related ##########################
[docs] def test(self): """Hook: Test lifecycle """ # self.load_model(self._model_path) self.load_model() self.on_test_start() self.test_loop() self.on_test_end()
[docs] def test_loop(self): """Hook: test loop """ for test_loop_var in self.test_loop_iter: self.on_start_test_batch() self.test_step(test_loop_var) self.on_end_test_batch()
[docs] def get_dataloader(self): """Hook: Retreives torch.utils.data.DataLoader object """ pass
[docs] def on_test_start(self, *args): """Hook: on test start """ pass
[docs] def on_start_test_batch(self, *args): """Hook: on test batch start """ pass
# @abstractmethod
[docs] def test_step(self, *args): """Test action (Put test here) """ pass
[docs] def on_end_test_batch(self, *args): """Hook: on end of batch test """ pass
[docs] def on_test_end(self, *args): """Hook: on end test """ pass
########################## # Handle models ##########################
[docs] def load_model(self, *args): """Hook: load model """ pass
[docs] def save_model(self, *args): """Hook: saves model """ pass
########################## # GAN specific ##########################
[docs] def discriminator_zero_grad(self): """Hook: Zero gradients of discriminator """ pass
[docs] def discriminator_loss(self, *args): """Hook: Training action (Put training here) """ raise NotImplementedError
[docs] def discriminator_backward(self): """Hook: Discriminator back-propagation """ pass
[docs] def discriminator_optim_step(self): """Discriminator optimizer step """ pass
[docs] def generator_zero_grad(self): """Hook: Zero gradients of generator """ pass
[docs] def generator_loss(self, *args): """Hook: Training action (Put training here) """ raise NotImplementedError
[docs] def generator_backward(self): """Hook: sends backward """ pass
[docs] def generator_optim_step(self): """Discriminator optimizer step """ pass