Source code for farabio.core.convnettrainer

import torch
import sys
from farabio.core.basetrainer import BaseTrainer


[docs]class ConvnetTrainer(BaseTrainer): """ConvnetTrainer is main trainer class for every ConvNet related architectures. Parameters ---------- BaseTrainer : ABC Inherits BaseTrainer class """
[docs] def __init__(self, config): """Initializes trainer object """ super().__init__() self.config = config self.default_attr() self.init_attr() self.get_trainloader() self.get_testloader() if self._data_parallel: self.build_parallel_model() elif not self._data_parallel: self.build_model()
########################## # Definition of attributes ##########################
[docs] def default_attr(self, *args): self._num_epochs = 10 self._mode = 'train' self._save_epoch = 1 self._start_epoch = 1 self._has_eval = True self._eval_interval = 1 self.train_loader = None self.valid_loader = None self.test_loader = None self.model = None self._model_path = None self.train_epoch_iter = None self.valid_epoch_iter = None self.test_loop_iter = None self._use_tqdm = False self._data_parallel = None self._next_loop = False self._epoch = 1 self._backbone = False
[docs] def init_attr(self, *args): """Abstract method that initializes object attributes """ self.define_data_attr() self.define_model_attr() self.define_train_attr() self.define_test_attr() self.define_log_attr() self.define_compute_attr() self.define_misc_attr()
[docs] def define_data_attr(self, *args): """Define data related attributes here """ pass
[docs] def define_model_attr(self, *args): """Define model related attributes here """ pass
[docs] def define_train_attr(self, *args): """Define training related attributes here """ pass
[docs] def define_test_attr(self, *args): """Define training related attributes here """ pass
[docs] def define_log_attr(self, *args): """Define log related attributes here """ pass
[docs] def define_compute_attr(self, *args): """Define compute related attributes here """ pass
[docs] def define_misc_attr(self, *args): """Define miscellaneous attributes here """ pass
########################## # Building model ##########################
[docs] def build_model(self, *args): """Abstract method that builds model """ pass
[docs] def build_parallel_model(self, *args): """Abstract method that builds multi-GPU model in parallel """ pass
########################## # Dataloaders ##########################
[docs] def get_trainloader(self, *args): """Hook: Retreives training set of torch.utils.data.DataLoader class """ pass
[docs] def get_testloader(self, *args): """Hook: Retreives test set of torch.utils.data.DataLoader class """ pass
########################## # Training related ##########################
[docs] def train(self): """Training loop with hooksa """ self.on_train_start() self.start_logger() self.train_loop() self.on_train_end()
[docs] def train_loop(self): """Hook: training loop """ for self._epoch in range(self._start_epoch, self._num_epochs+1): self.train_epoch()
[docs] def train_epoch(self): """Hook: epoch of training loop """ self.on_train_epoch_start() for train_epoch_var in self.train_epoch_iter: self.train_batch(train_epoch_var) self.on_train_epoch_end() if self._has_eval: if self._epoch % self._eval_interval == 0: self.evaluate_epoch() self.on_epoch_end()
[docs] def train_batch(self, args): """Hook: batch of training loop """ self.on_start_training_batch(args) self.training_step() self.on_end_training_batch()
[docs] def on_train_start(self): """Hook: On start of training loop """ self.batch_training_loss = 0
[docs] def start_logger(self, *args): """Hook: Starts logger """ pass
[docs] def on_train_epoch_start(self): """Hook: On epoch start """ self.batch_training_loss = 0 self.model.train()
[docs] def on_start_training_batch(self, *args): """Hook: On training batch start """ pass
[docs] def optimizer_zero_grad(self, *args): """Hook: Zero gradients of optimizer """
[docs] def training_step(self, *args): """Hook: During training batch """ pass
[docs] def loss_backward(self, *args): """Hook: Loss back-propagation """
[docs] def optimizer_step(self): """Hook: Optimizer step """ pass
[docs] def on_end_training_batch(self, *args): """Hook: On end of training batch """ pass
[docs] def on_train_epoch_end(self, *args): """Hook: On end of training epoch """ pass
[docs] def on_train_end(self): """Hook: On end of training """ pass
[docs] def on_epoch_end(self, *args): """Hook: on epoch end """ pass
[docs] def stop_train(self, *args): """On end of training """ sys.exit()
########################## # Evaluation loop related ##########################
[docs] def evaluate_epoch(self): """Hook: epoch of evaluation loop Parameters ---------- epoch : int Current epoch """ with torch.no_grad(): self.on_evaluate_epoch_start() for valid_epoch_var in self.valid_epoch_iter: self.on_evaluate_batch_start(valid_epoch_var) if self._next_loop: self._next_loop = False continue self.evaluate_batch(valid_epoch_var) self.on_evaluate_batch_end() self.on_evaluate_epoch_end()
[docs] def evaluate_batch(self, *args): """Hook: batch of evaluation loop """ pass
[docs] def on_evaluate_start(self, *args): """Hook: on evaluation end """ pass
[docs] def on_evaluate_epoch_start(self): """Hook: on evaluation start """ raise NotImplementedError
[docs] def on_evaluate_batch_start(self, *args): pass
[docs] def on_evaluate_batch_end(self): """Hook: On evaluate batch end """ pass
[docs] def on_evaluate_epoch_end(self, *args): pass
[docs] def on_evaluate_end(self, *args): """Hook: on evaluation end """ pass
########################## # Test loop related ##########################
[docs] def test(self): """Hook: Test lifecycle """ if self._data_parallel: self.load_parallel_model() elif not self._data_parallel: self.load_model() self.on_test_start() self.test_loop() self.on_test_end()
[docs] def test_loop(self): """Hook: test loop """ for test_loop_var in self.test_loop_iter: self.on_start_test_batch() self.test_step(test_loop_var) self.on_end_test_batch()
[docs] def on_test_start(self, *args): """Hook: on test start """ pass
[docs] def on_start_test_batch(self, *args): """Hook: on test batch start """ pass
[docs] def test_step(self, *args): """Test action (Put test here) """ pass
[docs] def on_end_test_batch(self, *args): """Hook: on end of batch test """ pass
[docs] def on_test_end(self, *args): """Hook: on end test """ pass
########################## # Handle models ##########################
[docs] def load_model(self, *args): """Hook: load model """ pass
[docs] def load_parallel_model(self, *args): """Hook: load parallel model """ pass
[docs] def save_model(self, *args): """Hook: saves model """ pass
[docs] def save_parallel_model(self, *args): """Hook: saves parallel model """ pass
########################## # Miscellaneous ##########################
[docs] def exit_trainer(self, *args): """Exits trainer """ sys.exit()