import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import albumentations as A
from albumentations.pytorch import ToTensor
from farabio.core.convnettrainer import ConvnetTrainer
from farabio.models.segmentation.unet.unet import Unet
from farabio.utils.regul import EarlyStopping
from farabio.utils.losses import FocalTverskyLoss, Losses, DiceBCELoss, DiceLoss, IoULoss, FocalLoss, LovaszHingeLoss, TverskyLoss, LovaszHingeLoss
from farabio.utils.tensorboard import TensorBoard
from farabio.utils.helpers import makedirs, parallel_state_dict
import skimage
from skimage import io, transform, img_as_ubyte
from skimage.io import imsave
from torchsummary import summary
from torchvision.datasets import ImageFolder
from farabio.data.biodatasets import DSB18Dataset
[docs]class UnetTrainer(ConvnetTrainer):
"""U-Net trainer class. Override with custom methods here.
Parameters
----------
ConvnetTrainer : BaseTrainer
Inherits ConvnetTrainer class
"""
[docs] def define_data_attr(self, *args):
self._in_ch = self.config.in_ch
self._out_ch = self.config.out_ch
self._data_path = self.config.data_path
self._in_shape = self.config.shape
[docs] def define_model_attr(self, *args):
self._semantic = self.config.semantic
self._model_save_name = self.config.model_save_name
self._model_save_dir = self.config.model_save_dir
self._criterion = self.config.criterion
[docs] def define_train_attr(self, *args):
self._epoch = self.config.start_epoch
self._patience = self.config.patience
self._early_stop = self.config.early_stop
self._save_epoch = self.config.save_epoch
self._has_eval = self.config.has_eval
if self.config.optim == 'adam':
self.optim = torch.optim.Adam
if self._early_stop:
self.early_stopping = EarlyStopping(
patience=self._patience, verbose=True)
makedirs(self._model_save_dir)
[docs] def define_test_attr(self, *args):
self._output_img_dir = self.config.output_img_dir
self._output_mask_dir = self.config.output_mask_dir
self._output_overlay_dir = self.config.output_overlay_dir
self._model_load_dir = self.config.model_load_dir
[docs] def define_log_attr(self, *args):
self._use_visdom = self.config.use_visdom
self._use_tensorboard = self.config.use_tensorboard
if self._use_tensorboard:
self.tb = TensorBoard(os.path.join(self._model_save_dir, "logs"))
elif not self._use_tensorboard:
self.tb = None
[docs] def define_compute_attr(self, *args):
self._cuda = self.config.cuda
self._device = torch.device(self.config.device)
self._num_gpu = self.config.num_gpu
self._num_workers = self.config.num_workers
self._data_parallel = self.config.data_parallel
[docs] def define_misc_attr(self, *args):
self._train_losses = []
self._val_losses = []
[docs] def get_trainloader(self):
train_dataset = DSB18Dataset(
root=self._data_path, transform=None, download=False)
split_ratio = 0.25
train_size = int(
np.round(train_dataset.__len__()*(1 - split_ratio), 0))
valid_size = int(np.round(train_dataset.__len__()*split_ratio, 0))
train_data, valid_data = random_split(
train_dataset, [train_size, valid_size])
self.train_loader = DataLoader(
dataset=train_data, batch_size=10, shuffle=True, num_workers=self._num_workers)
self.valid_loader = DataLoader(
dataset=valid_data, batch_size=10, num_workers=self._num_workers)
[docs] def get_testloader(self):
self.test_loader = self.valid_loader
[docs] def build_model(self):
self.model = Unet(self._in_ch, self._out_ch)
if self._cuda:
self.model.to(self._device)
_losses = {
"segmentation": {
"Dice": DiceLoss,
"DiceBCE": DiceBCELoss,
"IoU": IoULoss,
"Focal": FocalLoss,
"Tversky": TverskyLoss,
"FocalTversky": FocalTverskyLoss,
"Lovasz": LovaszHingeLoss
}
}
self.loss_type = _losses["segmentation"][self._criterion]()
self.optimizer = self.optim(self.model.parameters(),
lr=self.config.learning_rate)
[docs] def build_parallel_model(self):
self.model = Unet(self._in_ch, self._out_ch)
self.model = nn.DataParallel(self.model)
self.model.to(self._device)
_losses = {
"segmentation": {
"Dice": DiceLoss,
"DiceBCE": DiceBCELoss,
"IoU": IoULoss,
"Focal": FocalLoss,
"Tversky": TverskyLoss,
"FocalTversky": FocalTverskyLoss,
"Lovasz": LovaszHingeLoss
}
}
self.loss_type = _losses["segmentation"][self._criterion]()
self.optimizer = self.optim(list(self.model.parameters()),
lr=self.config.learning_rate)
[docs] def show_model_summary(self, *args):
print(summary(self.model, [(self._in_ch,
self._in_shape,
self._in_shape)]))
[docs] def load_model(self):
self.model.load_state_dict(torch.load(self._model_load_dir))
[docs] def load_parallel_model(self):
state_dict = torch.load(self._model_load_dir)
_par_state_dict = parallel_state_dict(state_dict)
self.model.load_state_dict(_par_state_dict)
[docs] def start_logger(self):
if self._use_visdom:
self.logger = None
[docs] def on_train_epoch_start(self):
self.batch_tloss = 0
self.model.train()
self.train_epoch_iter = enumerate(self.train_loader)
[docs] def on_start_training_batch(self, args):
self.iteration = args[0]
self.batch = args[-1]
[docs] def optimizer_zero_grad(self):
self.optimizer.zero_grad()
[docs] def loss_backward(self):
self.train_loss.backward()
[docs] def optimizer_step(self):
self.optimizer.step()
[docs] def training_step(self):
self.optimizer_zero_grad()
self.imgs = self.batch[0]
self.masks = self.batch[1]
if self._cuda:
self.imgs = self.imgs.to(self._device, dtype=torch.float32)
self.masks = self.masks.to(self._device, dtype=torch.float32)
self.outputs = self.model(self.imgs)
if self._semantic:
self.train_loss = Losses().extract_loss(self.outputs, self.masks)
elif not self._semantic:
self.train_loss = self.loss_type(self.outputs, self.masks)
# print(self.train_loss)
# print(type(self.train_loss))
#self.train_loss = Losses().calc_loss(self.outputs, self.masks)
self.batch_tloss += self.train_loss.item()
self.loss_backward()
self.optimizer_step()
[docs] def on_end_training_batch(self):
if self._use_visdom:
self.logger.log(
images={
'imgs': self.imgs,
'masks': self.masks,
'outputs': self.outputs
}
)
print(
f"===> Epoch [{self._epoch}]({self.iteration}/{len(self.train_loader)}): Loss: {self.train_loss.item():.4f}")
[docs] def on_train_epoch_end(self):
epoch_train_loss = round(self.batch_tloss / len(self.train_loader), 4)
self._train_losses.append(epoch_train_loss)
print(
f"===> Epoch {self._epoch} Complete: Avg. Train Loss: {epoch_train_loss}")
[docs] def on_evaluate_epoch_start(self):
self.batch_vloss = 0
self.model.eval()
self.valid_epoch_iter = self.valid_loader
[docs] def evaluate_batch(self, args):
self.batch = args
imgs = self.batch[0]
masks = self.batch[1]
imgs = imgs.to(device=self._device, dtype=torch.float32)
outputs = self.model(imgs)
if self._semantic:
masks = masks.to(device=self._device, dtype=torch.long)
self.val_loss = Losses().extract_loss(outputs, masks, self._device)
elif not self._semantic:
masks = masks.to(device=self._device, dtype=torch.float32)
self.val_loss = self.loss_type(outputs, masks)
#self.val_loss = Losses().calc_loss(outputs, masks)
[docs] def on_evaluate_batch_end(self):
self.batch_vloss += self.val_loss.item()
[docs] def on_evaluate_epoch_end(self):
epoch_val_loss = round(self.batch_vloss / len(self.valid_loader), 4)
self._val_losses.append(epoch_val_loss)
print(f"===> Epoch {self._epoch} Valid Loss: {epoch_val_loss}")
if self._use_tensorboard:
self.tb.scalar_summary('val_loss', epoch_val_loss, self._epoch)
if self._early_stop:
self.early_stopping(epoch_val_loss, self.model,
self._model_save_dir)
self.early_stop = self.early_stopping.early_stop
[docs] def on_epoch_end(self):
if self._epoch % self._save_epoch == 0:
if self._data_parallel:
self.save_parallel_model()
elif not self._data_parallel:
self.save_model()
if self.early_stop:
print("Early stopping")
self._model_save_name = "unet_es.pt"
if self._data_parallel:
self.save_parallel_model()
elif not self._data_parallel:
self.save_model()
self.stop_train()
[docs] def save_model(self):
torch.save(self.model.state_dict(), os.path.join(
self._model_save_dir, self._model_save_name))
[docs] def save_parallel_model(self):
torch.save(self.model.module.state_dict(), os.path.join(
self._model_save_dir, self._model_save_name))
[docs] def on_test_start(self):
self.model.eval()
self.test_loop_iter = enumerate(self.test_loader)
[docs] def test_step(self, args):
self.cur_batch = args[0]
self.imgs = args[-1][0]
self.fname = args[-1][-1]
if self._cuda:
self.imgs = self.imgs.to(device=self._device, dtype=torch.float32)
outputs = self.model(self.imgs)
self.pred = torch.sigmoid(outputs)
self.pred = (self.pred > 0.5).bool()
self.generate_result_img()
[docs] def generate_result_img(self, *args):
"""Generate image from batch: one by one
"""
for i in range(self.test_loader.batch_size):
img_fname = self.fname[i]
in_img = format_image(self.imgs[i].cpu().numpy())
out_img = format_mask(self.pred[i].cpu().numpy())
imsave(os.path.join(self._output_img_dir, img_fname),
img_as_ubyte(in_img), check_contrast=False)
imsave(os.path.join(self._output_mask_dir, img_fname),
img_as_ubyte(out_img), check_contrast=False)
[docs] def on_end_test_batch(self):
print(f"{self.cur_batch} / {len(self.test_loader)}")
def format_image(img):
img = np.array(np.transpose(img, (1, 2, 0)))
mean = np.array((0.485, 0.456, 0.406))
std = np.array((0.229, 0.224, 0.225))
img = std * img + mean
img = img*255
img = img.astype(np.uint8)
return img
def format_mask(mask):
mask = np.squeeze(np.transpose(mask, (1, 2, 0)))
return mask