import os
import ipdb
import time
from collections import namedtuple
import matplotlib
from tqdm import tqdm
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data as data_
from farabio.core.convnettrainer import ConvnetTrainer
from farabio.models.detection.faster_rcnn.dataset import Dataset, TestDataset, inverse_normalize
from farabio.models.detection.faster_rcnn.faster_rcnn_vgg16 import FasterRCNNVGG16
from farabio.models.detection.faster_rcnn.creator_tool import AnchorTargetCreator, ProposalTargetCreator
import farabio.utils.helpers as helpers
from farabio.utils.losses import Losses
from farabio.utils.metrics import eval_detection_voc
from farabio.utils.meters import ConfusionMeter, AverageValueMeter
from farabio.utils.visdom import FasterRCNNViz, visdom_bbox
# fix for ulimit
# https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667
import resource
#from farabio.models.detection.faster_rcnn.config import opt
# Start train with: python train.py train --env='fasterrcnn' --plot-every=100
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (20480, rlimit[1]))
matplotlib.use('agg')
LossTuple = namedtuple('LossTuple',
['rpn_loc_loss',
'rpn_cls_loss',
'roi_loc_loss',
'roi_cls_loss',
'total_loss'
])
[docs]class FasterRCNNTrainer(ConvnetTrainer):
"""FasterRCNNTrainer trainer class. Override with custom methods here.
The losses include:
* :obj:`rpn_loc_loss`: The localization loss for \
Region Proposal Network (RPN).
* :obj:`rpn_cls_loss`: The classification loss for RPN.
* :obj:`roi_loc_loss`: The localization loss for the head module.
* :obj:`roi_cls_loss`: The classification loss for the head module.
* :obj:`total_loss`: The sum of 4 loss above.
Args:
faster_rcnn (model.FasterRCNN):
A Faster R-CNN model that is going to be trained.
"""
[docs] def define_train_attr(self):
self._start_epoch = self.config.start_epoch
self._num_epochs = self.config.num_epochs
self._has_eval = self.config.has_eval
self._eval_interval = self.config.eval_interval
self._test_num = self.config.test_num
self.rpn_sigma = self.config.rpn_sigma
self.roi_sigma = self.config.roi_sigma
self._scale_epoch = self.config.scale_epoch
[docs] def define_model_attr(self):
self.anchor_target_creator = AnchorTargetCreator()
self.proposal_target_creator = ProposalTargetCreator()
self._backbone = True
self.load_optimizer = self.config['load_optimizer']
self._load_path = self.config['load_path']
self._best_path = None
[docs] def define_log_attr(self):
self._use_visdom = self.config.use_visdom
# indicators for training status
self.rpn_cm = ConfusionMeter(2)
self.roi_cm = ConfusionMeter(21)
self.meters = {k: AverageValueMeter()
for k in LossTuple._fields} # average loss
self._save_path = self.config.save_path
self._save_optimizer = self.config.save_optimizer
self._plot_every = self.config.plot_every
if self._use_visdom:
# visdom wrapper
self.vis = FasterRCNNViz(env=self.config.env)
[docs] def define_misc_attr(self):
self._mode = self.config.mode
[docs] def get_trainloader(self):
print('load data')
self.dataset = Dataset(self.config)
self.train_loader = data_.DataLoader(self.dataset,
batch_size=1,
shuffle=True, \
# pin_memory=True,
num_workers=self.config.num_workers)
[docs] def get_testloader(self):
testset = TestDataset(self.config)
self.test_loader = data_.DataLoader(testset,
batch_size=1,
num_workers=self.config.test_num_workers,
shuffle=False,
pin_memory=True
)
[docs] def build_model(self):
self.faster_rcnn = FasterRCNNVGG16()
print('model construct completed')
if self.config.cuda:
self.faster_rcnn.cuda()
# target creator create gt_bbox gt_label etc as training targets.
self.loc_normalize_mean = self.faster_rcnn.loc_normalize_mean
self.loc_normalize_std = self.faster_rcnn.loc_normalize_std
self.optimizer = self.faster_rcnn.get_optimizer()
[docs] def load_model(self):
state_dict = torch.load(self.config.load_path)
if 'model' in state_dict:
self.faster_rcnn.load_state_dict(state_dict['model'])
else: # legacy way, for backward compatibility
self.faster_rcnn.load_state_dict(state_dict)
if 'optimizer' in state_dict and self.load_optimizer:
self.optimizer.load_state_dict(state_dict['optimizer'])
[docs] def save(self, **kwargs):
save_dict = dict()
save_dict['model'] = self.faster_rcnn.state_dict()
#save_dict['config'] = opt._state_dict()
save_dict['config'] = helpers.state_dict_from_namespace(self.config)
save_dict['other_info'] = kwargs
save_dict['vis_info'] = self.vis.state_dict()
if self._save_optimizer:
save_dict['optimizer'] = self.optimizer.state_dict()
timestr = time.strftime('%m%d%H%M')
save_path = f"{self._save_path}fasterrcnn_{timestr}"
for k_, v_ in list(kwargs.items()):
save_path += f'_{v_}'
save_dir = os.path.dirname(save_path)
helpers.makedirs(save_dir)
print("saving")
self.save_model(save_dict, save_path)
self.vis.save([self.vis.env])
self._load_path = save_path
[docs] def save_model(self, save_dict, save_path):
torch.save(save_dict, save_path)
[docs] def on_train_start(self):
if self.config.load_path is not None:
self.load_model()
print(f'load pretrained model from {self.config.load_path}')
self.best_map = 0
self.lr_ = self.config.lr
[docs] def start_logger(self):
self.vis.text(self.dataset.db.label_names, win='labels')
[docs] def on_train_epoch_start(self):
self.reset_meters()
self.train_epoch_iter = tqdm(enumerate(self.train_loader))
[docs] def on_start_training_batch(self, args):
self.ii = args[0]
self.img = args[-1][0]
self.bbox_ = args[-1][1]
self.label_ = args[-1][2]
self.scale = args[-1][3]
[docs] def training_step(self):
self.scale = helpers.scalar(self.scale)
self.img, self.bbox, self.label = self.img.cuda(
).float(), self.bbox_.cuda(), self.label_.cuda()
self.optimizer_zero_grad()
self.forward()
self.loss_backward()
self.optimizer_step()
self.update_meters()
if (self.ii + 1) % self._plot_every == 0:
self.visdom_plot()
[docs] def on_evaluate_epoch_start(self):
self.pred_bboxes, self.pred_labels, self.pred_scores = list(), list(), list()
self.gt_bboxes, self.gt_labels, self.gt_difficults = list(), list(), list()
self.valid_epoch_iter = tqdm(enumerate(self.test_loader))
[docs] def on_evaluate_batch_start(self, args):
self.ii = args[0]
self.imgs = args[-1][0]
self.sizes = args[-1][1]
self.gt_bboxes_ = args[-1][2]
self.gt_labels_ = args[-1][3]
self.gt_difficults_ = args[-1][4]
[docs] def on_evaluate_epoch_end(self):
self.eval_result = eval_detection_voc(
self.pred_bboxes, self.pred_labels, self.pred_scores,
self.gt_bboxes, self.gt_labels, self.gt_difficults,
use_07_metric=True)
[docs] def visdom_plot(self):
if os.path.exists(self.config.debug_file):
ipdb.set_trace()
# plot loss
self.vis.plot_many(self.get_meter_data())
# plot groud truth bboxes
ori_img_ = inverse_normalize(helpers.tonumpy(self.img[0]))
gt_img = visdom_bbox(ori_img_,
helpers.tonumpy(self.bbox_[0]),
helpers.tonumpy(self.label_[0]))
self.vis.img('gt_img', gt_img)
# plot predicti bboxes
_bboxes, _labels, _scores = self.faster_rcnn.predict(
[ori_img_], visualize=True)
pred_img = visdom_bbox(ori_img_,
helpers.tonumpy(_bboxes[0]),
helpers.tonumpy(_labels[0]).reshape(-1),
helpers.tonumpy(_scores[0]))
self.vis.img('pred_img', pred_img)
# rpn confusion matrix(meter)
self.vis.text(
str(self.rpn_cm.value().tolist()), win='rpn_cm')
# roi confusion matrix
self.vis.img('roi_cm', helpers.totensor(
self.roi_cm.conf, False).float())
[docs] def on_epoch_end(self):
self.vis.plot('test_map', self.eval_result['map'])
lr_ = self.faster_rcnn.optimizer.param_groups[0]['lr']
log_info = 'lr:{}, map:{},loss:{}'.format(str(lr_),
str(self.eval_result['map']),
str(self.get_meter_data()))
self.vis.log(log_info)
if self.eval_result['map'] > self.best_map:
self.best_map = self.eval_result['map']
self.save(best_map=self.best_map)
if self._epoch == self._scale_epoch:
self.load_model()
self.faster_rcnn.scale_lr(self.config.lr_decay)
lr_ = lr_ * self.config.lr_decay
if self._epoch == self._num_epochs:
self.stop_train()
[docs] def evaluate_batch(self, *args):
sizes = [self.sizes[0][0].item(), self.sizes[1][0].item()]
pred_bboxes_, pred_labels_, pred_scores_ = self.faster_rcnn.predict(self.imgs, [
sizes])
self.gt_bboxes += list(self.gt_bboxes_.numpy())
self.gt_labels += list(self.gt_labels_.numpy())
self.gt_difficults += list(self.gt_difficults_.numpy())
self.pred_bboxes += pred_bboxes_
self.pred_labels += pred_labels_
self.pred_scores += pred_scores_
if self.ii == self._test_num:
self.exit_trainer()
[docs] def forward(self):
n = self.bbox.shape[0]
if n != 1:
raise ValueError('Currently only batch size 1 is supported.')
_, _, H, W = self.img.shape
img_size = (H, W)
features = self.faster_rcnn.extractor(self.img)
rpn_locs, rpn_scores, rois, roi_indices, anchor = \
self.faster_rcnn.rpn(features, img_size, self.scale)
# Since batch size is one, convert variables to singular form
bbox = self.bbox[0]
label = self.label[0]
rpn_score = rpn_scores[0]
rpn_loc = rpn_locs[0]
roi = rois
# Sample RoIs and forward
# it's fine to break the computation graph of rois,
# consider them as constant input
sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator(
roi,
helpers.tonumpy(bbox),
helpers.tonumpy(label),
self.loc_normalize_mean,
self.loc_normalize_std)
# NOTE it's all zero because now it only support for batch=1 now
sample_roi_index = torch.zeros(len(sample_roi))
roi_cls_loc, roi_score = self.faster_rcnn.head(
features,
sample_roi,
sample_roi_index)
# ------------------ RPN losses -------------------#
gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(
helpers.tonumpy(bbox),
anchor,
img_size)
gt_rpn_label = helpers.totensor(gt_rpn_label).long()
gt_rpn_loc = helpers.totensor(gt_rpn_loc)
rpn_loc_loss = Losses()._fast_rcnn_loc_loss(
rpn_loc,
gt_rpn_loc,
gt_rpn_label.data,
self.rpn_sigma)
# NOTE: default value of ignore_index is -100 ...
rpn_cls_loss = F.cross_entropy(
rpn_score, gt_rpn_label.cuda(), ignore_index=-1)
_gt_rpn_label = gt_rpn_label[gt_rpn_label > -1]
_rpn_score = helpers.tonumpy(
rpn_score)[helpers.tonumpy(gt_rpn_label) > -1]
self.rpn_cm.add(helpers.totensor(_rpn_score, False),
_gt_rpn_label.data.long())
# ------------------ ROI losses (fast rcnn loss) -------------------#
n_sample = roi_cls_loc.shape[0]
roi_cls_loc = roi_cls_loc.view(n_sample, -1, 4)
roi_loc = roi_cls_loc[torch.arange(0, n_sample).long().cuda(),
helpers.totensor(gt_roi_label).long()]
gt_roi_label = helpers.totensor(gt_roi_label).long()
gt_roi_loc = helpers.totensor(gt_roi_loc)
roi_loc_loss = Losses()._fast_rcnn_loc_loss(
roi_loc.contiguous(),
gt_roi_loc,
gt_roi_label.data,
self.roi_sigma)
roi_cls_loss = nn.CrossEntropyLoss()(roi_score, gt_roi_label.cuda())
self.roi_cm.add(helpers.totensor(roi_score, False),
gt_roi_label.data.long())
losses = [rpn_loc_loss, rpn_cls_loss, roi_loc_loss, roi_cls_loss]
losses = losses + [sum(losses)]
self.all_losses = LossTuple(*losses)
[docs] def optimizer_zero_grad(self):
self.optimizer.zero_grad()
[docs] def loss_backward(self):
self.all_losses.total_loss.backward()
[docs] def optimizer_step(self):
self.optimizer.step()
##########################
# Native methods
##########################
[docs] def update_meters(self):
loss_d = {k: helpers.scalar(v)
for k, v in list(self.all_losses._asdict().items())}
for key, meter in list(self.meters.items()):
meter.add(loss_d[key])
[docs] def reset_meters(self):
for key, meter in list(self.meters.items()):
meter.reset()
self.roi_cm.reset()
self.rpn_cm.reset()
[docs] def get_meter_data(self):
return {k: v.value()[0] for k, v in list(self.meters.items())}
"""Forward Faster R-CNN and calculate losses.
Here are notations used.
* :math:`N` is the batch size.
* :math:`R` is the number of bounding boxes per image.
Currently, only :math:`N=1` is supported.
(~torch.autograd.Variable)
Args:
imgs : A variable with a batch of images.
bboxes : A batch of bounding boxes.
Its shape is :math:`(N, R, 4)`.
labels : A batch of labels.
Its shape is :math:`(N, R)`. The background is excluded from
the definition, which means that the range of the value
is :math:`[0, L - 1]`. :math:`L` is the number of foreground
classes.
scale (float): Amount of scaling applied to
the raw image during preprocessing.
Returns:
namedtuple of 5 losses
"""
"""Serialize models include optimizer and other info
return path where the model-file is stored.
Args:
save_optimizer (bool): whether save optimizer.state_dict().
save_path (string): where to save model, if it's None, save_path
is generate using time str and info from kwargs.
Returns:
save_path(str): the path to save models.
"""