import os
import time
import datetime
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.ticker import NullLocator
from terminaltables import AsciiTable
import torch
from torch.autograd import Variable
from farabio.core.convnettrainer import ConvnetTrainer
from farabio.data.datasets import ListDataset, ImageFolder
from farabio.models.detection.yolov3.darknet import Darknet
from farabio.models.detection.yolov3.parsers import parse_data_config
from farabio.utils.helpers import load_classes, makedirs
from farabio.utils.regul import weights_init_normal
from farabio.utils.tensorboard import Logger
from farabio.utils.losses import get_batch_statistics, non_max_suppression, ap_per_class, xywh2xyxy
from farabio.utils.bboxtools import rescale_boxes
[docs]class YoloTrainer(ConvnetTrainer):
"""YoloTrainer trainer class. Override with custom methods here.
Parameters
----------
ConvnetTrainer : BaseTrainer
Inherits ConvnetTrainer class
"""
[docs] def get_trainloader(self):
data_config = parse_data_config(self.data_config)
self.class_names = load_classes(data_config["names"])
train_path = data_config["train"]
train_dataset = ListDataset(train_path, augment=True,
multiscale=self.config.multiscale_training)
self.train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.config.n_cpu,
pin_memory=True,
collate_fn=train_dataset.collate_fn,
)
if self._mode == 'detect':
self.valid_loader = torch.utils.data.DataLoader(
ImageFolder(
self.image_folder, img_size=self.img_size),
batch_size=self.dbatch_size,
shuffle=False,
num_workers=self.config.n_cpu,
)
elif self._mode == 'train' or self._mode == 'test':
valid_path = data_config["valid"]
valid_dataset = ListDataset(valid_path, img_size=self.img_size,
augment=False, multiscale=False)
self.valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.n_cpu,
collate_fn=valid_dataset.collate_fn
)
[docs] def define_data_attr(self):
self.batch_size = self.config.batch_size
self.dbatch_size = self.config.dbatch_size
self.data_config = self.config.data_config
self.img_size = self.config.img_size
self.image_folder = self.config.image_folder
self.checkpoint_interval = self.config.checkpoint_interval
[docs] def define_model_attr(self):
self.model_def = self.config.model_def
[docs] def define_train_attr(self):
self._eval_interval = self.config.evaluation_interval
self.iou_thres = self.config.iou_thres
self.conf_thres = self.config.conf_thres
self.econf_thres = self.config.econf_thres
self.nms_thres = self.config.nms_thres
self._num_epochs = self.config.num_epochs
self.chckpt_dir = self.config.chckpt_dir
self.grad_acc = self.config.gradient_accumulations
if self.config.optim == 'adam':
self.optim = torch.optim.Adam
makedirs(self.config.chckpt_dir)
[docs] def define_test_attr(self):
self.dconf_thres = self.config.dconf_thres
self.dnms_thres = self.config.dnms_thres
self.detect = self.config.detect
self.output_dir = self.config.output_dir
makedirs(self.config.output_dir)
[docs] def define_log_attr(self):
self.metrics = [
"grid_size",
"loss",
"x",
"y",
"w",
"h",
"conf",
"cls",
"cls_acc",
"recall50",
"recall75",
"precision",
"conf_obj",
"conf_noobj",
]
self.logger = Logger(self.config.logdir)
[docs] def define_compute_attr(self):
self.n_cpu = self.config.n_cpu
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
[docs] def define_misc_attr(self):
self._mode = self.config.mode
[docs] def build_model(self):
# Initiate model
self.model = Darknet(self.model_def).to(self.device)
if self._mode == 'train':
self.model.apply(weights_init_normal)
pretrained_weights = self.config.pretrained_weights
if pretrained_weights.endswith(".pth"):
# Custom Darknet weights
self.model.load_state_dict(torch.load(pretrained_weights))
else:
# Darknet-53 on ImageNet
self.model.load_darknet_weights(pretrained_weights)
elif self._mode == 'test':
weights_path = self.config.weights_path
if weights_path.endswith(".weights"):
# Load yolov3 coco weights
self.model.load_darknet_weights(weights_path)
else:
# Load custom data checkpoint weights
self.model.load_state_dict(torch.load(weights_path))
elif self._mode == 'detect':
weights_path = self.config.weights_path
if weights_path.endswith(".weights"):
# Load yolov3 coco weights
self.model.load_darknet_weights(weights_path)
else:
# Load custom data checkpoint weights
self.model.load_state_dict(torch.load(weights_path))
self.optimizer = self.optim(self.model.parameters())
[docs] def on_train_epoch_start(self):
self.model.train()
self.start_time = time.time()
self.train_epoch_iter = enumerate(self.train_loader)
[docs] def on_start_training_batch(self, args):
self.batch_i = args[0]
self.imgs = args[-1][1]
self.targets = args[-1][2]
[docs] def training_step(self):
self.batches_done = len(self.train_loader) * self._epoch + self.batch_i
self.imgs = Variable(self.imgs.to(self.device))
self.targets = Variable(self.targets.to(self.device),
requires_grad=False)
self.loss, outputs = self.model(self.imgs, self.targets)
self.loss.backward()
if self.batches_done % self.grad_acc == 0:
# Accumulates gradient before each step
self.optimizer.step()
self.optimizer.zero_grad()
[docs] def on_end_training_batch(self):
# ----------------
# Log progress
# ----------------
log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (
self._epoch, self._num_epochs, self.batch_i, len(self.train_loader))
metric_table = [
["Metrics", *[f"YOLO Layer {i}" for i in range(len(self.model.yolo_layers))]]]
# Log metrics at each YOLO layer
for i, metric in enumerate(self.metrics):
formats = {m: "%.6f" for m in self.metrics}
formats["grid_size"] = "%2d"
formats["cls_acc"] = "%.2f%%"
row_metrics = [formats[metric] % yolo.metrics.get(
metric, 0) for yolo in self.model.yolo_layers]
metric_table += [[metric, *row_metrics]]
# Tensorboard logging
tensorboard_log = []
for j, yolo in enumerate(self.model.yolo_layers):
for name, metric in yolo.metrics.items():
if name != "grid_size":
tensorboard_log += [(f"{name}_{j+1}", metric)]
tensorboard_log += [("loss", self.loss.item())]
self.logger.list_of_scalars_summary(
tensorboard_log, self.batches_done)
log_str += AsciiTable(metric_table).table
log_str += f"\nTotal loss {self.loss.item()}"
# Determine approximate time left for epoch
epoch_batches_left = len(self.train_loader) - (self.batch_i + 1)
time_left = datetime.timedelta(
seconds=epoch_batches_left * (time.time() - self.start_time) / (self.batch_i + 1))
log_str += f"\n---- ETA {time_left}"
print(log_str)
self.model.seen += self.imgs.size(0)
[docs] def on_epoch_end(self):
if self._epoch % self.checkpoint_interval == 0:
self.save_model()
[docs] def save_model(self):
torch.save(self.model.state_dict(),
os.path.join(f"{self.chckpt_dir}", f"yolov3_ckpt_{self._epoch}.pth"))
[docs] def on_evaluate_epoch_start(self):
self.model.eval()
self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
self.labels = []
self.sample_metrics = [] # List of tuples (TP, confs, pred)
self.valid_epoch_iter = enumerate(
tqdm(self.valid_loader, desc="Detecting objects"))
[docs] def on_evaluate_batch_start(self, args):
self.batch_i = args[0]
self.imgs = args[-1][1]
self.targets = args[-1][2]
if self.targets is None:
super()._next_loop = True
[docs] def evaluate_batch(self, *args):
# Extract labels
self.labels += self.targets[:, 1].tolist()
# Rescale target
self.targets[:, 2:] = xywh2xyxy(self.targets[:, 2:])
self.targets[:, 2:] *= self.img_size
self.imgs = Variable(self.imgs.type(self.Tensor), requires_grad=False)
with torch.no_grad():
outputs = self.model(self.imgs)
outputs = non_max_suppression(
outputs, conf_thres=self.econf_thres, nms_thres=self.nms_thres)
self.sample_metrics += get_batch_statistics(outputs,
self.targets, iou_threshold=self.iou_thres)
[docs] def on_evaluate_epoch_end(self):
# Concatenate sample statistics
true_positives, pred_scores, pred_labels = [
np.concatenate(x, 0) for x in list(zip(*self.sample_metrics))]
precision, recall, AP, f1, ap_class = ap_per_class(
true_positives, pred_scores, pred_labels, self.labels)
evaluation_metrics = [
("val_precision", precision.mean()),
("val_recall", recall.mean()),
("val_mAP", AP.mean()),
("val_f1", f1.mean()),
]
self.logger.list_of_scalars_summary(
evaluation_metrics, self._epoch)
# Print class APs and mAP
if self._mode == 'train':
ap_table = [["Index", "Class name", "AP"]]
for i, c in enumerate(ap_class):
ap_table += [[c, self.class_names[c], "%.5f" % AP[i]]]
print(AsciiTable(ap_table).table)
print(f"---- mAP {AP.mean()}")
elif self._mode == "test":
print("Average Precisions:")
for i, c in enumerate(ap_class):
print(
f"+ Class '{c}' ({self.class_names[c]}) - AP: {AP[i]}")
print(f"mAP: {AP.mean()}")
[docs] def test(self):
print("Compute mAP...")
self.evaluate_epoch()
[docs] def get_detections(self):
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
self.imgs = [] # Stores image paths
self.img_detections = [] # Stores detections for each image index
# for batch_i, (_, imgs, targets) in enumerate(tqdm(valid_loader, desc="Detecting objects")):
print("\nPerforming object detection:")
prev_time = time.time()
#print("vloader: ", len(self.valid_loader))
for batch_i, (img_paths, input_imgs) in enumerate(self.valid_loader):
# Configure input
input_imgs = Variable(input_imgs.type(Tensor))
# Get detections
with torch.no_grad():
detections = self.model(input_imgs)
detections = non_max_suppression(
detections, self.dconf_thres, self.dnms_thres)
# Log progress
current_time = time.time()
inference_time = datetime.timedelta(
seconds=current_time - prev_time)
prev_time = current_time
print("\t+ Batch %d, Inference Time: %s" %
(batch_i, inference_time))
# Save image and detections
self.imgs.extend(img_paths)
self.img_detections.extend(detections)
[docs] def plot_bbox(self):
# Bounding-box colors
cmap = plt.get_cmap("tab20b")
colors = [cmap(i) for i in np.linspace(0, 1, 20)]
print("\nSaving images:")
# Iterate through images and save plot of detections
for img_i, (path, detections) in enumerate(zip(self.imgs, self.img_detections)):
print("(%d) Image: '%s'" % (img_i, path))
# Create plot
img = np.array(Image.open(path))
plt.figure()
fig, ax = plt.subplots(1)
ax.imshow(img)
# Draw bounding boxes and labels of detections
if detections is not None:
# Rescale boxes to original image
detections = rescale_boxes(
detections, self.img_size, img.shape[:2])
unique_labels = detections[:, -1].cpu().unique()
n_cls_preds = len(unique_labels)
bbox_colors = random.sample(colors, n_cls_preds)
for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
print("\t+ Label: %s, Conf: %.5f" %
(self.class_names[int(cls_pred)], cls_conf.item()))
box_w = x2 - x1
box_h = y2 - y1
color = bbox_colors[int(
np.where(unique_labels == int(cls_pred))[0])]
# Create a Rectangle patch
bbox = patches.Rectangle(
(x1, y1), box_w, box_h, linewidth=2, edgecolor=color, facecolor="none")
# Add the bbox to the plot
ax.add_patch(bbox)
# Add label
plt.text(
x1,
y1,
s=self.class_names[int(cls_pred)],
color="white",
verticalalignment="top",
bbox={"color": color, "pad": 0},
)
# Save generated image with detections
plt.axis("off")
plt.gca().xaxis.set_major_locator(NullLocator())
plt.gca().yaxis.set_major_locator(NullLocator())
filename = os.path.basename(path).split(".")[0]
output_path = os.path.join(self.output_dir, f"{filename}.png")
plt.savefig(output_path, bbox_inches="tight", pad_inches=0.0)
plt.close()