webdataset.compat

View Source
import copy, os, random, sys, time
from dataclasses import dataclass
from itertools import islice
from typing import List

import braceexpand, yaml

from . import autodecode, cache, filters, shardlists, tariterators
from .filters import reraise_exception
from .pipeline import DataPipeline
from .pytorch import DataLoader, IterableDataset


class FluidInterface:
    def batched(self, batchsize, collation_fn=filters.default_collation_fn, partial=True):
        return self.compose(filters.batched(batchsize, collation_fn=collation_fn, partial=partial))

    def unbatched(self):
        return self.compose(filters.unbatched())

    def listed(self, batchsize, partial=True):
        return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None)

    def unlisted(self):
        return self.compose(filters.unlisted())

    def log_keys(self, logfile=None):
        return self.compose(filters.log_keys(logfile))

    def shuffle(self, size, **kw):
        if size < 1:
            return self
        else:
            return self.compose(filters.shuffle(size, **kw))

    def map(self, f, handler=reraise_exception):
        return self.compose(filters.map(f, handler=handler))

    def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception):
        handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
        decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
        return self.map(decoder, handler=handler)

    def map_dict(self, handler=reraise_exception, **kw):
        return self.compose(filters.map_dict(handler=handler, **kw))

    def select(self, predicate, **kw):
        return self.compose(filters.select(predicate, **kw))

    def to_tuple(self, *args, handler=reraise_exception):
        return self.compose(filters.to_tuple(*args, handler=handler))

    def map_tuple(self, *args, handler=reraise_exception):
        return self.compose(filters.map_tuple(*args, handler=handler))

    def slice(self, *args):
        return self.compose(filters.slice(*args))

    def rename(self, **kw):
        return self.compose(filters.rename(**kw))

    def rsample(self, p=0.5):
        return self.compose(filters.rsample(p))


class WebDataset(DataPipeline, FluidInterface):
    """Small fluid-interface wrapper for DataPipeline."""

    def __init__(
        self,
        urls,
        handler=reraise_exception,
        resampled=False,
        repeat=False,
        shardshuffle=None,
        cache_size=0,
        cache_dir=None,
        detshuffle=False,
        nodesplitter=shardlists.single_node_only,
        verbose=False,
    ):
        super().__init__()
        if isinstance(urls, IterableDataset):
            assert not resampled
            self.append(urls)
        elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
            with (open(urls)) as stream:
                spec = yaml.safe_load(stream)
            assert "datasets" in spec
            self.append(shardlists.MultiShardSample(spec))
        elif isinstance(urls, dict):
            assert "datasets" in urls
            self.append(shardlists.MultiShardSample(urls))
        elif resampled:
            self.append(shardlists.ResampledShards(urls))
        else:
            self.append(shardlists.SimpleShardList(urls))
            self.append(nodesplitter)
            self.append(shardlists.split_by_worker)
            if shardshuffle is True:
                shardshuffle = 100
            if shardshuffle is not None:
                if detshuffle:
                    self.append(filters.detshuffle(shardshuffle))
                else:
                    self.append(filters.shuffle(shardshuffle))
        self.append(shardlists.split_by_node)
        self.append(shardlists.split_by_worker)
        if cache_size == 0:
            self.append(tariterators.tarfile_to_samples(handler=handler))
        else:
            assert cache_size == -1 or cache_size > 0
            self.append(
                cache.cached_tarfile_to_samples(
                    handler=handler,
                    verbose=verbose,
                    cache_size=cache_size,
                    cache_dir=cache_dir,
                )
            )


class FluidWrapper(DataPipeline, FluidInterface):
    """Small fluid-interface wrapper for DataPipeline."""

    def __init__(self, initial):
        super().__init__()
        self.append(initial)


class WebLoader(DataPipeline, FluidInterface):
    def __init__(self, *args, **kw):
        super().__init__(DataLoader(*args, **kw))
#   class FluidInterface:
View Source
class FluidInterface:
    def batched(self, batchsize, collation_fn=filters.default_collation_fn, partial=True):
        return self.compose(filters.batched(batchsize, collation_fn=collation_fn, partial=partial))

    def unbatched(self):
        return self.compose(filters.unbatched())

    def listed(self, batchsize, partial=True):
        return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None)

    def unlisted(self):
        return self.compose(filters.unlisted())

    def log_keys(self, logfile=None):
        return self.compose(filters.log_keys(logfile))

    def shuffle(self, size, **kw):
        if size < 1:
            return self
        else:
            return self.compose(filters.shuffle(size, **kw))

    def map(self, f, handler=reraise_exception):
        return self.compose(filters.map(f, handler=handler))

    def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception):
        handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
        decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
        return self.map(decoder, handler=handler)

    def map_dict(self, handler=reraise_exception, **kw):
        return self.compose(filters.map_dict(handler=handler, **kw))

    def select(self, predicate, **kw):
        return self.compose(filters.select(predicate, **kw))

    def to_tuple(self, *args, handler=reraise_exception):
        return self.compose(filters.to_tuple(*args, handler=handler))

    def map_tuple(self, *args, handler=reraise_exception):
        return self.compose(filters.map_tuple(*args, handler=handler))

    def slice(self, *args):
        return self.compose(filters.slice(*args))

    def rename(self, **kw):
        return self.compose(filters.rename(**kw))

    def rsample(self, p=0.5):
        return self.compose(filters.rsample(p))
#   FluidInterface()
#   def batched( self, batchsize, collation_fn=<function default_collation_fn>, partial=True ):
View Source
    def batched(self, batchsize, collation_fn=filters.default_collation_fn, partial=True):
        return self.compose(filters.batched(batchsize, collation_fn=collation_fn, partial=partial))
#   def unbatched(self):
View Source
    def unbatched(self):
        return self.compose(filters.unbatched())
#   def listed(self, batchsize, partial=True):
View Source
    def listed(self, batchsize, partial=True):
        return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None)
#   def unlisted(self):
View Source
    def unlisted(self):
        return self.compose(filters.unlisted())
#   def log_keys(self, logfile=None):
View Source
    def log_keys(self, logfile=None):
        return self.compose(filters.log_keys(logfile))
#   def shuffle(self, size, **kw):
View Source
    def shuffle(self, size, **kw):
        if size < 1:
            return self
        else:
            return self.compose(filters.shuffle(size, **kw))
#   def map(self, f, handler=<function reraise_exception>):
View Source
    def map(self, f, handler=reraise_exception):
        return self.compose(filters.map(f, handler=handler))
#   def decode( self, *args, pre=None, post=None, only=None, partial=False, handler=<function reraise_exception> ):
View Source
    def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception):
        handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
        decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
        return self.map(decoder, handler=handler)
#   def map_dict(self, handler=<function reraise_exception>, **kw):
View Source
    def map_dict(self, handler=reraise_exception, **kw):
        return self.compose(filters.map_dict(handler=handler, **kw))
#   def select(self, predicate, **kw):
View Source
    def select(self, predicate, **kw):
        return self.compose(filters.select(predicate, **kw))
#   def to_tuple(self, *args, handler=<function reraise_exception>):
View Source
    def to_tuple(self, *args, handler=reraise_exception):
        return self.compose(filters.to_tuple(*args, handler=handler))
#   def map_tuple(self, *args, handler=<function reraise_exception>):
View Source
    def map_tuple(self, *args, handler=reraise_exception):
        return self.compose(filters.map_tuple(*args, handler=handler))
#   def slice(self, *args):
View Source
    def slice(self, *args):
        return self.compose(filters.slice(*args))
#   def rename(self, **kw):
View Source
    def rename(self, **kw):
        return self.compose(filters.rename(**kw))
#   def rsample(self, p=0.5):
View Source
    def rsample(self, p=0.5):
        return self.compose(filters.rsample(p))
#   class WebDataset(torch.utils.data.dataset.Dataset[+T_co]):
View Source
class WebDataset(DataPipeline, FluidInterface):
    """Small fluid-interface wrapper for DataPipeline."""

    def __init__(
        self,
        urls,
        handler=reraise_exception,
        resampled=False,
        repeat=False,
        shardshuffle=None,
        cache_size=0,
        cache_dir=None,
        detshuffle=False,
        nodesplitter=shardlists.single_node_only,
        verbose=False,
    ):
        super().__init__()
        if isinstance(urls, IterableDataset):
            assert not resampled
            self.append(urls)
        elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
            with (open(urls)) as stream:
                spec = yaml.safe_load(stream)
            assert "datasets" in spec
            self.append(shardlists.MultiShardSample(spec))
        elif isinstance(urls, dict):
            assert "datasets" in urls
            self.append(shardlists.MultiShardSample(urls))
        elif resampled:
            self.append(shardlists.ResampledShards(urls))
        else:
            self.append(shardlists.SimpleShardList(urls))
            self.append(nodesplitter)
            self.append(shardlists.split_by_worker)
            if shardshuffle is True:
                shardshuffle = 100
            if shardshuffle is not None:
                if detshuffle:
                    self.append(filters.detshuffle(shardshuffle))
                else:
                    self.append(filters.shuffle(shardshuffle))
        self.append(shardlists.split_by_node)
        self.append(shardlists.split_by_worker)
        if cache_size == 0:
            self.append(tariterators.tarfile_to_samples(handler=handler))
        else:
            assert cache_size == -1 or cache_size > 0
            self.append(
                cache.cached_tarfile_to_samples(
                    handler=handler,
                    verbose=verbose,
                    cache_size=cache_size,
                    cache_dir=cache_dir,
                )
            )

Small fluid-interface wrapper for DataPipeline.

#   WebDataset( urls, handler=<function reraise_exception>, resampled=False, repeat=False, shardshuffle=None, cache_size=0, cache_dir=None, detshuffle=False, nodesplitter=<function single_node_only>, verbose=False )
View Source
    def __init__(
        self,
        urls,
        handler=reraise_exception,
        resampled=False,
        repeat=False,
        shardshuffle=None,
        cache_size=0,
        cache_dir=None,
        detshuffle=False,
        nodesplitter=shardlists.single_node_only,
        verbose=False,
    ):
        super().__init__()
        if isinstance(urls, IterableDataset):
            assert not resampled
            self.append(urls)
        elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
            with (open(urls)) as stream:
                spec = yaml.safe_load(stream)
            assert "datasets" in spec
            self.append(shardlists.MultiShardSample(spec))
        elif isinstance(urls, dict):
            assert "datasets" in urls
            self.append(shardlists.MultiShardSample(urls))
        elif resampled:
            self.append(shardlists.ResampledShards(urls))
        else:
            self.append(shardlists.SimpleShardList(urls))
            self.append(nodesplitter)
            self.append(shardlists.split_by_worker)
            if shardshuffle is True:
                shardshuffle = 100
            if shardshuffle is not None:
                if detshuffle:
                    self.append(filters.detshuffle(shardshuffle))
                else:
                    self.append(filters.shuffle(shardshuffle))
        self.append(shardlists.split_by_node)
        self.append(shardlists.split_by_worker)
        if cache_size == 0:
            self.append(tariterators.tarfile_to_samples(handler=handler))
        else:
            assert cache_size == -1 or cache_size > 0
            self.append(
                cache.cached_tarfile_to_samples(
                    handler=handler,
                    verbose=verbose,
                    cache_size=cache_size,
                    cache_dir=cache_dir,
                )
            )
Inherited Members
webdataset.pipeline.DataPipeline
invoke
iterator1
iterator
stage
append
compose
with_length
with_epoch
repeat
torch.utils.data.dataset.IterableDataset
functions
reduce_ex_hook
register_function
register_datapipe_as_function
set_reduce_ex_hook
type
FluidInterface
batched
unbatched
listed
unlisted
log_keys
shuffle
map
decode
map_dict
select
to_tuple
map_tuple
slice
rename
rsample
#   class FluidWrapper(torch.utils.data.dataset.Dataset[+T_co]):
View Source
class FluidWrapper(DataPipeline, FluidInterface):
    """Small fluid-interface wrapper for DataPipeline."""

    def __init__(self, initial):
        super().__init__()
        self.append(initial)

Small fluid-interface wrapper for DataPipeline.

#   FluidWrapper(initial)
View Source
    def __init__(self, initial):
        super().__init__()
        self.append(initial)
Inherited Members
webdataset.pipeline.DataPipeline
invoke
iterator1
iterator
stage
append
compose
with_length
with_epoch
repeat
torch.utils.data.dataset.IterableDataset
functions
reduce_ex_hook
register_function
register_datapipe_as_function
set_reduce_ex_hook
type
FluidInterface
batched
unbatched
listed
unlisted
log_keys
shuffle
map
decode
map_dict
select
to_tuple
map_tuple
slice
rename
rsample
#   class WebLoader(torch.utils.data.dataset.Dataset[+T_co]):
View Source
class WebLoader(DataPipeline, FluidInterface):
    def __init__(self, *args, **kw):
        super().__init__(DataLoader(*args, **kw))

A pipeline starting with an IterableDataset and a series of filters.

#   WebLoader(*args, **kw)
View Source
    def __init__(self, *args, **kw):
        super().__init__(DataLoader(*args, **kw))
Inherited Members
webdataset.pipeline.DataPipeline
invoke
iterator1
iterator
stage
append
compose
with_length
with_epoch
repeat
torch.utils.data.dataset.IterableDataset
functions
reduce_ex_hook
register_function
register_datapipe_as_function
set_reduce_ex_hook
type
FluidInterface
batched
unbatched
listed
unlisted
log_keys
shuffle
map
decode
map_dict
select
to_tuple
map_tuple
slice
rename
rsample