import csv
import numpy as np
import os
import os.path
from typing import Callable
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
[docs]class BaseSgVisionDataset(VisionDataset):
"""
BaseSgVisionDataset
"""
def __init__(self, root: str, sample_loader: Callable = default_loader, target_loader: Callable = None,
collate_fn: Callable = None, valid_sample_extensions: tuple = IMG_EXTENSIONS,
sample_transform: Callable = None, target_transform: Callable = None):
"""
Ctor
:param root:
:param sample_loader:
:param target_loader:
:param collate_fn:
:param valid_sample_extensions:
:param sample_transform:
:param target_transform:
"""
super().__init__(root=root, transform=sample_transform, target_transform=target_transform)
self.samples_targets_tuples_list = list(tuple())
self.classes = []
self.valid_sample_extensions = valid_sample_extensions
self.sample_loader = sample_loader
self.target_loader = target_loader
self._generate_samples_and_targets()
# IF collate_fn IS PROVIDED IN CTOR WE ASSUME THERE IS A BASE-CLASS INHERITANCE W/O collate_fn IMPLEMENTATION
if collate_fn is not None:
self.collate_fn = collate_fn
def __getitem__(self, item):
"""
:param item:
:return:
"""
raise NotImplementedError
def __len__(self):
"""
:return:
"""
return len(self.samples_targets_tuples_list)
def _generate_samples_and_targets(self):
"""
_generate_samples_and_targets - An abstract method that fills the samples and targets members of the class
"""
raise NotImplementedError
def _validate_file(self, filename: str) -> bool:
"""
validate_file
:param filename:
:return:
"""
for valid_extension in self.valid_sample_extensions:
if filename.lower().endswith(valid_extension):
return True
return False
[docs] @staticmethod
def numpy_loader_func(path):
"""
_numpy_loader_func - Uses numpy load func
:param path:
:return:
"""
return np.load(path)
[docs] @staticmethod
def text_file_loader_func(text_file_path: str, inline_splitter: str = ' ') -> list:
"""
text_file_loader_func - Uses a line by line based code to get vectorized data from a text-based file
:param text_file_path: Input text file
:param inline_splitter: The char to use in order to separate between different VALUES of the SAME vector
please notice that DIFFERENT VECTORS SHOULD BE IN SEPARATE LINES ('\n') SEPARATED
:return: a list of tuples, where each tuple is a vector of target values
"""
if not os.path.isfile(text_file_path):
raise ValueError(" Error in text file path")
with open(text_file_path, "r", encoding="utf-8") as text_file:
targets_list = [tuple(map(float, line.split(inline_splitter))) for line in text_file]
return targets_list
[docs]class DirectoryDataSet(BaseSgVisionDataset):
"""
DirectoryDataSet - A PyTorch Vision Data Set extension that receives a root Dir and two separate sub directories:
- Sub-Directory for Samples
- Sub-Directory for Targets
"""
def __init__(self, root: str,
samples_sub_directory: str, targets_sub_directory: str, target_extension: str,
sample_loader: Callable = default_loader, target_loader: Callable = None, collate_fn: Callable = None,
sample_extensions: tuple = IMG_EXTENSIONS, sample_transform: Callable = None,
target_transform: Callable = None):
"""
CTOR
:param root: root directory that contains all of the Data Set
:param samples_sub_directory: name of the samples sub-directory
:param targets_sub_directory: name of the targets sub-directory
:param sample_extensions: file extensions for samples
:param target_extension: file extension of the targets
:param sample_loader: Func to load samples
:param target_loader: Func to load targets
:param collate_fn: collate_fn func to process batches for the Data Loader
:param sample_transform: Func to pre-process samples for data loading
:param target_transform: Func to pre-process targets for data loading
"""
# INITIALIZING THE TARGETS LOADER TO USE THE TEXT FILE LOADER FUNC
if target_loader is None:
target_loader = self.text_file_loader_func
self.target_extension = target_extension
self.samples_dir_suffix = samples_sub_directory
self.targets_dir_suffix = targets_sub_directory
super().__init__(root=root, sample_loader=sample_loader, target_loader=target_loader,
collate_fn=collate_fn, valid_sample_extensions=sample_extensions,
sample_transform=sample_transform, target_transform=target_transform)
def __getitem__(self, item):
"""
getter method for iteration
:param item:
:return:
"""
sample_path, target_path = self.samples_targets_tuples_list[item]
sample = self.sample_loader(sample_path)
target = self.target_loader(target_path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def _generate_samples_and_targets(self):
"""
_generate_samples_and_targets - Uses class built in members to generate the list of (SAMPLE, TARGET/S)
that is saved in self.samples_targets_tuples_list
"""
missing_sample_files, missing_target_files = 0, 0
# VALIDATE DATA PATH
samples_dir_path = self.root + os.path.sep + self.samples_dir_suffix
targets_dir_path = self.root + os.path.sep + self.targets_dir_suffix
if not os.path.exists(samples_dir_path) or not os.path.exists(targets_dir_path):
raise ValueError(" Error in data path")
# ITERATE OVER SAMPLES AND MAKE SURE THERE ARE MATCHING LABELS
for sample_file_name in os.listdir(samples_dir_path):
sample_file_path = samples_dir_path + os.path.sep + sample_file_name
if os.path.isfile(sample_file_path) and self._validate_file(sample_file_path):
sample_file_prefix = str(sample_file_name.split('.')[:-1][0])
# TRY TO GET THE MATCHING LABEL
matching_target_file_name = sample_file_prefix + self.target_extension
target_file_path = targets_dir_path + os.path.sep + matching_target_file_name
if os.path.isfile(target_file_path):
self.samples_targets_tuples_list.append((sample_file_path, target_file_path))
else:
missing_target_files += 1
else:
missing_sample_files += 1
for counter_name, missing_files_counter in [('samples', missing_sample_files),
('targets', missing_target_files)]:
if missing_files_counter > 0:
print(__name__ + ' There are ' + str(missing_files_counter) + ' missing ' + counter_name)
[docs]class ListDataset(BaseSgVisionDataset):
"""
ListDataset - A PyTorch Vision Data Set extension that receives a file with FULL PATH to each of the samples.
Then, the assumption is that for every sample, there is a * matching target * in the same
path but with a different extension, i.e:
for the samples paths: (That appear in the list file)
/root/dataset/class_x/sample1.png
/root/dataset/class_y/sample123.png
the matching labels paths: (That DO NOT appear in the list file)
/root/dataset/class_x/sample1.ext
/root/dataset/class_y/sample123.ext
"""
def __init__(self, root, file, sample_loader: Callable = default_loader, target_loader: Callable = None,
collate_fn: Callable = None, sample_extensions: tuple = IMG_EXTENSIONS,
sample_transform: Callable = None, target_transform: Callable = None, target_extension='.npy'):
"""
CTOR
:param root: root directory that contains all of the Data Set
:param file: Path to the file with the samples list
:param sample_extensions: file extension for samples
:param target_extension: file extension of the targets
:param sample_loader: Func to load samples
:param target_loader: Func to load targets
:param collate_fn: collate_fn func to process batches for the Data Loader
:param sample_transform: Func to pre-process samples for data loading
:param target_transform: Func to pre-process targets for data loading
"""
if target_loader is None:
target_loader = self.numpy_loader_func
self.list_file_path = file
self.loader = sample_loader
self.target_loader = target_loader
self.extensions = sample_extensions
self.target_extension = target_extension
super().__init__(root, sample_loader=sample_loader, target_loader=target_loader,
collate_fn=collate_fn, sample_transform=sample_transform,
valid_sample_extensions=sample_extensions,
target_transform=target_transform)
def __getitem__(self, item):
"""
Args:
item (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
sample_path, target_path = self.samples_targets_tuples_list[item]
sample = self.loader(sample_path)
target = self.target_loader(target_path)[0]
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def _generate_samples_and_targets(self):
"""
_generate_samples_and_targets
"""
file = open(self.root + os.path.sep + self.list_file_path, "r", encoding="utf-8")
reader = csv.reader(file)
data = [row[0] for row in reader]
for f in data:
path = self.root + os.path.sep + f
target_path = path[:-4] + self.target_extension
if self._validate_file(path) and os.path.exists(target_path):
self.samples_targets_tuples_list.append((path, target_path))