webdataset.filters
A collection of iterators for data transformations.
These functions are plain iterator functions. You can find curried versions in webdataset.filters, and you can find IterableDataset wrappers in webdataset.processing.
View Source
# # Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # """A collection of iterators for data transformations. These functions are plain iterator functions. You can find curried versions in webdataset.filters, and you can find IterableDataset wrappers in webdataset.processing. """ import itertools, os, random, sys, time from functools import reduce, wraps import numpy as np from . import autodecode, utils from .pytorch import TorchTensor from .utils import PipelineStage class FilterFunction(object): """Helper class for currying pipeline stages. We use this roundabout construct becauce it can be pickled. """ def __init__(self, f, *args, **kw): """Create a curried function.""" self.f = f self.args = args self.kw = kw def __call__(self, data): """Call the curried function with the given argument.""" return self.f(data, *self.args, **self.kw) def __str__(self): """Compute a string representation.""" return f"<{self.f.__name__} {self.args} {self.kw}>" def __repr__(self): """Compute a string representation.""" return f"<{self.f.__name__} {self.args} {self.kw}>" class RestCurried(object): """Helper class for currying pipeline stages. We use this roundabout construct because it can be pickled. """ def __init__(self, f): """Store the function for future currying.""" self.f = f def __call__(self, *args, **kw): """Curry with the given arguments.""" return FilterFunction(self.f, *args, **kw) def pipelinefilter(f): """Turn the decorated function into one that is partially applied for all arguments other than the first.""" result = RestCurried(f) return result def reraise_exception(exn): """Reraises the given exception; used as a handler. :param exn: exception """ raise exn def identity(x): """Return the argument.""" return x def compose2(f, g): """Compose two functions, g(f(x)).""" return lambda x: g(f(x)) def compose(*args): """Compose a sequence of functions (left-to-right).""" return reduce(compose2, args) def pipeline(source, *args): """Write an input pipeline; first argument is source, rest are filters.""" if len(args) == 0: return source return compose(*args)(source) def getfirst(a, keys, default=None, missing_is_error=True): """Get the first matching key from a dictionary. Keys can be specified as a list, or as a string of keys separated by ';'. """ if isinstance(keys, str): assert " " not in keys keys = keys.split(";") for k in keys: if k in a: return a[k] if missing_is_error: raise ValueError(f"didn't find {keys} in {list(a.keys())}") return default def parse_field_spec(fields): """Parse a specification for a list of fields to be extracted. Keys are separated by spaces in the spec. Each key can itself be composed of key alternatives separated by ';'. """ if isinstance(fields, str): fields = fields.split() return [field.split(";") for field in fields] def transform_with(sample, transformers): """Transform a list of values using a list of functions. sample: list of values transformers: list of functions If there are fewer transformers than inputs, or if a transformer function is None, then the identity function is used for the corresponding sample fields. """ if transformers is None or len(transformers) == 0: return sample result = list(sample) assert len(transformers) <= len(sample) for i in range(len(transformers)): # skipcq: PYL-C0200 f = transformers[i] if f is not None: result[i] = f(sample[i]) return result ### # Iterators ### def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""): """Print information about the samples that are passing through. :param data: source iterator :param fmt: format statement (using sample dict as keyword) :param n: when to stop :param every: how often to print :param width: maximum width :param stream: output stream :param name: identifier printed before any output """ for i, sample in enumerate(data): if i < n or (every > 0 and (i + 1) % every == 0): if fmt is None: print("---", name, file=stream) for k, v in sample.items(): print(k, repr(v)[:width], file=stream) else: print(fmt.format(**sample), file=stream) yield sample info = pipelinefilter(_info) def pick(buf, rng): k = rng.randint(0, len(buf) - 1) sample = buf[k] buf[k] = buf[-1] buf.pop() return sample def _shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): """Shuffle the data in the stream. This uses a buffer of size `bufsize`. Shuffling at startup is less random; this is traded off against yielding samples quickly. data: iterator bufsize: buffer size for shuffling returns: iterator rng: either random module or random.Random instance """ if rng is None: rng = random.Random(int((os.getpid() + time.time()) * 1e9)) initial = min(initial, bufsize) buf = [] for sample in data: buf.append(sample) if len(buf) < bufsize: try: buf.append(next(data)) # skipcq: PYL-R1708 except StopIteration: pass if len(buf) >= initial: yield pick(buf, rng) while len(buf) > 0: yield pick(buf, rng) shuffle = pipelinefilter(_shuffle) class detshuffle(PipelineStage): def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1): self.bufsize = bufsize self.initial = initial self.seed = seed self.epoch = epoch def run(self, src): self.epoch += 1 rng = random.Random() rng.seed((self.seed, self.epoch)) return _shuffle(src, self.bufsize, self.initial, rng) def _select(data, predicate): """Select samples based on a predicate. :param data: source iterator :param predicate: predicate (function) """ for sample in data: if predicate(sample): yield sample select = pipelinefilter(_select) def _log_keys(data, logfile=None): import fcntl if logfile is None or logfile == "": for sample in data: yield sample else: with open(logfile, "a") as stream: for i, sample in enumerate(data): buf = f"{i}\t{sample.get('__worker__')}\t{sample.get('__rank__')}\t{sample.get('__key__')}\n" try: fcntl.flock(stream.fileno(), fcntl.LOCK_EX) stream.write(buf) finally: fcntl.flock(stream.fileno(), fcntl.LOCK_UN) yield sample log_keys = pipelinefilter(_log_keys) def _decode(data, *args, handler=reraise_exception, **kw): """Decode data based on the decoding functions given as arguments.""" decoder = lambda x: autodecode.imagehandler(x) if isinstance(x, str) else x handlers = [decoder(x) for x in args] f = autodecode.Decoder(handlers, **kw) for sample in data: assert isinstance(sample, dict), sample try: decoded = f(sample) except Exception as exn: # skipcq: PYL-W0703 if handler(exn): continue else: break yield decoded decode = pipelinefilter(_decode) def _map(data, f, handler=reraise_exception): """Map samples.""" for sample in data: try: result = f(sample) except Exception as exn: if handler(exn): continue else: break if result is None: continue if isinstance(sample, dict) and isinstance(result, dict): result["__key__"] = sample.get("__key__") yield result map = pipelinefilter(_map) def _rename(data, handler=reraise_exception, keep=True, **kw): """Rename samples based on keyword arguments.""" for sample in data: try: if not keep: yield { k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items() } else: def listify(v): return v.split(";") if isinstance(v, str) else v to_be_replaced = {x for v in kw.values() for x in listify(v)} result = {k: v for k, v in sample.items() if k not in to_be_replaced} result.update( { k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items() } ) yield result except Exception as exn: if handler(exn): continue else: break rename = pipelinefilter(_rename) def _associate(data, associator, **kw): """Associate additional data with samples.""" for sample in data: if callable(associator): extra = associator(sample["__key__"]) else: extra = associator.get(sample["__key__"], {}) sample.update(extra) # destructive yield sample associate = pipelinefilter(_associate) def _map_dict(data, handler=reraise_exception, **kw): """Map the entries in a dict sample with individual functions.""" assert len(list(kw.keys())) > 0 for key, f in kw.items(): assert callable(f), (key, f) for sample in data: assert isinstance(sample, dict) try: for k, f in kw.items(): sample[k] = f(sample[k]) except Exception as exn: if handler(exn): continue else: break yield sample map_dict = pipelinefilter(_map_dict) def _to_tuple( data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None ): """Convert dict samples to tuples.""" if none_is_error is None: none_is_error = missing_is_error if len(args) == 1 and isinstance(args[0], str) and " " in args[0]: args = args[0].split() for sample in data: try: result = tuple( [getfirst(sample, f, missing_is_error=missing_is_error) for f in args] ) if none_is_error and any(x is None for x in result): raise ValueError(f"to_tuple {args} got {sample.keys()}") yield result except Exception as exn: if handler(exn): continue else: break to_tuple = pipelinefilter(_to_tuple) def _map_tuple(data, *args, handler=reraise_exception): """Map the entries of a tuple with individual functions.""" args = [f if f is not None else utils.identity for f in args] for f in args: assert callable(f), f for sample in data: assert isinstance(sample, (list, tuple)) sample = list(sample) n = min(len(args), len(sample)) try: for i in range(n): sample[i] = args[i](sample[i]) except Exception as exn: if handler(exn): continue else: break yield tuple(sample) map_tuple = pipelinefilter(_map_tuple) def default_collation_fn(samples, combine_tensors=True, combine_scalars=True): """Take a collection of samples (dictionaries) and create a batch. If `tensors` is True, `ndarray` objects are combined into tensor batches. :param dict samples: list of samples :param bool tensors: whether to turn lists of ndarrays into a single ndarray :returns: single sample consisting of a batch :rtype: dict """ assert isinstance(samples[0], (list, tuple)), type(samples[0]) batched = list(zip(*samples)) result = [] for b in batched: if isinstance(b[0], (int, float)): if combine_scalars: b = np.array(list(b)) elif isinstance(b[0], TorchTensor): if combine_tensors: import torch b = torch.stack(list(b)) elif isinstance(b[0], np.ndarray): if combine_tensors: b = np.array(list(b)) else: b = list(b) result.append(b) return result def _batched( data, batchsize=20, collation_fn=default_collation_fn, partial=True, ): """Create batches of the given size. :param data: iterator :param batchsize: target batch size :param tensors: automatically batch lists of ndarrays into ndarrays :param partial: return partial batches :returns: iterator """ batch = [] for sample in data: if len(batch) >= batchsize: if collation_fn is not None: batch = collation_fn(batch) yield batch batch = [] batch.append(sample) if len(batch) == 0: return elif len(batch) == batchsize or partial: if collation_fn is not None: batch = collation_fn(batch) yield batch batched = pipelinefilter(_batched) def _unlisted(data): """Turn batched data back into unbatched data.""" for batch in data: assert isinstance(batch, list), sample for sample in batch: yield sample unlisted = pipelinefilter(_unlisted) def _unbatched(data): """Turn batched data back into unbatched data.""" for sample in data: assert isinstance(sample, (tuple, list)), sample assert len(sample) > 0 for i in range(len(sample[0])): yield tuple(x[i] for x in sample) unbatched = pipelinefilter(_unbatched) def _rsample(data, p=0.5): """Randomly subsample a stream of data.""" assert p >= 0.0 and p <= 1.0 for sample in data: if random.uniform(0.0, 1.0) < p: yield sample rsample = pipelinefilter(_rsample) slice = pipelinefilter(itertools.islice)
View Source
class FilterFunction(object): """Helper class for currying pipeline stages. We use this roundabout construct becauce it can be pickled. """ def __init__(self, f, *args, **kw): """Create a curried function.""" self.f = f self.args = args self.kw = kw def __call__(self, data): """Call the curried function with the given argument.""" return self.f(data, *self.args, **self.kw) def __str__(self): """Compute a string representation.""" return f"<{self.f.__name__} {self.args} {self.kw}>" def __repr__(self): """Compute a string representation.""" return f"<{self.f.__name__} {self.args} {self.kw}>"
Helper class for currying pipeline stages.
We use this roundabout construct becauce it can be pickled.
View Source
def __init__(self, f, *args, **kw): """Create a curried function.""" self.f = f self.args = args self.kw = kw
Create a curried function.
View Source
class RestCurried(object): """Helper class for currying pipeline stages. We use this roundabout construct because it can be pickled. """ def __init__(self, f): """Store the function for future currying.""" self.f = f def __call__(self, *args, **kw): """Curry with the given arguments.""" return FilterFunction(self.f, *args, **kw)
Helper class for currying pipeline stages.
We use this roundabout construct because it can be pickled.
View Source
def __init__(self, f): """Store the function for future currying.""" self.f = f
Store the function for future currying.
View Source
def pipelinefilter(f): """Turn the decorated function into one that is partially applied for all arguments other than the first.""" result = RestCurried(f) return result
Turn the decorated function into one that is partially applied for all arguments other than the first.
View Source
def reraise_exception(exn): """Reraises the given exception; used as a handler. :param exn: exception """ raise exn
Reraises the given exception; used as a handler.
:param exn: exception
View Source
def identity(x): """Return the argument.""" return x
Return the argument.
View Source
def compose2(f, g): """Compose two functions, g(f(x)).""" return lambda x: g(f(x))
Compose two functions, g(f(x)).
View Source
def compose(*args): """Compose a sequence of functions (left-to-right).""" return reduce(compose2, args)
Compose a sequence of functions (left-to-right).
View Source
def pipeline(source, *args): """Write an input pipeline; first argument is source, rest are filters.""" if len(args) == 0: return source return compose(*args)(source)
Write an input pipeline; first argument is source, rest are filters.
View Source
def getfirst(a, keys, default=None, missing_is_error=True): """Get the first matching key from a dictionary. Keys can be specified as a list, or as a string of keys separated by ';'. """ if isinstance(keys, str): assert " " not in keys keys = keys.split(";") for k in keys: if k in a: return a[k] if missing_is_error: raise ValueError(f"didn't find {keys} in {list(a.keys())}") return default
Get the first matching key from a dictionary.
Keys can be specified as a list, or as a string of keys separated by ';'.
View Source
def parse_field_spec(fields): """Parse a specification for a list of fields to be extracted. Keys are separated by spaces in the spec. Each key can itself be composed of key alternatives separated by ';'. """ if isinstance(fields, str): fields = fields.split() return [field.split(";") for field in fields]
Parse a specification for a list of fields to be extracted.
Keys are separated by spaces in the spec. Each key can itself be composed of key alternatives separated by ';'.
View Source
def transform_with(sample, transformers): """Transform a list of values using a list of functions. sample: list of values transformers: list of functions If there are fewer transformers than inputs, or if a transformer function is None, then the identity function is used for the corresponding sample fields. """ if transformers is None or len(transformers) == 0: return sample result = list(sample) assert len(transformers) <= len(sample) for i in range(len(transformers)): # skipcq: PYL-C0200 f = transformers[i] if f is not None: result[i] = f(sample[i]) return result
Transform a list of values using a list of functions.
sample: list of values transformers: list of functions
If there are fewer transformers than inputs, or if a transformer function is None, then the identity function is used for the corresponding sample fields.
View Source
def pick(buf, rng): k = rng.randint(0, len(buf) - 1) sample = buf[k] buf[k] = buf[-1] buf.pop() return sample
View Source
class detshuffle(PipelineStage): def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1): self.bufsize = bufsize self.initial = initial self.seed = seed self.epoch = epoch def run(self, src): self.epoch += 1 rng = random.Random() rng.seed((self.seed, self.epoch)) return _shuffle(src, self.bufsize, self.initial, rng)
View Source
def __init__(self, bufsize=1000, initial=100, seed=0, epoch=-1): self.bufsize = bufsize self.initial = initial self.seed = seed self.epoch = epoch
View Source
def run(self, src): self.epoch += 1 rng = random.Random() rng.seed((self.seed, self.epoch)) return _shuffle(src, self.bufsize, self.initial, rng)
Inherited Members
View Source
def default_collation_fn(samples, combine_tensors=True, combine_scalars=True): """Take a collection of samples (dictionaries) and create a batch. If `tensors` is True, `ndarray` objects are combined into tensor batches. :param dict samples: list of samples :param bool tensors: whether to turn lists of ndarrays into a single ndarray :returns: single sample consisting of a batch :rtype: dict """ assert isinstance(samples[0], (list, tuple)), type(samples[0]) batched = list(zip(*samples)) result = [] for b in batched: if isinstance(b[0], (int, float)): if combine_scalars: b = np.array(list(b)) elif isinstance(b[0], TorchTensor): if combine_tensors: import torch b = torch.stack(list(b)) elif isinstance(b[0], np.ndarray): if combine_tensors: b = np.array(list(b)) else: b = list(b) result.append(b) return result
Take a collection of samples (dictionaries) and create a batch.
If tensors
is True, ndarray
objects are combined into
tensor batches.
:param dict samples: list of samples :param bool tensors: whether to turn lists of ndarrays into a single ndarray :returns: single sample consisting of a batch :rtype: dict