Source code for farabio.models.classification.class_trainer

import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.backends.cudnn as cudnn
from torchsummary import summary
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from farabio.core.convnettrainer import ConvnetTrainer
from farabio.utils.helpers import makedirs, parallel_state_dict
from farabio.utils.losses import Losses
from farabio.utils.loggers import Logger, savefig, progress_bar
from farabio.models.classification.conv import *


[docs]class ClassTrainer(ConvnetTrainer): """Classification trainer class. Override with custom methods here. Parameters ----------- ConvnetTrainer : BaseTrainer Inherits ConvnetTrainer class """
[docs] def define_data_attr(self, *args): self._title = self.config.title + self.config.arch self._train_batch_size = self.config.batch_size_train self._test_batch_size = self.config.batch_size_test self._classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
[docs] def define_model_attr(self, *args): self._arch = self.config.arch self._cardinality = self.config.cardinality self._num_classes = self.config.num_classes self._depth = self.config.depth self._widen_factor = self.config.widen_factor self._drop_rate = self.config.dropout self._growth_rate = self.config.growth_rate self._compression_rate = self.config.compression_rate self._block_name = self.config.block_name
[docs] def define_train_attr(self): self._lr = self.config.learning_rate self._momentum = self.config.momentum self._weight_decay = self.config.weight_decay self._schedule = self.config.schedule self._gamma = self.config.gamma self._num_epochs = self._num_epochs self._resume = self.config.resume
[docs] def define_compute_attr(self, *args): self._cuda = self.config.cuda self._device = self.config.device self._num_workers = self.config.num_workers self._data_parallel = self.config.data_parallel
[docs] def define_log_attr(self): self.best_accuracy = 0 self._checkpoint = self.config.checkpoint
[docs] def get_trainloader(self): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10( root='/home/data/02_SSD4TB/suzy/datasets/public', train=True, download=False, transform=transform_train) self.train_loader = DataLoader( trainset, batch_size=self._train_batch_size, shuffle=True, num_workers=self._num_workers)
[docs] def get_testloader(self): transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) testset = torchvision.datasets.CIFAR10( root='/home/data/02_SSD4TB/suzy/datasets/public', train=False, download=False, transform=transform_test) self.test_loader = DataLoader( testset, batch_size=self._test_batch_size, shuffle=False, num_workers=self._num_workers)
[docs] def build_model(self): print("==> creating model '{}'".format(self._arch)) class_models = { "densenet": DenseNet121(), "dpn92": DPN92(), "efficientnet": EfficientNetB0(), "googlenet": GoogLeNet(), "mobilenet": MobileNet(), "mobilenet2": MobileNetV2(), "preactresnet": PreActResNet18(), "regnet": RegNetX_200MF(), "resnet": ResNet18(), "resnext": ResNeXt29_2x64d(), "senet": SENet18(), "shufflenet2": ShuffleNetV2(1), "simpledla": SimpleDLA(), "vgg": VGG('VGG19') } self.model = class_models[self._arch] print(self._arch) if self._cuda: self.model.to(self._device) cudnn.benchmark = True self.criterion = nn.CrossEntropyLoss() self.optimizer = torch.optim.SGD(self.model.parameters( ), lr=self._lr, momentum=self._momentum, weight_decay=self._weight_decay) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=200)
[docs] def build_parallel_model(self): print("==> creating parallel model '{}'".format(self._arch)) class_models = { "densenet": DenseNet121(), "dpn92": DPN92(), "efficientnet": EfficientNetB0(), "googlenet": GoogLeNet(), "mobilenet": MobileNet(), "mobilenet2": MobileNetV2(), "preactresnet": PreActResNet18(), "regnet": RegNetX_200MF(), "resnet": ResNet18(), "resnext": ResNeXt29_2x64d(), "senet": SENet18(), "shufflenet2": ShuffleNetV2(1), "simpledla": SimpleDLA(), "vgg": VGG('VGG19') } self.model = class_models[self._arch] self.model = torch.nn.DataParallel(self.model).cuda() cudnn.benchmark = True self.criterion = nn.CrossEntropyLoss() self.optimizer = torch.optim.SGD(self.model.parameters( ), lr=self._lr, momentum=self._momentum, weight_decay=self._weight_decay)
[docs] def on_train_epoch_start(self): print(f'\nEpoch: {self._epoch}') self.model.train() self.train_loss = 0 self.correct = 0 self.total = 0 self.train_epoch_iter = enumerate(self.train_loader)
[docs] def on_start_training_batch(self, args): self.batch_idx = args[0] self.inputs = args[-1][0] self.targets = args[-1][-1]
[docs] def training_step(self): if self._cuda: self.inputs = self.inputs.to(self._device) self.targets = self.targets.to(self._device) self.optimizer_zero_grad() self.outputs = self.model(self.inputs) self.loss = self.criterion(self.outputs, self.targets) self.loss_backward() self.optimizer_step()
[docs] def on_end_training_batch(self): self.train_loss += self.loss.item() _, predicted = self.outputs.max(1) self.total += self.targets.size(0) self.correct += predicted.eq(self.targets).sum().item() progress_bar(self.batch_idx, len(self.train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (self.train_loss/(self.batch_idx+1), 100.*self.correct/self.total, self.correct, self.total))
[docs] def optimizer_zero_grad(self): self.optimizer.zero_grad()
[docs] def optimizer_step(self): self.optimizer.step()
[docs] def loss_backward(self): self.loss.backward()
[docs] def on_evaluate_epoch_start(self): self.model.eval() self.test_loss = 0 self.correct = 0 self.total = 0 self.valid_epoch_iter = enumerate(self.test_loader)
[docs] def on_evaluate_batch_start(self, args): self.batch_idx = args[0] self.inputs = args[-1][0] self.targets = args[-1][-1]
[docs] def evaluate_batch(self, args): if self._cuda: self.inputs = self.inputs.to(self._device) self.targets = self.targets.to(self._device) # async? # compute output self.outputs = self.model(self.inputs) self.loss = self.criterion(self.outputs, self.targets)
[docs] def on_evaluate_batch_end(self): self.test_loss += self.loss.item() _, predicted = self.outputs.max(1) self.total += self.targets.size(0) self.correct += predicted.eq(self.targets).sum().item() progress_bar(self.batch_idx, len(self.test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (self.test_loss/(self.batch_idx+1), 100.*self.correct/self.total, self.correct, self.total))
[docs] def on_evaluate_epoch_end(self): # Save checkpoint. acc = 100.*self.correct/self.total if acc > self.best_accuracy: print('Saving..') state = { 'net': self.model.state_dict(), 'acc': acc, 'epoch': self._epoch, } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, './checkpoint/ckpt.pth') self.best_acc = acc
[docs] def on_epoch_end(self): self.scheduler.step()