Source code for farabio.models.segmentation.attunet.attunet_trainer

import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import albumentations as A
from albumentations.pytorch import ToTensor
from farabio.core.convnettrainer import ConvnetTrainer
from farabio.models.segmentation.attunet.attunet import AttUnet
from farabio.utils.regul import EarlyStopping
from farabio.utils.losses import Losses
from farabio.utils.tensorboard import TensorBoard
from farabio.utils.helpers import makedirs, parallel_state_dict
import skimage
from skimage import io, transform, img_as_ubyte
from skimage.io import imsave
from torchsummary import summary


[docs]class AttunetTrainer(ConvnetTrainer): """Attention U-Net trainer class. Override with custom methods here. Attention U-Net cannot fit into one GPU. Supports only parallel mode. Parameters ---------- ConvnetTrainer : BaseTrainer Inherits ConvnetTrainer class """
[docs] def define_data_attr(self, *args): self._in_ch = self.config.in_ch self._out_ch = self.config.out_ch self._data_path = self.config.data_path self._in_shape = self.config.shape
[docs] def define_model_attr(self, *args): self._semantic = self.config.semantic self._model_save_name = self.config.model_save_name self._model_save_dir = os.path.join(self.config.model_save_dir)
[docs] def define_train_attr(self, *args): self._epoch = self.config.start_epoch self._patience = self.config.patience self._early_stop = self.config.early_stop self._save_epoch = self.config.save_epoch self._has_eval = self.config.has_eval if self.config.optim == 'adam': self.optim = torch.optim.Adam if self._early_stop: self.early_stopping = EarlyStopping( patience=self._patience, verbose=True) makedirs(self._model_save_dir)
[docs] def define_test_attr(self, *args): self._output_img_dir = self.config.output_img_dir self._output_mask_dir = self.config.output_mask_dir self._output_overlay_dir = self.config.output_overlay_dir self._model_load_dir = self.config.model_load_dir
[docs] def define_log_attr(self, *args): self._use_visdom = self.config.use_visdom self._use_tensorboard = self.config.use_tensorboard if self._use_tensorboard: self.tb = TensorBoard(os.path.join(self._model_save_dir, "logs")) elif not self._use_tensorboard: self.tb = None
[docs] def define_compute_attr(self, *args): self._cuda = self.config.cuda self._device = torch.device(self.config.device) self._num_gpu = self.config.num_gpu self._num_workers = self.config.num_workers self._data_parallel = self.config.data_parallel
[docs] def define_misc_attr(self, *args): self._train_losses = [] self._val_losses = []
[docs] def get_trainloader(self): train_dataset = LoadDataSet( self._data_path, shape=self._in_shape, transform=get_train_transform(self._in_shape)) split_ratio = 0.25 train_size = int( np.round(train_dataset.__len__()*(1 - split_ratio), 0)) valid_size = int(np.round(train_dataset.__len__()*split_ratio, 0)) train_data, valid_data = random_split( train_dataset, [train_size, valid_size]) self.train_loader = DataLoader( dataset=train_data, batch_size=10, shuffle=True, num_workers=self._num_workers) self.valid_loader = DataLoader( dataset=valid_data, batch_size=10, num_workers=self._num_workers)
[docs] def get_testloader(self): self.test_loader = self.valid_loader
[docs] def build_model(self): self.model = AttUnet(self._in_ch, self._out_ch) if self._cuda: self.model.to(self._device) self.optimizer = self.optim(self.model.parameters(), lr=self.config.learning_rate)
[docs] def build_parallel_model(self): self.model = AttUnet(self._in_ch, self._out_ch) self.model = nn.DataParallel(self.model) self.model.to(self._device) self.optimizer = self.optim(list(self.model.parameters()), lr=self.config.learning_rate)
[docs] def show_model_summary(self, *args): print(summary(self.model, [(self._in_ch, self._in_shape, self._in_shape)]))
[docs] def load_model(self): self.model.load_state_dict(torch.load(self._model_load_dir))
[docs] def load_parallel_model(self): state_dict = torch.load(self._model_load_dir) _par_state_dict = parallel_state_dict(state_dict) self.model.load_state_dict(_par_state_dict)
[docs] def start_logger(self): if self._use_visdom: self.logger = None
[docs] def on_train_epoch_start(self): self.model.train() self.batch_tloss = 0 self.train_epoch_iter = enumerate(self.train_loader)
[docs] def on_start_training_batch(self, args): self.iteration = args[0] self.batch = args[-1]
[docs] def optimizer_zero_grad(self): self.optimizer.zero_grad()
[docs] def loss_backward(self): self.train_loss.backward()
[docs] def optimizer_step(self): self.optimizer.step()
[docs] def training_step(self): self.optimizer_zero_grad() self.imgs = self.batch[0] self.masks = self.batch[1] if self._cuda: self.imgs = self.imgs.to(self._device, dtype=torch.float32) self.masks = self.masks.to(self._device, dtype=torch.float32) self.outputs = self.model(self.imgs) if self._semantic: self.train_loss = Losses().extract_loss(self.outputs, self.masks) elif not self._semantic: self.train_loss = Losses().calc_loss(self.outputs, self.masks) self.batch_tloss += self.train_loss.item() self.loss_backward() self.optimizer_step()
[docs] def on_end_training_batch(self): if self._use_visdom: self.logger.log( images={ 'imgs': self.imgs, 'masks': self.masks, 'outputs': self.outputs } ) print( f"===> Epoch [{self._epoch}]({self.iteration}/{len(self.train_loader)}): Loss: {self.train_loss.item():.4f}")
[docs] def on_train_epoch_end(self): epoch_train_loss = round(self.batch_tloss / len(self.train_loader), 4) self._train_losses.append(epoch_train_loss) print( f"===> Epoch {self._epoch} Complete: Avg. Train Loss: {epoch_train_loss}")
[docs] def on_evaluate_epoch_start(self): self.batch_vloss = 0 self.model.eval() self.valid_epoch_iter = self.valid_loader
[docs] def evaluate_batch(self, args): self.batch = args imgs = self.batch[0] masks = self.batch[1] imgs = imgs.to(device=self._device, dtype=torch.float32) outputs = self.model(imgs) if self._semantic: masks = masks.to(device=self._device, dtype=torch.long) self.val_loss = Losses().extract_loss(outputs, masks, self._device) elif not self._semantic: masks = masks.to(device=self._device, dtype=torch.float32) self.val_loss = Losses().calc_loss(outputs, masks)
[docs] def on_evaluate_batch_end(self): self.batch_vloss += self.val_loss.item()
[docs] def on_evaluate_epoch_end(self): epoch_val_loss = round(self.batch_vloss / len(self.valid_loader), 4) self._val_losses.append(epoch_val_loss) print(f"===> Epoch {self._epoch} Valid Loss: {epoch_val_loss}") if self._use_tensorboard: self.tb.scalar_summary('val_loss', epoch_val_loss, self._epoch) if self._early_stop: self.early_stopping(epoch_val_loss, self.model, self._model_save_dir) self.early_stop = self.early_stopping.early_stop
[docs] def on_epoch_end(self): if self._epoch % self._save_epoch == 0: if self._data_parallel: self.save_parallel_model() elif not self._data_parallel: self.save_model() if self.early_stop: print("Early stopping") self._model_save_name = "attunet_es.pt" if self._data_parallel: self.save_parallel_model() elif not self._data_parallel: self.save_model() self.stop_train()
[docs] def save_model(self): torch.save(self.model.state_dict(), os.path.join(self._model_save_dir, self._model_save_name))
[docs] def save_parallel_model(self): torch.save(self.model.module.state_dict(), os.path.join(self._model_save_dir, self._model_save_name))
[docs] def on_test_start(self): self.model.eval() self.test_loop_iter = enumerate(self.test_loader)
[docs] def test_step(self, args): self.cur_batch = args[0] self.imgs = args[-1][0] self.fname = args[-1][-1] if self._cuda: self.imgs = self.imgs.to(device=self._device, dtype=torch.float32) outputs = self.model(self.imgs) self.pred = torch.sigmoid(outputs) self.pred = (self.pred > 0.5).bool() self.generate_result_img()
[docs] def generate_result_img(self, *args): """Generate image from batch: one by one """ for i in range(self.test_loader.batch_size): img_fname = self.fname[i] in_img = format_image(self.imgs[i].cpu().numpy()) out_img = format_mask(self.pred[i].cpu().numpy()) imsave(os.path.join(self._output_img_dir, img_fname), img_as_ubyte(in_img), check_contrast=False) imsave(os.path.join(self._output_mask_dir, img_fname), img_as_ubyte(out_img), check_contrast=False)
[docs] def on_end_test_batch(self): print(f"{self.cur_batch} / {len(self.test_loader)}")
def format_image(img): img = np.array(np.transpose(img, (1, 2, 0))) mean = np.array((0.485, 0.456, 0.406)) std = np.array((0.229, 0.224, 0.225)) img = std * img + mean img = img*255 img = img.astype(np.uint8) return img def format_mask(mask): mask = np.squeeze(np.transpose(mask, (1, 2, 0))) return mask def get_train_transform(img_shape): """Albumentations transform """ return A.Compose( [ A.Resize(img_shape, img_shape), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), A.HorizontalFlip(p=0.25), A.VerticalFlip(p=0.25), ToTensor() ]) class LoadDataSet(Dataset): """Dataset loader """ def __init__(self, path, shape, transform=None): self.path = path self.folders = os.listdir(path) self.transforms = get_train_transform(shape) self.shape = shape def __len__(self): return len(self.folders) def __getitem__(self, idx): image_folder = os.path.join(self.path, self.folders[idx], 'images/') mask_folder = os.path.join(self.path, self.folders[idx], 'masks/') fname = os.listdir(image_folder)[0] image_path = os.path.join(image_folder, fname) img = io.imread(image_path)[:, :, :3].astype('float32') img = transform.resize(img, (self.shape, self.shape)) mask = self.get_mask(mask_folder, self.shape, self.shape).astype('float32') augmented = self.transforms(image=img, mask=mask) img = augmented['image'] mask = augmented['mask'] mask = mask[0].permute(2, 0, 1) return (img, mask, fname) def get_mask(self, mask_folder, IMG_HEIGHT, IMG_WIDTH): mask = np.zeros((IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool) for mask_ in os.listdir(mask_folder): mask_ = io.imread(os.path.join(mask_folder, mask_)) mask_ = transform.resize(mask_, (IMG_HEIGHT, IMG_WIDTH)) mask_ = np.expand_dims(mask_, axis=-1) mask = np.maximum(mask, mask_) return mask