faster_rcnn_trainer Module

../_images/faster_rcnn.png

FasterRCNNTrainer class uses Faster R-CNN model architecture, which is originally proposed in this arXiv 1. Implemented trainer module uses this Git 2 code as reference work.

References

1

https://arxiv.org/abs/1506.01497

2

https://github.com/chenyuntc/simple-faster-rcnn-pytorch

FasterRCNNTrainer class

class farabio.models.detection.faster_rcnn.faster_rcnn_trainer.FasterRCNNTrainer(config)[source]

FasterRCNNTrainer trainer class. Override with custom methods here.

The losses include:

  • rpn_loc_loss: The localization loss for Region Proposal Network (RPN).

  • rpn_cls_loss: The classification loss for RPN.

  • roi_loc_loss: The localization loss for the head module.

  • roi_cls_loss: The classification loss for the head module.

  • total_loss: The sum of 4 loss above.

Args:
faster_rcnn (model.FasterRCNN):

A Faster R-CNN model that is going to be trained.

define_train_attr()[source]

Define training related attributes here

define_model_attr()[source]

Define model related attributes here

define_log_attr()[source]

Define log related attributes here

define_misc_attr()[source]

Define miscellaneous attributes here

get_trainloader()[source]

Hook: Retreives training set of torch.utils.data.DataLoader class

get_testloader()[source]

Hook: Retreives test set of torch.utils.data.DataLoader class

build_model()[source]

Abstract method that builds model

load_model()[source]

Hook: load model

save(**kwargs)[source]
save_model(save_dict, save_path)[source]

Hook: saves model

on_train_start()[source]

Hook: On start of training loop

start_logger()[source]

Hook: Starts logger

on_train_epoch_start()[source]

Hook: On epoch start

on_start_training_batch(args)[source]

Hook: On training batch start

training_step()[source]

Hook: During training batch

on_evaluate_epoch_start()[source]

Hook: on evaluation start

on_evaluate_batch_start(args)[source]
on_evaluate_epoch_end()[source]
visdom_plot()[source]
on_epoch_end()[source]

Hook: on epoch end

evaluate_batch(*args)[source]

Hook: batch of evaluation loop

forward()[source]
optimizer_zero_grad()[source]

Hook: Zero gradients of optimizer

loss_backward()[source]

Hook: Loss back-propagation

optimizer_step()[source]

Hook: Optimizer step

update_meters()[source]
reset_meters()[source]
get_meter_data()[source]