import sys
sys.path.insert(0, "..")
import webdataset as wds
from torchvision import transforms
from pprint import pprint

bucket = "pipe:curl -s -L http://storage.googleapis.com/nvdata-openimages/"
# shards_train = bucket + "openimages-train-{000000..000554}.tar"
shards_train = bucket + "openimages-train-{000000..000009}.tar"

train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.ToTensor(),
        # transforms.Normalize(DATA_MEANS, DATA_STD)
    ]
)

train_ds = wds.DataPipeline(
    wds.SimpleShardList(shards_train),
    wds.detshuffle(100),
    wds.cached_tarfile_to_samples(verbose=True),
    wds.decode("pil"),
    wds.to_tuple("png;jpg json"),
    wds.map_tuple(train_transform, None),
)

image, info = next(iter(train_ds))
print(image.shape)
print(info)
# downloading http://storage.googleapis.com/nvdata-openimages/openimages-train-000000.tar to ./_cache/nvdata-openimages/openimages-train-000000.tar


torch.Size([3, 32, 32])
[{'ImageID': 'e39871fd9fd74f55', 'Source': 'xclick', 'LabelName': '/m/01g317', 'Confidence': '1', 'XMin': '0.389323', 'XMax': '0.661458', 'YMin': '0.094727', 'YMax': '0.571289', 'IsOccluded': '1', 'IsTruncated': '0', 'IsGroupOf': '0', 'IsDepiction': '1', 'IsInside': '0'}, {'ImageID': 'e39871fd9fd74f55', 'Source': 'xclick', 'LabelName': '/m/05y5lj', 'Confidence': '1', 'XMin': '0.402344', 'XMax': '0.427083', 'YMin': '0.468750', 'YMax': '0.485352', 'IsOccluded': '1', 'IsTruncated': '0', 'IsGroupOf': '0', 'IsDepiction': '1', 'IsInside': '0'}, {'ImageID': 'e39871fd9fd74f55', 'Source': 'xclick', 'LabelName': '/m/05y5lj', 'Confidence': '1', 'XMin': '0.516927', 'XMax': '0.597656', 'YMin': '0.093750', 'YMax': '0.153320', 'IsOccluded': '0', 'IsTruncated': '0', 'IsGroupOf': '0', 'IsDepiction': '1', 'IsInside': '0'}]


# opening ./_cache/nvdata-openimages/openimages-train-000000.tar