import os
import unittest
import subprocess
import numpy as np
import pandas as pd
from zipfile import ZipFile
from PIL import Image
from typing import Optional, Callable
import matplotlib.pyplot as plt
from skimage import io, transform
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as F
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import draw_segmentation_masks
__all__ = ['ChestXrayDataset', 'DSB18Dataset', 'HistocancerDataset',
'RANZCRDataset', 'RetinopathyDataset']
kaggle_biodatasets = [
"aptos2019-blindness-detection",
"chest-xray-pneumonia",
"data-science-bowl-2018",
"histopathologic-cancer-detection",
"intel-mobileodt-cervical-cancer-screening",
"ranzcr-clip-catheter-line-classification",
"skin-cancer-mnist"
]
def download_datasets(tag, path="."):
"""Helper function to download datasets
Parameters
----------
tag : str
tag for dataset
.. note::
available tags:
kaggle_biodatasets = [
"aptos2019-blindness-detection",
"chest-xray-pneumonia",
"data-science-bowl-2018",
"histopathologic-cancer-detection",
"intel-mobileodt-cervical-cancer-screening",
"ranzcr-clip-catheter-line-classification",
"skin-cancer-mnist"
]
path : str, optional
path where to save dataset, by default "."
Examples
----------
>>> download_datasets(tag="skin-cancer-mnist", path=".")
"""
if tag == "chest-xray-pneumonia":
bash_c_tag = ["kaggle", "datasets", "download",
"-d", "paultimothymooney/chest-xray-pneumonia"]
elif tag == "skin-cancer-mnist":
bash_c_tag = ["kaggle", "datasets", "download",
"-d", "kmader/skin-cancer-mnist-ham10000"]
else:
bash_c = ["kaggle", "competitions", "download", "-c"]
bash_c_tag = bash_c.copy()
bash_c_tag.append(tag)
prev_cwd = os.getcwd()
os.chdir(path)
process = subprocess.Popen(bash_c_tag, stdout=subprocess.PIPE)
output, error = process.communicate()
print(output)
os.chdir(prev_cwd)
def extract_zip(fzip, fnew=None):
with ZipFile(fzip, 'r') as zip: # ZipFile(fzip, 'r') as zip:
print('Extracting all the train files now...')
zip.extractall(fnew)
print('Done!')
[docs]class ChestXrayDataset(ImageFolder):
r"""PyTorch friendly ChestXrayDataset class
Dataset is loaded using Kaggle API.
For further information on raw dataset and pneumonia detection, please refer to [1]_.
Examples
----------
>>> valid_dataset = ChestXrayDataset(root=_path, download=True, mode="val", show=True)
.. image:: ../imgs/ChestXrayDataset.png
:width: 600
References
---------------
.. [1] https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia
"""
def __init__(self, root: str = ".", download: bool = False, mode: str = "train", shape: int = 256, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, show: bool = True):
tag = "chest-xray-pneumonia"
modes = ["train", "val", "test"]
assert mode in modes, "Available options for mode: train, val, test"
self.shape = shape
self.mode = mode
if download:
download_datasets(tag, path=root)
extract_zip(os.path.join(root, tag+".zip"),
os.path.join(root, tag))
if transform is None:
self.transform = self.default_transform(mode)
else:
self.transform = transform
if target_transform is not None:
self.target_transform = target_transform
if download:
dataset_path = os.path.join(root, tag, "chest_xray", mode)
else:
dataset_path = os.path.join(root, mode)
super(ChestXrayDataset, self).__init__(
root=dataset_path, transform=self.transform)
if show:
self.visualize_batch()
def __getitem__(self, index):
path, target = self.samples[index]
fname = path.split("/")[-1]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, fname
def default_transform(self, mode="train"):
if mode == "train":
transform = transforms.Compose([
transforms.Resize((self.shape, self.shape)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
elif mode == 'val' or mode == 'test':
transform = transforms.Compose([
transforms.Resize((self.shape, self.shape)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
return transform
[docs] def visualize_batch(self):
loader = DataLoader(self, batch_size=4, shuffle=True)
inputs, labels, fnames = next(iter(loader))
list_imgs = [inputs[i] for i in range(len(inputs))]
self.show(list_imgs, labels, fnames)
def show(self, imgs, labels, fnames):
if not isinstance(imgs, list):
imgs = [imgs]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = img.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * img + mean
inp = np.clip(inp, 0, 1)
axs[0, i].imshow(np.asarray(inp))
axs[0, i].set(xticks=[], yticks=[])
axs[0, i].text(0, -0.2, str(int(labels[i])) + ": " +
self.classes[labels[i]], transform=axs[0, i].transAxes)
axs[0, i].set_title("..."+fnames[i][-12:-5])
[docs]class DSB18Dataset(Dataset):
r"""PyTorch friendly DSB18Dataset class
Dataset is loaded using Kaggle API.
For further information on raw dataset and nuclei segmentation, please refer to [1]_.
Examples
----------
>>> train_dataset = DSB18Dataset(_path, transform=None, download=False, show=True)
.. image:: ../imgs/DSB18Dataset.png
:width: 600
References
---------------
.. [1] https://www.kaggle.com/c/data-science-bowl-2018/overview
"""
def __init__(self, root: str = ".", download: bool = False, mode: str = "train", shape: int = 512, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, show: bool = True):
tag = "data-science-bowl-2018"
modes = ["train", "val", "test"]
assert mode in modes, "Available options for mode: train, val"
if mode == "train" or mode == "val":
stage = "stage1_train"
else:
stage = "stage1_test"
self.mode = mode
path = os.path.join(root, tag, stage)
if download:
download_datasets(tag, path=root)
extract_zip(os.path.join(root, tag+".zip"),
os.path.join(root, tag))
extract_zip(os.path.join(root, tag, stage + ".zip"), path)
else:
path = os.path.join(root, stage)
self.path = path
self.shape = shape
if self.mode != "test":
seed = 42
train_list = os.listdir(self.path)
train_list, valid_list = train_test_split(
train_list,
test_size=0.2,
random_state=seed
)
if self.mode == "train":
self.folders = train_list
elif self.mode == "val":
self.folders = valid_list
else:
self.folders = os.listdir(self.path)
if transform is None:
self.transform = self.default_transform()
else:
self.transform = transform
if target_transform is None:
self.target_transform = self.default_target_transform()
else:
self.target_transform = target_transform
if show:
self.visualize_batch()
def __len__(self):
return len(self.folders)
def __getitem__(self, idx):
image_folder = os.path.join(self.path, self.folders[idx], 'images/')
fname = os.listdir(image_folder)[0]
image_path = os.path.join(image_folder, fname)
img = Image.open(image_path).convert('RGB')
img = self.transform(img)
if self.mode != "test":
mask_folder = os.path.join(self.path, self.folders[idx], 'masks/')
mask = self.get_mask(mask_folder)
mask = self.target_transform(mask)
sample = (img, mask, fname)
else:
sample = (img, fname)
return sample
def get_mask(self, mask_folder):
mask = np.zeros((self.shape, self.shape, 1), dtype=bool)
for mask_ in os.listdir(mask_folder):
mask_ = io.imread(os.path.join(mask_folder, mask_))
mask_ = transform.resize(mask_, (self.shape, self.shape))
mask_ = np.expand_dims(mask_, axis=-1)
mask = np.maximum(mask, mask_)
return mask
def default_transform(self):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((self.shape, self.shape))
])
return transform
def default_target_transform(self):
target_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((self.shape, self.shape))
])
return target_transform
[docs] def visualize_batch(self):
loader = DataLoader(self, shuffle=True, batch_size=4)
if self.mode != "test":
imgs, masks, fnames = next(iter(loader))
else:
imgs, fnames = next(iter(loader))
batch_inputs = F.convert_image_dtype(imgs, dtype=torch.uint8)
if self.mode != "test":
batch_outputs = F.convert_image_dtype(masks, dtype=torch.bool)
list_imgs = [
draw_segmentation_masks(
img, masks=mask, alpha=0.6, colors=(102, 255, 178))
for img, mask in zip(batch_inputs, batch_outputs)
]
else:
list_imgs = [imgs[i] for i in range(len(imgs))]
self.show(list_imgs, fnames)
def show(self, imgs, fnames):
if not isinstance(imgs, list):
imgs = [imgs]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = img.detach()
img = F.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
axs[0, i].set_title("..."+fnames[i][-10:-4])
[docs]class HistocancerDataset(Dataset):
r"""PyTorch friendly HistocancerDataset class
Dataset is loaded using Kaggle API.
For further information on raw dataset and tumor classification, please refer to [1]_.
Examples
----------
>>> train_dataset = HistocancerDataset(root=".", download=False, mode="train")
.. image:: ../imgs/HistocancerDataset.png
:width: 600
References
---------------
.. [1] <https://www.kaggle.com/c/histopathologic-cancer-detection/data>`_
"""
def __init__(self, root: str = ".", mode: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, show: bool = True):
tag = "histopathologic-cancer-detection"
modes = ["train", "val", "test"]
assert mode in modes, "Available options for mode: train, val, test"
self.mode = mode
if download:
download_datasets(tag, path=root)
extract_zip(os.path.join(root, tag+".zip"),
os.path.join(root, tag))
self.path = os.path.join(root, tag)
else:
self.path = os.path.join(root)
if self.mode != "test":
self.csv_path = os.path.join(self.path, "train_labels.csv")
self.img_path = os.path.join(self.path, "train")
self.labels = pd.read_csv(self.csv_path)
train_data, val_data = train_test_split(
self.labels, stratify=self.labels.label, test_size=0.1)
if self.mode == "train":
data = train_data
elif self.mode == "val":
data = val_data
self.data = data.values
else:
self.img_path = os.path.join(self.path, "test")
self.data = os.listdir(self.img_path)
if transform is None:
self.transform = self.default_transform(mode)
else:
self.transform = transform
self.target_transform = target_transform
if show:
self.visualize_batch()
def __len__(self):
return len(self.data)
def __getitem__(self, index):
if self.mode != "test":
fname, label = self.data[index]
img_path = os.path.join(self.img_path, fname+'.tif')
else:
fname = self.data[index]
img_path = os.path.join(self.img_path, fname)
img = Image.open(img_path).convert("RGB")
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
label = self.transform(label)
if self.mode != "test":
sample = (img, label, fname)
else:
sample = (img, fname)
return sample
def default_transform(self, mode):
if mode == "train":
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
elif mode == "val" or mode == "test":
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
return transform
[docs] def visualize_batch(self):
loader = DataLoader(self, batch_size=4, shuffle=True)
if self.mode != "test":
imgs, labels, fnames = next(iter(loader))
else:
imgs, fnames = next(iter(loader))
labels = None
list_imgs = [imgs[i] for i in range(len(imgs))]
self.show(list_imgs, fnames, labels)
def show(self, imgs, fnames, labels=None):
if not isinstance(imgs, list):
imgs = [imgs]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = img.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * img + mean
inp = np.clip(inp, 0, 1)
axs[0, i].imshow(np.asarray(inp))
axs[0, i].set(xticklabels=[], yticklabels=[],
xticks=[], yticks=[])
if self.mode != "test":
if labels[i] == 0:
lab = "non-tumor"
else:
lab = "tumor"
axs[0, i].set_title("..."+fnames[i][-6:])
axs[0, i].text(0, -0.2, str(int(labels[i])) + ": " +
lab, transform=axs[0, i].transAxes)
else:
axs[0, i].set_title("..."+fnames[i][-11:-4])
[docs]class RANZCRDataset(Dataset):
r"""PyTorch friendly RANZCRDataset class
Dataset is loaded using Kaggle API.
For further information on raw dataset and catheters presence, please refer to [1]_.
Examples
----------
>>> train_dataset = RANZCRDataset(_path_ranzcr, show=True, shape=512)
.. image:: ../imgs/RANZCRDataset.png
:width: 600
References
---------------
.. [1] https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification/data
"""
def __init__(self, root: str = ".", mode: str = "train", shape: int = 256, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, show: bool = True):
tag = "ranzcr-clip-catheter-line-classification"
modes = ["train", "val", "test"]
assert mode in modes, "Available options for mode: train, val, test"
self.mode = mode
if download:
download_datasets(tag, path=root)
extract_zip(os.path.join(root, tag+".zip"),
os.path.join(root, tag))
path = os.path.join(root, tag)
else:
path = root
train_path = os.path.join(path, "train")
test_path = os.path.join(path, "test")
csv_path = os.path.join(path, "train_annotations.csv")
self.data = pd.read_csv(csv_path)
self.labels, self.encoded_labels = self.get_labels()
self.train_list, self.valid_list = self.get_train_valid(train_path)
self.shape = shape
if transform is None:
self.transform = self.default_transform()
else:
self.transform = transform
if target_transform is not None:
self.target_transform = target_transform
if self.mode != "test":
if self.mode == "train":
self.file_list = self.train_list
else:
self.file_list = self.valid_list
else:
self.file_list = glob.glob(test_path+"/*")
if show:
self.visualize_batch()
def __len__(self):
self.filelength = len(self.file_list)
return self.filelength
def __getitem__(self, idx):
if self.mode != "test":
img_path = self.file_list[idx][0]
fname = img_path.split("/")[-1]
else:
img_path = self.file_list[idx]
fname = img_path.split("/")[-1]
img = Image.open(img_path).convert("RGB")
img = self.transform(img)
if self.mode != "test":
label = self.file_list[idx][1]
sample = (img, label, fname)
else:
sample = (img, fname)
return sample
def default_transform(self):
transform = transforms.Compose([
transforms.Resize((self.shape, self.shape)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
return transform
[docs] def visualize_batch(self):
loader = DataLoader(self, batch_size=4, shuffle=True)
if self.mode != "test":
imgs, labels, fnames = next(iter(loader))
else:
imgs, fnames = next(iter(loader))
labels = None
list_imgs = [imgs[i] for i in range(len(imgs))]
self.show(list_imgs, fnames, labels)
def show(self, imgs, fnames, labels=None):
if not isinstance(imgs, list):
imgs = [imgs]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = img.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * img + mean
inp = np.clip(inp, 0, 1)
axs[0, i].imshow(np.asarray(inp))
axs[0, i].set(xticklabels=[], yticklabels=[],
xticks=[], yticks=[])
axs[0, i].set_title("..."+fnames[i][-11:-4])
if self.mode != "test":
lab = self.unique_labels[labels[i]]
axs[0, i].text(0, -0.2, str(int(labels[i])) +
": " + lab, transform=axs[0, i].transAxes)
def get_labels(self):
self.data = self.data.drop(["data"], axis=1)
data_org = self.data['label']
labels = data_org.to_list()
used = set()
self.unique_labels = [
x for x in labels if x not in used and (used.add(x) or True)]
ord_enc = OrdinalEncoder()
self.data[['label']] = ord_enc.fit_transform(self.data[['label']])
self.data.label = self.data.label.astype("int")
label = self.data["label"]
label = label.to_list()
encoded_labels = label
return labels, encoded_labels
def get_train_valid(self, train_path):
seed = 42
train_list = []
for i in self.data.index:
a = self.data["StudyInstanceUID"].loc[i]
b = train_path + "/" + a + ".jpg"
train_list.append((b, self.data['label'].loc[i]))
train_list, valid_list = train_test_split(train_list,
test_size=0.2,
random_state=seed)
return train_list, valid_list
[docs]class RetinopathyDataset(Dataset):
r"""PyTorch friendly RetinopathyDataset class
Dataset is loaded using Kaggle API.
For further information on raw dataset and blindness detection, please refer to [1]_.
Examples
----------
>>> train_dataset = RetinopathyDataset(".", mode="train", show=True)
.. image:: ../imgs/RetinopathyDataset.png
:width: 600
References
---------------
.. [1] <https://www.kaggle.com/c/aptos2019-blindness-detection/data>`_
"""
def __init__(self, root: str = ".", mode: str = "train", shape: int = 256, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, show: bool = True):
tag = "aptos2019-blindness-detection"
if download:
download_datasets(tag, path=root)
extract_zip(os.path.join(root, tag+".zip"),
os.path.join(root, tag))
path = os.path.join(root, tag)
else:
path = root
self.mode = mode
if mode != "test":
self.csv_path = os.path.join(path, "train.csv")
self.img_path = os.path.join(path, "train_images")
data = pd.read_csv(self.csv_path)
train_idx, val_idx = train_test_split(
range(len(data)), test_size=0.1,)
train_data = data.iloc[train_idx]
val_data = data.iloc[val_idx]
if self.mode == "train":
self.data = train_data
self.data.reset_index(drop=True, inplace=True)
else:
self.data = val_data
self.data.reset_index(drop=True, inplace=True)
else:
self.img_path = os.path.join(path, "test_images")
self.data = os.listdir(self.img_path)
self.shape = shape
if transform is None:
self.transform = self.default_transform()
else:
self.transform = transform
if target_transform is not None:
self.target_transform = target_transform
if show:
self.visualize_batch()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if self.mode != "test":
fname = self.data.loc[idx, 'id_code'] + ".png"
else:
fname = self.data[idx]
img_name = os.path.join(self.img_path, fname)
img = Image.open(img_name).convert("RGB")
img = self.transform(img)
if self.mode != "test":
label = torch.tensor(self.data.loc[idx, 'diagnosis'])
sample = (img, label, fname)
else:
sample = (img, fname)
return sample
def default_transform(self):
transform = transforms.Compose([
transforms.Resize((self.shape, self.shape)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
return transform
[docs] def visualize_batch(self):
loader = DataLoader(self, batch_size=4, shuffle=True)
if self.mode != "test":
imgs, labels, fnames = next(iter(loader))
else:
imgs, fnames = next(iter(loader))
labels = None
list_imgs = [imgs[i] for i in range(len(imgs))]
self.show(list_imgs, fnames, labels)
def show(self, imgs, fnames, labels=None):
if not isinstance(imgs, list):
imgs = [imgs]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = img.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * img + mean
inp = np.clip(inp, 0, 1)
axs[0, i].imshow(np.asarray(inp))
axs[0, i].set(xticklabels=[], yticklabels=[],
xticks=[], yticks=[])
axs[0, i].set_title("..."+fnames[i][-10:-4])
if self.mode != "test":
axs[0, i].text(0, -0.2, "Severity: " +
str(int(labels[i])), transform=axs[0, i].transAxes)
class TestBiodatasets(unittest.TestCase):
def testChestXrayDataset(self):
_path = "/home/data/02_SSD4TB/suzy/datasets/public/chest-xray"
valid_dataset = ChestXrayDataset(
root=_path, download=False, mode="val", show=False)
print(valid_dataset)
def testDSB18Dataset(self):
_path = "/home/data/02_SSD4TB/suzy/datasets/public/data-science-bowl-2018"
train_dataset = DSB18Dataset(
root=_path, transform=None, mode="train", download=False, show=False)
print(train_dataset)
def testHistocancerDataset(self):
_path = "/home/data/02_SSD4TB/suzy/datasets/public/histopathologic-cancer-detection"
train_dataset = HistocancerDataset(
root=_path, download=False, mode="train", show=False)
print(train_dataset)
def testRANZCRDataset(self):
_path = "/home/data/02_SSD4TB/suzy/datasets/public/ranzcr-clip-catheter-line-classification"
train_dataset = RANZCRDataset(
root=_path, show=False, shape=512, mode="train", download=False)
print(train_dataset)
def testRetinopathyDataset(self):
_path = "/home/data/02_SSD4TB/suzy/datasets/public/aptos2019-blindness-detection"
train_dataset = RetinopathyDataset(
root=_path, mode="train", show=False, download=False)
print(train_dataset)
# unittest.main()