webdataset.mix
Classes for mixing samples from multiple sources.
View Source
# # Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # """Classes for mixing samples from multiple sources.""" import itertools, os, random, time, sys from functools import reduce, wraps import numpy as np from . import autodecode, utils from .pytorch import TorchTensor from torch.utils.data import IterableDataset from .utils import PipelineStage def round_robin_shortest(*sources): i = 0 while True: try: sample = next(sources[i % len(sources)]) yield sample except StopIteration: break i += 1 def round_robin_longest(*sources): i = 0 while len(sources) > 0: try: sample = next(sources[i]) i += 1 yield sample except StopIteration: del sources[i] class RoundRobin(IterableDataset): def __init__(self, datasets, longest=False): self.datasets = datasets self.longest = longest def __iter__(self): """Return an iterator over the sources.""" sources = [iter(d) for d in self.datasets] if self.longest: return round_robin_longest(*sources) else: return round_robin_shortest(*sources) def random_samples(sources, probs=None, longest=False): if probs is None: probs = [1] * len(sources) else: probs = list(probs) while len(sources) > 0: cum = (np.array(probs) / np.sum(probs)).cumsum() r = random.random() i = np.searchsorted(cum, r) try: yield next(sources[i]) except StopIteration: if longest: del sources[i] del probs[i] else: break class RandomMix(IterableDataset): def __init__(self, datasets, probs=None, longest=False): self.datasets = datasets self.probs = probs self.longest = longest def __iter__(self): """Return an iterator over the sources.""" sources = [iter(d) for d in self.datasets] return random_samples(sources, self.probs, longest=self.longest)
View Source
def round_robin_shortest(*sources): i = 0 while True: try: sample = next(sources[i % len(sources)]) yield sample except StopIteration: break i += 1
View Source
def round_robin_longest(*sources): i = 0 while len(sources) > 0: try: sample = next(sources[i]) i += 1 yield sample except StopIteration: del sources[i]
View Source
class RoundRobin(IterableDataset): def __init__(self, datasets, longest=False): self.datasets = datasets self.longest = longest def __iter__(self): """Return an iterator over the sources.""" sources = [iter(d) for d in self.datasets] if self.longest: return round_robin_longest(*sources) else: return round_robin_shortest(*sources)
An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite :meth:__iter__
, which would return an
iterator of samples in this dataset.
When a subclass is used with :class:~torch.utils.data.DataLoader
, each
item in the dataset will be yielded from the :class:~torch.utils.data.DataLoader
iterator. When :attr:num_workers > 0
, each worker process will have a
different copy of the dataset object, so it is often desired to configure
each copy independently to avoid having duplicate data returned from the
workers. :func:~torch.utils.data.get_worker_info
, when called in a worker
process, returns information about the worker. It can be used in either the
dataset's :meth:__iter__
method or the :class:~torch.utils.data.DataLoader
's
:attr:worker_init_fn
option to modify each copy's behavior.
Example 1: splitting workload across all workers in :meth:__iter__
::
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = torch.utils.data.get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]
Example 2: splitting workload across all workers using :attr:worker_init_fn
::
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]
>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
... worker_info = torch.utils.data.get_worker_info()
... dataset = worker_info.dataset # the dataset copy in this worker process
... overall_start = dataset.start
... overall_end = dataset.end
... # configure the dataset to only process the split workload
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... dataset.start = overall_start + worker_id * per_worker
... dataset.end = min(dataset.start + per_worker, overall_end)
...
>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
View Source
def __init__(self, datasets, longest=False): self.datasets = datasets self.longest = longest
Inherited Members
- torch.utils.data.dataset.IterableDataset
- functions
- reduce_ex_hook
- register_function
- register_datapipe_as_function
- set_reduce_ex_hook
- type
View Source
def random_samples(sources, probs=None, longest=False): if probs is None: probs = [1] * len(sources) else: probs = list(probs) while len(sources) > 0: cum = (np.array(probs) / np.sum(probs)).cumsum() r = random.random() i = np.searchsorted(cum, r) try: yield next(sources[i]) except StopIteration: if longest: del sources[i] del probs[i] else: break
View Source
class RandomMix(IterableDataset): def __init__(self, datasets, probs=None, longest=False): self.datasets = datasets self.probs = probs self.longest = longest def __iter__(self): """Return an iterator over the sources.""" sources = [iter(d) for d in self.datasets] return random_samples(sources, self.probs, longest=self.longest)
An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite :meth:__iter__
, which would return an
iterator of samples in this dataset.
When a subclass is used with :class:~torch.utils.data.DataLoader
, each
item in the dataset will be yielded from the :class:~torch.utils.data.DataLoader
iterator. When :attr:num_workers > 0
, each worker process will have a
different copy of the dataset object, so it is often desired to configure
each copy independently to avoid having duplicate data returned from the
workers. :func:~torch.utils.data.get_worker_info
, when called in a worker
process, returns information about the worker. It can be used in either the
dataset's :meth:__iter__
method or the :class:~torch.utils.data.DataLoader
's
:attr:worker_init_fn
option to modify each copy's behavior.
Example 1: splitting workload across all workers in :meth:__iter__
::
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = torch.utils.data.get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]
Example 2: splitting workload across all workers using :attr:worker_init_fn
::
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]
>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
... worker_info = torch.utils.data.get_worker_info()
... dataset = worker_info.dataset # the dataset copy in this worker process
... overall_start = dataset.start
... overall_end = dataset.end
... # configure the dataset to only process the split workload
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... dataset.start = overall_start + worker_id * per_worker
... dataset.end = min(dataset.start + per_worker, overall_end)
...
>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
View Source
def __init__(self, datasets, probs=None, longest=False): self.datasets = datasets self.probs = probs self.longest = longest
Inherited Members
- torch.utils.data.dataset.IterableDataset
- functions
- reduce_ex_hook
- register_function
- register_datapipe_as_function
- set_reduce_ex_hook
- type