Source code for farabio.models.superres.srgan.srgan_trainer

import os
import time
import math
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.utils as vutils
from farabio.data.transforms import display_transform
from farabio.core.gantrainer import GanTrainer
from farabio.utils.losses import GeneratorLoss
from farabio.utils.metrics import ssim
from farabio.utils.helpers import x, makedirs
from farabio.models.superres.srgan.srgan import Generator, Discriminator
from farabio.data.datasets import TrainDatasetFromFolder, ValDatasetFromFolder, TestDatasetFromFolder


[docs]class SrganTrainer(GanTrainer): """SrganTrainer trainer class. Override with custom methods here. Parameters ---------- GanTrainer : parent object Parent object of SrganTrainer """
[docs] def define_data_attr(self): self._trainset_dir = self.config.train_set self._validset_dir = self.config.valid_set self._testset_dir = self.config.test_set self._batch_size_train = self.config.batch_size_train self._batch_size_valid = self.config.batch_size_valid self._batch_size_test = self.config.batch_size_test self._upscale_factor = self.config.upscale_factor self._crop_size = self.config.crop_size
[docs] def define_model_attr(self): self._model_path = self.config.model_path self._model_save_dir = self.config.model_save_dir
[docs] def define_train_attr(self): self._num_epochs = self.config.num_epochs self._start_epoch = self.config.start_epoch self._has_eval = True if self.config.optim == 'adam': self.optim = torch.optim.Adam
[docs] def define_log_attr(self): self._save_epoch = self.config.save_epoch self._save_csv_epoch = self.config.save_csv_epoch
[docs] def define_compute_attr(self): self._cuda = self.config.cuda self._num_workers = self.config.num_workers
[docs] def define_misc_attr(self): self._mode = self.config.mode
[docs] def get_trainloader(self): if self._mode == 'train': train_set = TrainDatasetFromFolder( self._trainset_dir, crop_size=self._crop_size, upscale_factor=self._upscale_factor) valid_set = ValDatasetFromFolder( self._validset_dir, upscale_factor=self._upscale_factor) self.train_loader = DataLoader( dataset=train_set, num_workers=self._num_workers, batch_size=self._batch_size_train, shuffle=True) self.valid_loader = DataLoader(dataset=valid_set, num_workers=4, batch_size=self._batch_size_valid, shuffle=False)
[docs] def get_testloader(self): if self._mode == 'test': test_set = TestDatasetFromFolder( self._testset_dir, upscale_factor=self._upscale_factor) self.test_loader = DataLoader( dataset=test_set, num_workers=self._num_workers, batch_size=self._batch_size_test, shuffle=False)
# self.load_model(config.model_name)
[docs] def build_model(self): """Build model Parameters ---------- epoch : int current epoch """ self.netG = Generator(self._upscale_factor) self.netD = Discriminator() self.generator_criterion = GeneratorLoss() if self._cuda: self.netG.cuda() self.netD.cuda() self.generator_criterion.cuda() self.optimizerG = self.optim(self.netG.parameters()) self.optimizerD = self.optim(self.netD.parameters())
# print('# generator parameters:', sum(param.numel() for param in netG.parameters())) # print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
[docs] def start_logger(self): self.results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
[docs] def on_train_epoch_start(self): self.train_epoch_iter = tqdm(self.train_loader) self.running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0} self.netG.train() self.netD.train()
[docs] def train_batch(self, args): self.on_start_training_batch(args) ###### Discriminator ###### self.discriminator_zero_grad() self.discriminator_loss() self.discriminator_backward() self.discriminator_optim_step() ###### Generator ###### self.generator_zero_grad() self.generator_loss() self.generator_backward() self.generator_optim_step() self.on_end_training_batch()
[docs] def on_start_training_batch(self, args): self.data = args[0] self.target = args[1] self.batch_size = self.data.size(0) self.running_results['batch_sizes'] += self.batch_size
[docs] def discriminator_zero_grad(self): self.netD.zero_grad()
[docs] def discriminator_loss(self): ############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### self.real_img = Variable(self.target) self.z = Variable(self.data) if self._cuda: self.real_img = self.real_img.cuda() self.z = self.z.cuda() fake_img = self.netG(self.z) self.real_out = self.netD(self.real_img).mean() fake_out = self.netD(fake_img).mean() self._discriminator_loss = 1 - self.real_out + fake_out
[docs] def discriminator_optim_step(self): """Discriminator optimizer step """ self.optimizerD.step()
[docs] def generator_zero_grad(self): """Zero grad """ self.netG.zero_grad()
[docs] def generator_loss(self): ############################ # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss ########################### fake_img = self.netG(self.z) self.fake_out = self.netD(fake_img).mean() self._generator_loss = self.generator_criterion( self.fake_out, fake_img, self.real_img)
[docs] def generator_backward(self): """Hook: sends backward """ self._generator_loss.backward()
[docs] def generator_optim_step(self): """Discriminator optimizer step """ self.optimizerG.step()
[docs] def optimizer_zero_grad(self): """Zero grad """ self.netG.zero_grad() self.netD.zero_grad()
[docs] def discriminator_backward(self): self._discriminator_loss.backward(retain_graph=True)
[docs] def on_end_training_batch(self): # loss for current batch before optimization self.running_results['g_loss'] += self._generator_loss.item() * \ self.batch_size self.running_results['d_loss'] += self._discriminator_loss.item() * \ self.batch_size self.running_results['d_score'] += self.real_out.item() * \ self.batch_size self.running_results['g_score'] += self.fake_out.item() * \ self.batch_size self.train_epoch_iter.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % ( self.epoch, self._num_epochs, self.running_results['d_loss'] / self.running_results['batch_sizes'], self.running_results['g_loss'] / self.running_results['batch_sizes'], self.running_results['d_score'] / self.running_results['batch_sizes'], self.running_results['g_score'] / self.running_results['batch_sizes']))
[docs] def on_epoch_end(self): if self.epoch % self._save_csv_epoch == 0 and self.epoch != 0: self.save_model() self.save_csv()
[docs] def on_evaluate_epoch_start(self): self.netG.eval() self.out_img_path = os.path.join(self._model_save_dir, "SRF_" + str(self._upscale_factor), "output") makedirs(self.out_img_path) self.valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0} self.val_images = [] self.valid_epoch_iter = tqdm(self.valid_loader)
[docs] def evaluate_batch(self, args): val_lr = args[0] val_hr_restore = args[1] val_hr = args[2] self.valing_results['batch_sizes'] += self._batch_size_valid lr = val_lr self.hr = val_hr if self._cuda: lr = lr.cuda() self.hr = self.hr.cuda() self.sr = self.netG(lr) self.val_hr_restore = val_hr_restore self.batch_mse = ((self.sr - self.hr) ** 2).data.mean() self.valing_results['mse'] += self.batch_mse * self._batch_size_valid self.batch_ssim = ssim(self.sr, self.hr).item()
[docs] def on_evaluate_batch_end(self): self.valing_results['ssims'] += self.batch_ssim * \ self._batch_size_valid self.valing_results['psnr'] = 10 * math.log10((self.hr.max()**2) / ( self.valing_results['mse'] / self.valing_results['batch_sizes'])) self.valing_results['ssim'] = self.valing_results['ssims'] / \ self.valing_results['batch_sizes'] self.valid_epoch_iter.set_description( desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % ( self.valing_results['psnr'], self.valing_results['ssim'])) self.val_images.extend( [display_transform()(self.val_hr_restore.squeeze(0)), display_transform()(self.hr.data.cpu().squeeze(0)), display_transform()(self.sr.data.cpu().squeeze(0))])
[docs] def on_evaluate_epoch_end(self): self.val_images = torch.stack(self.val_images) self.val_images = torch.chunk( self.val_images, self.val_images.size(0) // 15) val_save_bar = tqdm(self.val_images, desc='[saving training results]') index = 1 for image in val_save_bar: image = vutils.make_grid(image, nrow=3, padding=5) vutils.save_image( image, self.out_img_path + 'epoch_%d_index_%d.png' % (self.epoch, index), padding=5) index += 1
[docs] def save_model(self): """Save model Parameters ---------- epoch : int current epoch """ out_model_path = os.path.join(self._model_save_dir, "epochs") makedirs(out_model_path) g_name = f'netG_epoch_{self._upscale_factor}_{self.epoch}.pth' d_name = f'netD_epoch_{self._upscale_factor}_{self.epoch}.pth' torch.save(self.netG.state_dict(), os.path.join(out_model_path, g_name)) torch.save(self.netG.state_dict(), os.path.join(out_model_path, d_name)) self.results['d_loss'].append( self.running_results['d_loss'] / self.running_results['batch_sizes']) self.results['g_loss'].append( self.running_results['g_loss'] / self.running_results['batch_sizes']) self.results['d_score'].append( self.running_results['d_score'] / self.running_results['batch_sizes']) self.results['g_score'].append( self.running_results['g_score'] / self.running_results['batch_sizes']) self.results['psnr'].append(self.valing_results['psnr']) self.results['ssim'].append(self.valing_results['ssim'])
[docs] def save_csv(self): # save loss\scores\psnr\ssim out_stat_path = os.path.join(self._model_save_dir, "SRF_" + str(self._upscale_factor), "statistics") makedirs(out_stat_path) print("saving .csv file") print(self.results) data_frame = pd.DataFrame( data={'Loss_D': self.results['d_loss'], 'Loss_G': self.results['g_loss'], 'Score_D': self.results['d_score'], 'Score_G': self.results['g_score'], 'PSNR': self.results['psnr'], 'SSIM': self.results['ssim']}, index=range(1, self.epoch+1)) csv_name = 'srf_' + str(self._upscale_factor) + '_train_results.csv' data_frame.to_csv(os.path.join( out_stat_path, csv_name), index_label='Epoch')
[docs] def load_model(self): self.netG.eval() self.netG.load_state_dict(torch.load(self._model_path))
[docs] def test_batch(self, model_name): pass
[docs] def on_test_start(self): self.results = {'Set5': {'psnr': [], 'ssim': []}, 'Set14': {'psnr': [], 'ssim': []}, 'BSD100': {'psnr': [], 'ssim': []}, 'Urban100': {'psnr': [], 'ssim': []}, 'SunHays80': {'psnr': [], 'ssim': []}} self.out_bench_path = os.path.join(self._model_save_dir, "SRF_" + str(self._upscale_factor), "benchmark_results") makedirs(self.out_bench_path) test_bar = tqdm(self.test_loader, desc='[testing benchmark datasets]') self.test_loop_iter = test_bar
[docs] def test_step(self, test_arg): image_name = test_arg[0] lr_image = test_arg[1] hr_restore_img = test_arg[2] hr_image = test_arg[3] image_name = image_name[0] lr_image = Variable(lr_image) hr_image = Variable(hr_image) if self._cuda: lr_image = lr_image.cuda() hr_image = hr_image.cuda() sr_image = self.netG(lr_image) mse = ((hr_image - sr_image) ** 2).data.mean() psnr = 10 * math.log10(1 / mse) _ssim = ssim(sr_image, hr_image).item() # data[0] test_images = torch.stack( [display_transform()(hr_restore_img.squeeze(0)), display_transform()(hr_image.data.cpu().squeeze(0)), display_transform()(sr_image.data.cpu().squeeze(0))]) image = vutils.make_grid(test_images, nrow=3, padding=5) vutils.save_image(image, self.out_bench_path + image_name.split('.')[0] + '_psnr_%.4f_ssim_%.4f.' % (psnr, _ssim) + image_name.split('.')[-1], padding=5) # save psnr\ssim # print(image_name) # print(image_name.split('_')[0]) self.results['SunHays80']['psnr'].append(psnr) self.results['SunHays80']['ssim'].append(_ssim)
[docs] def on_test_end(self): out_stat_path = os.path.join(self._model_save_dir, "SRF_" + str(self._upscale_factor), "statistics") makedirs(out_stat_path) saved_results = {'psnr': [], 'ssim': []} for item in self.results.values(): psnr = np.array(item['psnr']) _ssim = np.array(item['ssim']) if (len(psnr) == 0) or (len(_ssim) == 0): psnr = 'No data' _ssim = 'No data' else: psnr = psnr.mean() _ssim = _ssim.mean() saved_results['psnr'].append(psnr) saved_results['ssim'].append(_ssim) data_frame = pd.DataFrame(saved_results, self.results.keys()) data_frame.to_csv(os.path.join(out_stat_path, 'srf_' + str(self._upscale_factor) + '_test_results.csv'), index_label='DataSet')