class_trainer Module

ClassTrainer class offers several classification models.

Available architectures

CIFAR10-classification

Model

Accuracy for CIFAR10

VGG16 1

92.64%

ResNet18 2

93.02%

ResNet50 2

93.62%

ResNet101 2

93.75%

RegNetX_200MF 3

94.24%

RegNetY_400MF 3

94.29%

MobileNetV2 4

94.43%

ResNeXt29(32x4d) 5

94.73%

ResNeXt29(2x64d) 5

94.82%

SimpleDLA 9

94.89%

DenseNet121 6

95.04%

PreActResNet18 7

95.11%

DPN92 8

95.16%

DLA 10

95.47%

arch.vgg.VGG

arch.resnet.ResNet

arch.resnet.RegNet

arch.mobilenetv2.MobileNetV2

arch.resnext.ResNeXt

arch.densenet.DenseNet

arch.preact_resnet.PreActResNet

arch.dpn.DPN

arch.dla_simple.SimpleDLA

References

1(1,2)

https://arxiv.org/abs/1409.1556

2(1,2,3,4)

https://arxiv.org/abs/1512.03385

3(1,2,3)

https://arxiv.org/abs/2003.13678

4(1,2)

https://arxiv.org/abs/1801.04381

5(1,2,3)

https://arxiv.org/abs/1611.05431

6(1,2)

https://arxiv.org/abs/1608.06993

7(1,2)

https://arxiv.org/abs/1603.05027

8(1,2)

https://arxiv.org/abs/1707.01629

9

https://arxiv.org/abs/1707.064

10(1,2)

https://arxiv.org/pdf/1707.06484.pdf

ClassTrainer class

class farabio.models.classification.class_trainer.ClassTrainer(config)[source]

Classification trainer class. Override with custom methods here.

Parameters
ConvnetTrainerBaseTrainer

Inherits ConvnetTrainer class

define_data_attr(*args)[source]

Define data related attributes here

define_model_attr(*args)[source]

Define model related attributes here

define_train_attr()[source]

Define training related attributes here

define_compute_attr(*args)[source]

Define compute related attributes here

define_log_attr()[source]

Define log related attributes here

get_trainloader()[source]

Hook: Retreives training set of torch.utils.data.DataLoader class

get_testloader()[source]

Hook: Retreives test set of torch.utils.data.DataLoader class

build_model()[source]

Abstract method that builds model

build_parallel_model()[source]

Abstract method that builds multi-GPU model in parallel

on_train_epoch_start()[source]

Hook: On epoch start

on_start_training_batch(args)[source]

Hook: On training batch start

training_step()[source]

Hook: During training batch

on_end_training_batch()[source]

Hook: On end of training batch

optimizer_zero_grad()[source]

Hook: Zero gradients of optimizer

optimizer_step()[source]

Hook: Optimizer step

loss_backward()[source]

Hook: Loss back-propagation

on_evaluate_epoch_start()[source]

Hook: on evaluation start

on_evaluate_batch_start(args)[source]
evaluate_batch(args)[source]

Hook: batch of evaluation loop

on_evaluate_batch_end()[source]

Hook: On evaluate batch end

on_evaluate_epoch_end()[source]
on_epoch_end()[source]

Hook: on epoch end