import os
import sys
import itertools
from PIL import Image
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torchvision.transforms as transforms
from farabio.core.gantrainer import GanTrainer
from farabio.data.datasets import ImageDataset
from farabio.utils.helpers import ReplayBuffer
from farabio.utils.visdom import CycleganViz
from farabio.utils.regul import weights_init_normal, LambdaLR
from farabio.models.translation.cyclegan.cyclegan import Generator, Discriminator
[docs]class CycleganTrainer(GanTrainer):
"""CycleganTrainer trainer class. Override with custom methods here.
Parameters
----------
GanTrainer : farabio.core.basetrainer.BaseTrainer
Inherits GanTrainer class
"""
[docs] def define_data_attr(self):
self._dataroot = self.config.dataroot
self._size = self.config.size
self._in_ch = self.config.in_ch
self._out_ch = self.config.out_ch
self._batch_size = self.config.batch_size
[docs] def define_model_attr(self):
self._model_save_dir = self.config.save_dir
self._genA2B_path = self.config.generator_A2B
self._genB2A_path = self.config.generator_A2B
[docs] def define_train_attr(self):
self._num_epochs = self.config.num_epochs
self._start_epoch = self.config.start_epoch
self._learning_rate = self.config.learning_rate
self._decay_epoch = self.config.decay_epoch
self._has_eval = self.config.has_eval
if self.config.optim == 'adam':
self._optim = torch.optim.Adam
self._scheduler = torch.optim.lr_scheduler.LambdaLR
[docs] def define_test_attr(self):
self._output_dir = self.config.output_dir
[docs] def define_log_attr(self):
self._save_epoch = self.config.save_epoch
[docs] def define_compute_attr(self):
self._cuda = self.config.cuda
self._num_workers = self.config.num_workers
[docs] def define_misc_attr(self):
self._mode = self.config.mode
[docs] def build_model(self):
self.netG_A2B = Generator(self._in_ch, self._out_ch)
self.netG_B2A = Generator(self._out_ch, self._in_ch)
self.netD_A = Discriminator(self._in_ch)
self.netD_B = Discriminator(self._out_ch)
if self._cuda:
self.netG_A2B.cuda()
self.netG_B2A.cuda()
self.netD_A.cuda()
self.netD_B.cuda()
self.netG_A2B.apply(weights_init_normal)
self.netG_B2A.apply(weights_init_normal)
self.netD_A.apply(weights_init_normal)
self.netD_B.apply(weights_init_normal)
# Losses
self.criterion_GAN = torch.nn.MSELoss()
self.criterion_cycle = torch.nn.L1Loss()
self.criterion_identity = torch.nn.L1Loss()
# Optimizers & LR schedulers
self.optimizer_G = self._optim(itertools.chain(self.netG_A2B.parameters(),
self.netG_B2A.parameters()),
lr=self._learning_rate, betas=(0.5, 0.999))
self.optimizer_D_A = self._optim(self.netD_A.parameters(),
lr=self._learning_rate,
betas=(0.5, 0.999))
self.optimizer_D_B = self._optim(self.netD_B.parameters(),
lr=self._learning_rate,
betas=(0.5, 0.999))
self.lr_scheduler_G = self._scheduler(self.optimizer_G, lr_lambda=LambdaLR(
self._num_epochs, self._start_epoch, self._decay_epoch).step)
self.lr_scheduler_D_A = self._scheduler(self.optimizer_D_A, lr_lambda=LambdaLR(
self._num_epochs, self._start_epoch, self._decay_epoch).step)
self.lr_scheduler_D_B = self._scheduler(self.optimizer_D_B, lr_lambda=LambdaLR(
self._num_epochs, self._start_epoch, self._decay_epoch).step)
[docs] def on_train_start(self):
Tensor = torch.cuda.FloatTensor if self._cuda else torch.Tensor
self.input_A = Tensor(
self._batch_size, self._in_ch, self._size, self._size)
self.input_B = Tensor(
self._batch_size, self._out_ch, self._size, self._size)
self.target_real = Variable(
Tensor(self._batch_size).fill_(1.0), requires_grad=False)
self.target_fake = Variable(
Tensor(self._batch_size).fill_(0.0), requires_grad=False)
self.fake_A_buffer = ReplayBuffer()
self.fake_B_buffer = ReplayBuffer()
[docs] def start_logger(self):
self.logger = CycleganViz(self._num_epochs, len(self.train_loader))
[docs] def on_train_epoch_start(self):
self.train_epoch_iter = enumerate(self.train_loader)
[docs] def train_batch(self, args):
self.on_start_training_batch(args)
###### Generators A2B and B2A ######
self.generator_zero_grad()
self.generator_loss()
self.generator_backward()
self.generator_optim_step()
###### Discriminator A ######
self.discriminatorA_zero_grad()
self.discriminatorA_loss()
self.discriminatorA_backward()
self.discriminatorA_optim_step()
###### Discriminator B ######
self.discriminatorB_zero_grad()
self.discriminatorB_loss()
self.discriminatorB_backward()
self.discriminatorB_optim_step()
self.on_end_training_batch()
[docs] def on_start_training_batch(self, args):
# Set model input
self.i = args[0]
self.batch = args[-1]
self.real_A = Variable(self.input_A.copy_(self.batch['A']))
self.real_B = Variable(self.input_B.copy_(self.batch['B']))
[docs] def generator_zero_grad(self):
self.optimizer_G.zero_grad()
[docs] def generator_loss(self):
"""Total generator loss
"""
self.loss_identity_A, self.loss_identity_B = self.identity_g_loss()
self.loss_GAN_A2B, self.loss_GAN_B2A = self.gan_g_loss()
self.loss_cycle_ABA, self.loss_cycle_BAB = self.cycle_g_loss()
self.loss_G = self.loss_identity_A + self.loss_identity_B + \
self.loss_GAN_A2B + self.loss_GAN_B2A + \
self.loss_cycle_ABA + self.loss_cycle_BAB
[docs] def generator_backward(self):
self.loss_G.backward()
[docs] def generator_optim_step(self):
self.optimizer_G.step()
[docs] def identity_g_loss(self):
"""Identity loss
Returns
-------
scalar, scalar
torch.nn.L1Loss, torch.nn.L1Loss
"""
# G_A2B(B) should equal B if real B is fed
same_B = self.netG_A2B(self.real_B)
loss_identity_B = self.criterion_identity(same_B, self.real_B)*5.0
# G_B2A(A) should equal A if real A is fed
same_A = self.netG_B2A(self.real_A)
loss_identity_A = self.criterion_identity(same_A, self.real_A)*5.0
return loss_identity_A, loss_identity_B
[docs] def gan_g_loss(self):
"""GAN loss
Returns
-------
scalar, scalar
torch.nn.MSELoss, torch.nn.MSELoss
"""
self.fake_B = self.netG_A2B(self.real_A)
pred_fake = self.netD_B(self.fake_B)
loss_GAN_A2B = self.criterion_GAN(pred_fake, self.target_real)
self.fake_A = self.netG_B2A(self.real_B)
pred_fake = self.netD_A(self.fake_A)
loss_GAN_B2A = self.criterion_GAN(pred_fake, self.target_real)
return loss_GAN_A2B, loss_GAN_B2A
[docs] def cycle_g_loss(self):
"""Cycle loss
Returns
-------
scalar, scalar
torch.nn.L1Loss, torch.nn.L1Loss
"""
recovered_A = self.netG_B2A(self.fake_B)
loss_cycle_ABA = self.criterion_cycle(recovered_A, self.real_A)*10.0
recovered_B = self.netG_A2B(self.fake_A)
loss_cycle_BAB = self.criterion_cycle(recovered_B, self.real_B)*10.0
return loss_cycle_ABA, loss_cycle_BAB
[docs] def discriminatorA_zero_grad(self):
self.optimizer_D_A.zero_grad()
[docs] def discriminatorA_loss(self):
"""Loss for discriminator A: fake and real.
"""
# Real loss
loss_D_real = self.real_dA_loss()
# Fake loss
loss_D_fake = self.fake_dA_loss()
# Total loss
self.loss_D_A = (loss_D_real + loss_D_fake)*0.5
[docs] def real_dA_loss(self):
"""Loss for discriminator A: real
Returns
-------
scalar
torch.nn.MSELoss
"""
pred_real = self.netD_A(self.real_A)
loss_D_real = self.criterion_GAN(pred_real, self.target_real)
return loss_D_real
[docs] def fake_dA_loss(self):
"""Loss for discriminator A: fake
Returns
-------
scalar
torch.nn.MSELoss
"""
self.fake_A = self.fake_A_buffer.push_and_pop(self.fake_A)
pred_fake = self.netD_A(self.fake_A.detach())
loss_D_fake = self.criterion_GAN(pred_fake, self.target_fake)
return loss_D_fake
[docs] def discriminatorA_backward(self):
self.loss_D_A.backward()
[docs] def discriminatorA_optim_step(self):
self.optimizer_D_A.step()
[docs] def discriminatorB_zero_grad(self):
self.optimizer_D_B.zero_grad()
[docs] def discriminatorB_loss(self):
"""Loss for discriminator B: fake and real.
"""
# Real loss
loss_D_real = self.real_dB_loss()
# Fake loss
loss_D_fake = self.fake_dB_loss()
# Total loss
self.loss_D_B = (loss_D_real + loss_D_fake)*0.5
[docs] def real_dB_loss(self):
"""Loss for discriminator B: real
Returns
-------
scalar
torch.nn.MSELoss
"""
pred_real = self.netD_B(self.real_B)
loss_D_real = self.criterion_GAN(pred_real, self.target_real)
return loss_D_real
[docs] def fake_dB_loss(self):
"""Loss for discriminator B: fake
Returns
-------
scalar
torch.nn.MSELoss
"""
self.fake_B = self.fake_B_buffer.push_and_pop(self.fake_B)
pred_fake = self.netD_B(self.fake_B.detach())
loss_D_fake = self.criterion_GAN(pred_fake, self.target_fake)
return loss_D_fake
[docs] def discriminatorB_backward(self):
self.loss_D_B.backward()
[docs] def discriminatorB_optim_step(self):
self.optimizer_D_B.step()
[docs] def on_end_training_batch(self):
self.logger.log({'loss_G': self.loss_G,
'loss_G_identity': (self.loss_identity_A + self.loss_identity_B),
'loss_G_GAN': (self.loss_GAN_A2B + self.loss_GAN_B2A),
'loss_G_cycle': (self.loss_cycle_ABA + self.loss_cycle_BAB),
'loss_D': (self.loss_D_A + self.loss_D_B)},
images={'real_A': self.real_A,
'real_B': self.real_B,
'fake_A': self.fake_A,
'fake_B': self.fake_B})
[docs] def on_train_epoch_end(self):
# Update learning rates
self.lr_scheduler_G.step()
self.lr_scheduler_D_A.step()
self.lr_scheduler_D_B.step()
[docs] def on_epoch_end(self):
if self.epoch % self._save_epoch == 0:
self.save_model()
[docs] def get_trainloader(self):
transforms_ = [transforms.Resize(int(self._size*1.12), Image.BICUBIC),
transforms.RandomCrop(self._size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
self.train_loader = DataLoader(ImageDataset(self._dataroot, transforms_=transforms_, unaligned=True, mode='train'),
batch_size=self._batch_size, shuffle=True, num_workers=self._num_workers)
[docs] def get_testloader(self):
transforms_ = [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
self.test_loader = DataLoader(ImageDataset(self._dataroot, transforms_=transforms_, unaligned=False, mode='test'),
batch_size=self._batch_size, shuffle=False, num_workers=self._num_workers)
[docs] def save_model(self):
"""Save model
Parameters
----------
epoch : int
current epoch
"""
netG_A2B_name = os.path.join(self._model_save_dir, "netG_A2B.pth")
netG_B2A_name = os.path.join(self._model_save_dir, "netG_B2A.pth")
netD_A_name = os.path.join(self._model_save_dir, "netD_A.pth")
netD_B_name = os.path.join(self._model_save_dir, "netD_B.pth")
# Save models checkpoints
torch.save(self.netG_A2B.state_dict(), netG_A2B_name)
torch.save(self.netG_B2A.state_dict(), netG_B2A_name)
torch.save(self.netD_A.state_dict(), netD_A_name)
torch.save(self.netD_B.state_dict(), netD_B_name)
[docs] def load_model(self):
self.netG_A2B = Generator(self._in_ch, self._out_ch)
self.netG_B2A = Generator(self._out_ch, self._in_ch)
if self._cuda:
self.netG_A2B.cuda()
self.netG_B2A.cuda()
# Load state dicts
self.netG_A2B.load_state_dict(torch.load(self._genA2B_path))
self.netG_B2A.load_state_dict(torch.load(self._genB2A_path))
# Set model's test mode
self.netG_A2B.eval()
self.netG_B2A.eval()
# # Create output dirs if they don't exist
if not os.path.exists(os.path.join(self._output_dir, "A")):
os.makedirs(os.path.join(self._output_dir, "A"))
if not os.path.exists(os.path.join(self._output_dir, "B")):
os.makedirs(os.path.join(self._output_dir, "B"))
[docs] def on_test_start(self):
# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if self._cuda else torch.Tensor
self.input_A = Tensor(
self._batch_size, self._in_ch, self._size, self._size)
self.input_B = Tensor(
self._batch_size, self._out_ch, self._size, self._size)
self.test_loop_iter = enumerate(self.test_loader)
[docs] def test_step(self, args):
self.i = args[0]
batch = args[-1]
# Set model input
real_A = Variable(self.input_A.copy_(batch['A']))
real_B = Variable(self.input_B.copy_(batch['B']))
# Generate output
self.fake_B = 0.5*(self.netG_A2B(real_A).data + 1.0)
self.fake_A = 0.5*(self.netG_B2A(real_B).data + 1.0)
[docs] def on_end_test_batch(self):
idx = str(self.i+1).zfill(4)
# Save image files
save_image(self.fake_A, os.path.join(
self._output_dir, "A", idx+".png"))
save_image(self.fake_B, os.path.join(
self._output_dir, "B", idx+".png"))
sys.stdout.write('\rGenerated images %04d of %04d' %
(self.i+1, len(self.test_loader)))