import sys
sys.path.insert(0, "..")
import webdataset as wds
from torchvision import transforms
from pprint import pprint
import torch
import random
from torchmore import flex, combos, layers
from torch import nn
import pytorch_lightning as pl
def collate4ocr(samples):
    """Collate image+sequence samples into batches.

    This returns an image batch and a compressed sequence batch using CTCLoss conventions.
    """
    images, seqs = zip(*samples)
    images = [im.unsqueeze(2) if im.ndimension() == 2 else im for im in images]
    bh, bw, bd = map(max, zip(*[x.shape for x in images]))
    result = torch.zeros((len(images), bh, bw, bd), dtype=torch.float)
    for i, im in enumerate(images):
        if im.dtype == torch.uint8:
            im = im.float() / 255.0
        h, w, d = im.shape
        dy, dx = random.randint(0, bh - h), random.randint(0, bw - w)
        result[i, dy : dy + h, dx : dx + w, :d] = im
    return result, seqs

bucket = "pipe:curl -s -L http://storage.googleapis.com/nvdata-ocropus-words/"
shards_train = bucket + "uw3-word-{000000..000022}.tar"

def make_loader(spec, num_workers=4, batch_size=8, nshuffle=1000):
    dataset = wds.DataPipeline(
        wds.shardspec(spec),
        wds.split_by_worker,
        wds.detshuffle(100),
        wds.cached_tarfile_to_samples(verbose=True),
        wds.shuffle(nshuffle),
        wds.decode("torchrgb"),
        wds.to_tuple("png;jpg txt"),
        wds.batched(batch_size, collation_fn=collate4ocr),
    )
    loader = wds.WebLoader(dataset, num_workers=num_workers, batch_size=None)
    return loader

dl = make_loader(shards_train)
image, info = next(iter(dl))
print(image.shape)
print(info)
# opening ./_cache/nvdata-ocropus-words/uw3-word-000000.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000001.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000002.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000003.tar



torch.Size([8, 3, 50, 315])
['loading', 'parameters.', 'of', 'generalized', 'alloys.', 'ferromagsets', 'thermodynamical', 'The']
class MaxReduce(nn.Module):
    d: int
    def __init__(self, d: int):
        super().__init__()
        self.d = d
    def forward(self, x):
        return x.max(self.d)[0]


def make_text_model(noutput=1024, shape=(1, 3, 48, 300)):
    """Text recognition model using 2D LSTM and convolutions."""
    model = nn.Sequential(
        *combos.conv2d_block(32, 3, mp=(2, 1), repeat=2),
        *combos.conv2d_block(48, 3, mp=(2, 1), repeat=2),
        *combos.conv2d_block(64, 3, mp=2, repeat=2),
        *combos.conv2d_block(96, 3, repeat=2),
        flex.Lstm2(100),
        # layers.Fun("lambda x: x.max(2)[0]"),
        MaxReduce(2),
        flex.ConvTranspose1d(400, 1, stride=2),
        flex.Conv1d(100, 3),
        flex.BatchNorm1d(),
        nn.ReLU(),
        layers.Reorder("BDL", "LBD"),
        flex.LSTM(100, bidirectional=True),
        layers.Reorder("LBD", "BDL"),
        flex.Conv1d(noutput, 1),
    )
    flex.shape_inference(model, shape)
    return model
def pack_for_ctc(seqs):
    """Pack a list of sequences for nn.CTCLoss."""
    allseqs = torch.cat(seqs).long()
    alllens = torch.tensor([len(s) for s in seqs]).long()
    return (allseqs, alllens)


class TextModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = make_text_model()
        self.ctc_loss = nn.CTCLoss(zero_infinity=True)
    def forward(self, x):
        return self.model(x)
    def compute_loss(self, outputs, targets):
        targets, tlens = pack_for_ctc(targets)
        b, d, L = outputs.size()
        olens = torch.full((b,), L, dtype=torch.long)
        outputs = outputs.log_softmax(1)
        outputs = layers.reorder(outputs, "BDL", "LBD")
        assert tlens.size(0) == b
        assert tlens.sum() == targets.size(0)
        return self.ctc_loss(outputs.cpu(), targets.cpu(), olens.cpu(), tlens.cpu())
    def training_step(self, batch, batch_nb):
        images, texts = batch
        outputs = self.forward(images)
        seqs = [torch.tensor([ord(c) for c in s]) for s in texts]
        loss = self.compute_loss(outputs.log_softmax(1), seqs)
        return loss
    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=1e-4)
text_model = TextModel()
train_loader = make_loader(shards_train)
trainer = pl.Trainer(gpus=[0], max_epochs=3)
trainer.fit(text_model, train_loader)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | model    | Sequential | 1.2 M 
1 | ctc_loss | CTCLoss    | 0     
----------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.868     Total estimated model params size (MB)


Epoch 0: : 0it [00:00, ?it/s]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000001.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000002.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000000.tar

# opening ./_cache/nvdata-ocropus-words/uw3-word-000003.tar


Epoch 0: : 4it [00:00,  9.77it/s, loss=276, v_num=20]

/home/tmb/proj/webdataset/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:56: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 8. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
  warning_cache.warn(


Epoch 0: : 9487it [07:28, 21.15it/s, loss=6.29, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000016.tar to ./_cache/nvdata-ocropus-words/uw3-word-000016.tar


Epoch 0: : 9488it [07:28, 21.15it/s, loss=6.29, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000017.tar to ./_cache/nvdata-ocropus-words/uw3-word-000017.tar


Epoch 0: : 9489it [07:28, 21.15it/s, loss=6.3, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000018.tar


Epoch 0: : 9490it [07:28, 21.15it/s, loss=6.31, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000015.tar to ./_cache/nvdata-ocropus-words/uw3-word-000015.tar


Epoch 0: : 9495it [07:28, 21.15it/s, loss=6.34, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000015.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000017.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000016.tar


Epoch 0: : 19487it [15:08, 21.44it/s, loss=4.95, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000012.tar to ./_cache/nvdata-ocropus-words/uw3-word-000012.tar


Epoch 0: : 19488it [15:08, 21.44it/s, loss=4.98, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000013.tar to ./_cache/nvdata-ocropus-words/uw3-word-000013.tar


Epoch 0: : 19489it [15:09, 21.44it/s, loss=4.97, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000014.tar to ./_cache/nvdata-ocropus-words/uw3-word-000014.tar


Epoch 0: : 19490it [15:09, 21.44it/s, loss=4.99, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000019.tar


Epoch 0: : 19495it [15:09, 21.44it/s, loss=5.06, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000013.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000014.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000012.tar


Epoch 0: : 29487it [22:43, 21.63it/s, loss=4.26, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000020.tar to ./_cache/nvdata-ocropus-words/uw3-word-000020.tar


Epoch 0: : 29488it [22:43, 21.63it/s, loss=4.26, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000021.tar to ./_cache/nvdata-ocropus-words/uw3-word-000021.tar


Epoch 0: : 29489it [22:43, 21.63it/s, loss=4.27, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000022.tar to ./_cache/nvdata-ocropus-words/uw3-word-000022.tar


Epoch 0: : 29490it [22:43, 21.63it/s, loss=4.3, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000007.tar to ./_cache/nvdata-ocropus-words/uw3-word-000007.tar


Epoch 0: : 29495it [22:43, 21.63it/s, loss=4.37, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000022.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000020.tar


Epoch 0: : 29496it [22:49, 21.54it/s, loss=4.38, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000007.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000021.tar



Epoch 0: : 38077it [29:00, 21.88it/s, loss=4.31, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000006.tar to ./_cache/nvdata-ocropus-words/uw3-word-000006.tar


Epoch 0: : 38085it [29:00, 21.88it/s, loss=4.36, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000006.tar


Epoch 0: : 39487it [30:01, 21.92it/s, loss=4.19, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000004.tar to ./_cache/nvdata-ocropus-words/uw3-word-000004.tar


Epoch 0: : 39488it [30:01, 21.92it/s, loss=4.18, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000005.tar to ./_cache/nvdata-ocropus-words/uw3-word-000005.tar


Epoch 0: : 39490it [30:01, 21.92it/s, loss=4.1, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000011.tar to ./_cache/nvdata-ocropus-words/uw3-word-000011.tar


Epoch 0: : 39495it [30:02, 21.92it/s, loss=4.15, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000005.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000011.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000004.tar


Epoch 0: : 48077it [36:08, 22.17it/s, loss=4.3, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000010.tar to ./_cache/nvdata-ocropus-words/uw3-word-000010.tar


Epoch 0: : 48085it [36:08, 22.17it/s, loss=4.36, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000010.tar


Epoch 0: : 49487it [37:08, 22.21it/s, loss=4.41, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000008.tar to ./_cache/nvdata-ocropus-words/uw3-word-000008.tar


Epoch 0: : 49488it [37:08, 22.21it/s, loss=4.44, v_num=20]

# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000009.tar to ./_cache/nvdata-ocropus-words/uw3-word-000009.tar


Epoch 0: : 49495it [37:08, 22.21it/s, loss=4.33, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000008.tar


Epoch 0: : 49496it [37:11, 22.18it/s, loss=4.29, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000009.tar


Epoch 0: : 56445it [42:10, 22.31it/s, loss=4.27, v_num=20]

/home/tmb/proj/webdataset/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:56: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 4. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
  warning_cache.warn(


Epoch 1: : 0it [00:00, ?it/s, loss=4, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000002.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000001.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000000.tar

# opening ./_cache/nvdata-ocropus-words/uw3-word-000003.tar


Epoch 1: : 9487it [06:51, 23.06it/s, loss=4.11, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000016.tar


Epoch 1: : 9488it [06:51, 23.06it/s, loss=4.12, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000017.tar


Epoch 1: : 9489it [06:51, 23.06it/s, loss=4.14, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000018.tar


Epoch 1: : 9490it [06:51, 23.06it/s, loss=4.15, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000015.tar


Epoch 1: : 19487it [13:43, 23.66it/s, loss=3.91, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000012.tar


Epoch 1: : 19488it [13:43, 23.66it/s, loss=3.93, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000013.tar


Epoch 1: : 19489it [13:43, 23.66it/s, loss=3.92, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000014.tar


Epoch 1: : 19490it [13:43, 23.66it/s, loss=3.94, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000019.tar


Epoch 1: : 29487it [20:39, 23.79it/s, loss=3.68, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000020.tar


Epoch 1: : 29488it [20:39, 23.79it/s, loss=3.68, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000021.tar


Epoch 1: : 29489it [20:39, 23.79it/s, loss=3.69, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000022.tar


Epoch 1: : 29490it [20:39, 23.79it/s, loss=3.71, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000007.tar


Epoch 1: : 38077it [26:33, 23.89it/s, loss=3.86, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000006.tar


Epoch 1: : 39487it [27:33, 23.89it/s, loss=3.72, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000004.tar


Epoch 1: : 39488it [27:33, 23.89it/s, loss=3.71, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000005.tar


Epoch 1: : 39490it [27:33, 23.89it/s, loss=3.65, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000011.tar


Epoch 1: : 48077it [33:32, 23.89it/s, loss=3.93, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000010.tar


Epoch 1: : 49487it [34:30, 23.90it/s, loss=4.03, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000008.tar


Epoch 1: : 49488it [34:30, 23.90it/s, loss=4.06, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000009.tar


Epoch 2: : 0it [00:00, ?it/s, loss=3.73, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000002.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000001.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000000.tar


# opening ./_cache/nvdata-ocropus-words/uw3-word-000003.tar


Epoch 2: : 9487it [06:50, 23.14it/s, loss=3.83, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000016.tar


Epoch 2: : 9488it [06:50, 23.14it/s, loss=3.83, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000017.tar


Epoch 2: : 9489it [06:50, 23.14it/s, loss=3.85, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000018.tar


Epoch 2: : 9490it [06:50, 23.14it/s, loss=3.87, v_num=20]

# opening ./_cache/nvdata-ocropus-words/uw3-word-000015.tar


Epoch 2: : 11193it [08:01, 23.23it/s, loss=3.76, v_num=20]