Source code for farabio.models.classification.transformer_trainer
import os
import random
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from tqdm.std import tqdm
from farabio.core.convnettrainer import ConvnetTrainer
from farabio.models.classification.vit.linformer import Linformer
from farabio.models.classification.vit.efficient import ViT
from farabio.utils.losses import Losses
from farabio.utils.loggers import Logger
from farabio.data.biodatasets import RANZCRDataset
[docs]class TransformerTrainer(ConvnetTrainer):
"""Classification trainer class. Override with custom methods here.
Parameters
-----------
ConvnetTrainer : BaseTrainer
Inherits ConvnetTrainer class
"""
[docs] def define_data_attr(self, *args):
self._root = "/home/data/02_SSD4TB/suzy/datasets/public"
self._batch_size = self.config.batch_size
self._dataset = self.config.dataset
[docs] def define_train_attr(self):
self._lr = self.config.learning_rate
self._gamma = self.config.gamma
[docs] def seed_everything(self):
random.seed(self._seed)
os.environ["PYTHONSEED"] = str(self._seed)
np.random.seed(self._seed)
torch.cuda.manual_seed(self._seed)
torch.cuda.manual_seed_all(self._seed)
torch.backends.cudnn.deterministic = True
[docs] def get_trainloader(self):
if self._dataset == 'RANZCRDataset':
train_dataset = RANZCRDataset(
root=self._root, train=True, transform=None, download=False)
self.train_loader = DataLoader(dataset=train_dataset,
batch_size=self._batch_size, shuffle=True)
[docs] def get_testloader(self):
if self._dataset == 'RANZCRDataset':
valid_dataset = RANZCRDataset(
root=self._root, train=False, transform=None, download=False)
self.valid_loader = DataLoader(dataset=valid_dataset,
batch_size=self._batch_size, shuffle=True)
[docs] def build_model(self):
print(f"==> creating model {self._title}")
efficient_transformer = Linformer(
dim=128,
seq_len=49+1, # 7x7 patches + 1 cls-token
depth=12,
heads=8,
k=64
)
self.model = ViT(
dim=128,
image_size=224,
patch_size=32,
num_classes=11,
transformer=efficient_transformer,
channels=3,
)
if self._cuda:
self.model.to(self._device)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self._lr)
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=self._gamma)
[docs] def on_train_epoch_start(self):
print(f'\nEpoch: {self._epoch}')
self.model.train()
self.epoch_loss = 0.0
self.epoch_accuracy = 0.0
self._i = 0
self.train_epoch_iter = tqdm(self.train_loader)
[docs] def training_step(self):
#print(self._i)
self._i = self._i + 1
if self._cuda:
self._data = self._data.to(self._device)
self._label = self._label.to(self._device)
self._output = self.model(self._data)
self._label = self._label.type(torch.cuda.LongTensor)
self.loss = self.criterion(self._output, self._label)
self.optimizer_zero_grad()
self.loss_backward()
self.optimizer_step()
[docs] def on_end_training_batch(self):
self.acc = (self._output.argmax(dim=1) == self._label).float().mean()
self.epoch_accuracy += self.acc / len(self.train_loader)
self.epoch_loss += self.loss / len(self.train_loader)
[docs] def on_evaluate_epoch_start(self):
self.model.eval()
self._j = 0
self.epoch_val_accuracy = 0
self.epoch_val_loss = 0
self.valid_epoch_iter = enumerate(self.test_loader)
[docs] def evaluate_batch(self, args):
self._j = self._j + 1
if self._cuda:
self._data = self._data.to(self._device)
self._label = self._label.to(self._device) # async?
# compute output
self.val_output = self.model(self._data)
self.label = self.label.type(torch.cuda.LongTensor)
self.loss = self.criterion(self.val_output, self._label)
[docs] def on_evaluate_batch_end(self):
self.acc = (self.val_output.argmax(dim=1) == self.label).float().mean()
self.epoch_val_accuracy += self.acc / len(self.valid_loader)
self.epoch_val_loss += self.val_loss / len(self.valid_loader)
[docs] def on_epoch_end(self):
print(
f"Epoch: {self.epoch+1} - loss: {self.epoch_loss:.4f} - acc : {self.epoch_accuracy:.4f} - val_loss : {self.epoch_val_loss:.4f} - val_acc: {self.epoch_val_accuracy: .4f}\n"
)