import os
import glob
from pathlib import Path
from xml.etree import ElementTree
from tqdm import tqdm
import numpy as np
from super_gradients.training.utils.utils import download_and_untar_from_url, get_image_size_from_path
from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.datasets.datasets_conf import PASCAL_VOC_2012_CLASSES_LIST
logger = get_logger(__name__)
[docs]class PascalVOCDetectionDataset(DetectionDataset):
"""Dataset for Pascal VOC object detection"""
def __init__(self, images_sub_directory: str, *args, **kwargs):
"""Dataset for Pascal VOC object detection
:param images_sub_directory: Sub directory of data_dir that includes images.
"""
self.images_sub_directory = images_sub_directory
self.img_and_target_path_list = None
kwargs['all_classes_list'] = PASCAL_VOC_2012_CLASSES_LIST
kwargs['original_target_format'] = DetectionTargetsFormat.XYXY_LABEL
super().__init__(*args, **kwargs)
def _setup_data_source(self):
"""Initialize img_and_target_path_list and warn if label file is missing
:return: List of tuples made of (img_path,target_path)
"""
img_files_folder = self.data_dir + self.images_sub_directory
if not Path(img_files_folder).exists():
raise FileNotFoundError(f"{self.data_dir} does not include {self.images_sub_directory}. "
f"Please make sure that f{self.data_dir} refers to PascalVOC dataset and that "
"it was downloaded using PascalVOCDetectionDataSetV2.download()")
img_files = glob.glob(img_files_folder + "*.jpg")
if len(img_files) == 0:
raise FileNotFoundError(f"No image file found at {img_files_folder}")
target_files = [img_file.replace("images", "labels").replace(".jpg", ".txt") for img_file in img_files]
img_and_target_path_list = [(img_file, target_file)
for img_file, target_file in zip(img_files, target_files)
if os.path.exists(target_file)]
if len(img_and_target_path_list) == 0:
raise FileNotFoundError("No target file associated to the images was found")
num_missing_files = len(img_files) - len(img_and_target_path_list)
if num_missing_files > 0:
logger.warning(f'{num_missing_files} labels files were not loaded our of {len(img_files)} image files')
self.img_and_target_path_list = img_and_target_path_list
return len(self.img_and_target_path_list)
def _load_annotation(self, sample_id: int) -> dict:
"""Load annotations associated to a specific sample.
:return: Annotation including:
- target in XYXY_LABEL format
- img_path
"""
img_path, target_path = self.img_and_target_path_list[sample_id]
with open(target_path, 'r') as targets_file:
target = np.array([x.split() for x in targets_file.read().splitlines()], dtype=np.float32)
width, height = get_image_size_from_path(img_path)
# We have to rescale the targets because the images will be rescaled.
r = min(self.input_dim[1] / height, self.input_dim[0] / width)
target[:, :4] *= r
initial_img_shape = (width, height)
resized_img_shape = (int(width * r), int(height * r))
return {"img_path": img_path, "target": target,
"initial_img_shape": initial_img_shape, "resized_img_shape": resized_img_shape}
[docs] @staticmethod
def download(data_dir: str):
"""Download Pascal dataset in XYXY_LABEL format.
Data extracted form http://host.robots.ox.ac.uk/pascal/VOC/
"""
def _parse_and_save_labels(path, new_label_path, year, image_id):
"""Parse and save the labels of an image in XYXY_LABEL format."""
with open(f'{path}/VOC{year}/Annotations/{image_id}.xml') as f:
xml_parser = ElementTree.parse(f).getroot()
labels = []
for obj in xml_parser.iter('object'):
cls = obj.find('name').text
if cls in PASCAL_VOC_2012_CLASSES_LIST and not int(obj.find('difficult').text) == 1:
xml_box = obj.find('bndbox')
def get_coord(box_coord):
return xml_box.find(box_coord).text
xmin, ymin, xmax, ymax = get_coord("xmin"), get_coord("ymin"), get_coord("xmax"), get_coord("ymax")
labels.append(" ".join([xmin, ymin, xmax, ymax, str(PASCAL_VOC_2012_CLASSES_LIST.index(cls))]))
with open(new_label_path, 'w') as f:
f.write("\n".join(labels))
urls = ["http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar", # 439M 5011 images
"http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar", # 430M, 4952 images
"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"] # 1.86G, 17125 images
data_dir = Path(data_dir)
download_and_untar_from_url(urls, dir=data_dir / 'images')
# Convert
data_path = data_dir / 'images' / 'VOCdevkit'
for year, image_set in ('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test'):
dest_imgs_path = data_dir / 'images' / f'{image_set}{year}'
dest_imgs_path.mkdir(exist_ok=True, parents=True)
dest_labels_path = data_dir / 'labels' / f'{image_set}{year}'
dest_labels_path.mkdir(exist_ok=True, parents=True)
with open(data_path / f'VOC{year}/ImageSets/Main/{image_set}.txt') as f:
image_ids = f.read().strip().split()
for id in tqdm(image_ids, desc=f'{image_set}{year}'):
img_path = data_path / f'VOC{year}/JPEGImages/{id}.jpg'
new_img_path = dest_imgs_path / img_path.name
new_label_path = (dest_labels_path / img_path.name).with_suffix('.txt')
img_path.rename(new_img_path) # Move image to dest folder
_parse_and_save_labels(data_path, new_label_path, year, id)