cyclegan_trainer Module

../_images/cyclegan.png

CycleganTrainer class uses Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, which is originally proposed in this arXiv 1. Implemented trainer module uses this Git 2 code as reference work.

References

1

https://arxiv.org/abs/1703.10593

2

https://github.com/aitorzip/PyTorch-CycleGAN

CycleganTrainer class

class farabio.models.translation.cyclegan.cyclegan_trainer.CycleganTrainer(config)[source]

CycleganTrainer trainer class. Override with custom methods here.

Parameters
GanTrainerfarabio.core.basetrainer.BaseTrainer

Inherits GanTrainer class

define_data_attr()[source]

Define data related attributes here

define_model_attr()[source]

Define model related attributes here

define_train_attr()[source]

Define training related attributes here

define_test_attr()[source]

Define training related attributes here

define_log_attr()[source]

Define log related attributes here

define_compute_attr()[source]

Define compute related attributes here

define_misc_attr()[source]

Define miscellaneous attributes here

build_model()[source]

Abstract method that builds model

on_train_start()[source]

Hook: On start of training loop

start_logger()[source]

Hook: Starts logger

on_train_epoch_start()[source]

Hook: On epoch start

train_batch(args)[source]

Hook: batch of training loop

on_start_training_batch(args)[source]

Hook: On training batch start

generator_zero_grad()[source]

Hook: Zero gradients of generator

generator_loss()[source]

Total generator loss

generator_backward()[source]

Hook: sends backward

generator_optim_step()[source]

Discriminator optimizer step

identity_g_loss()[source]

Identity loss

Returns
scalar, scalar

torch.nn.L1Loss, torch.nn.L1Loss

gan_g_loss()[source]

GAN loss

Returns
scalar, scalar

torch.nn.MSELoss, torch.nn.MSELoss

cycle_g_loss()[source]

Cycle loss

Returns
scalar, scalar

torch.nn.L1Loss, torch.nn.L1Loss

discriminatorA_zero_grad()[source]
discriminatorA_loss()[source]

Loss for discriminator A: fake and real.

real_dA_loss()[source]

Loss for discriminator A: real

Returns
scalar

torch.nn.MSELoss

fake_dA_loss()[source]

Loss for discriminator A: fake

Returns
scalar

torch.nn.MSELoss

discriminatorA_backward()[source]
discriminatorA_optim_step()[source]
discriminatorB_zero_grad()[source]
discriminatorB_loss()[source]

Loss for discriminator B: fake and real.

real_dB_loss()[source]

Loss for discriminator B: real

Returns
scalar

torch.nn.MSELoss

fake_dB_loss()[source]

Loss for discriminator B: fake

Returns
scalar

torch.nn.MSELoss

discriminatorB_backward()[source]
discriminatorB_optim_step()[source]
on_end_training_batch()[source]

Hook: On end of training batch

on_train_epoch_end()[source]

Hook: On end of training epoch

on_epoch_end()[source]

Hook: on epoch end

get_trainloader()[source]

Hook: Retreives training set of torch.utils.data.DataLoader class

get_testloader()[source]

Hook: Retreives test set of torch.utils.data.DataLoader class

save_model()[source]

Save model

Parameters
epochint

current epoch

load_model()[source]

Hook: load model

on_test_start()[source]

Hook: on test start

test_step(args)[source]

Test action (Put test here)

on_end_test_batch()[source]

Hook: on end of batch test