import sys
sys.path.insert(0, "..")
import webdataset as wds
from torchvision import transforms
from pprint import pprint
import torch
import random
from torchmore import flex, combos, layers
from torch import nn
import pytorch_lightning as pl
def collate4ocr(samples):
"""Collate image+sequence samples into batches.
This returns an image batch and a compressed sequence batch using CTCLoss conventions.
"""
images, seqs = zip(*samples)
images = [im.unsqueeze(2) if im.ndimension() == 2 else im for im in images]
bh, bw, bd = map(max, zip(*[x.shape for x in images]))
result = torch.zeros((len(images), bh, bw, bd), dtype=torch.float)
for i, im in enumerate(images):
if im.dtype == torch.uint8:
im = im.float() / 255.0
h, w, d = im.shape
dy, dx = random.randint(0, bh - h), random.randint(0, bw - w)
result[i, dy : dy + h, dx : dx + w, :d] = im
return result, seqs
bucket = "pipe:curl -s -L http://storage.googleapis.com/nvdata-ocropus-words/"
shards_train = bucket + "uw3-word-{000000..000022}.tar"
def make_loader(spec, num_workers=4, batch_size=8, nshuffle=1000):
dataset = wds.DataPipeline(
wds.shardspec(spec),
wds.split_by_worker,
wds.detshuffle(100),
wds.cached_tarfile_to_samples(verbose=True),
wds.shuffle(nshuffle),
wds.decode("torchrgb"),
wds.to_tuple("png;jpg txt"),
wds.batched(batch_size, collation_fn=collate4ocr),
)
loader = wds.WebLoader(dataset, num_workers=num_workers, batch_size=None)
return loader
dl = make_loader(shards_train)
image, info = next(iter(dl))
print(image.shape)
print(info)
# opening ./_cache/nvdata-ocropus-words/uw3-word-000000.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000001.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000002.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000003.tar
torch.Size([8, 3, 50, 315])
['loading', 'parameters.', 'of', 'generalized', 'alloys.', 'ferromagsets', 'thermodynamical', 'The']
class MaxReduce(nn.Module):
d: int
def __init__(self, d: int):
super().__init__()
self.d = d
def forward(self, x):
return x.max(self.d)[0]
def make_text_model(noutput=1024, shape=(1, 3, 48, 300)):
"""Text recognition model using 2D LSTM and convolutions."""
model = nn.Sequential(
*combos.conv2d_block(32, 3, mp=(2, 1), repeat=2),
*combos.conv2d_block(48, 3, mp=(2, 1), repeat=2),
*combos.conv2d_block(64, 3, mp=2, repeat=2),
*combos.conv2d_block(96, 3, repeat=2),
flex.Lstm2(100),
# layers.Fun("lambda x: x.max(2)[0]"),
MaxReduce(2),
flex.ConvTranspose1d(400, 1, stride=2),
flex.Conv1d(100, 3),
flex.BatchNorm1d(),
nn.ReLU(),
layers.Reorder("BDL", "LBD"),
flex.LSTM(100, bidirectional=True),
layers.Reorder("LBD", "BDL"),
flex.Conv1d(noutput, 1),
)
flex.shape_inference(model, shape)
return model
def pack_for_ctc(seqs):
"""Pack a list of sequences for nn.CTCLoss."""
allseqs = torch.cat(seqs).long()
alllens = torch.tensor([len(s) for s in seqs]).long()
return (allseqs, alllens)
class TextModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = make_text_model()
self.ctc_loss = nn.CTCLoss(zero_infinity=True)
def forward(self, x):
return self.model(x)
def compute_loss(self, outputs, targets):
targets, tlens = pack_for_ctc(targets)
b, d, L = outputs.size()
olens = torch.full((b,), L, dtype=torch.long)
outputs = outputs.log_softmax(1)
outputs = layers.reorder(outputs, "BDL", "LBD")
assert tlens.size(0) == b
assert tlens.sum() == targets.size(0)
return self.ctc_loss(outputs.cpu(), targets.cpu(), olens.cpu(), tlens.cpu())
def training_step(self, batch, batch_nb):
images, texts = batch
outputs = self.forward(images)
seqs = [torch.tensor([ord(c) for c in s]) for s in texts]
loss = self.compute_loss(outputs.log_softmax(1), seqs)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=1e-4)
text_model = TextModel()
train_loader = make_loader(shards_train)
trainer = pl.Trainer(gpus=[0], max_epochs=3)
trainer.fit(text_model, train_loader)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
----------------------------------------
0 | model | Sequential | 1.2 M
1 | ctc_loss | CTCLoss | 0
----------------------------------------
1.2 M Trainable params
0 Non-trainable params
1.2 M Total params
4.868 Total estimated model params size (MB)
Epoch 0: : 0it [00:00, ?it/s]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000001.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000002.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000000.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000003.tar
Epoch 0: : 4it [00:00, 9.77it/s, loss=276, v_num=20]
/home/tmb/proj/webdataset/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:56: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 8. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
warning_cache.warn(
Epoch 0: : 9487it [07:28, 21.15it/s, loss=6.29, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000016.tar to ./_cache/nvdata-ocropus-words/uw3-word-000016.tar
Epoch 0: : 9488it [07:28, 21.15it/s, loss=6.29, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000017.tar to ./_cache/nvdata-ocropus-words/uw3-word-000017.tar
Epoch 0: : 9489it [07:28, 21.15it/s, loss=6.3, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000018.tar
Epoch 0: : 9490it [07:28, 21.15it/s, loss=6.31, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000015.tar to ./_cache/nvdata-ocropus-words/uw3-word-000015.tar
Epoch 0: : 9495it [07:28, 21.15it/s, loss=6.34, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000015.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000017.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000016.tar
Epoch 0: : 19487it [15:08, 21.44it/s, loss=4.95, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000012.tar to ./_cache/nvdata-ocropus-words/uw3-word-000012.tar
Epoch 0: : 19488it [15:08, 21.44it/s, loss=4.98, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000013.tar to ./_cache/nvdata-ocropus-words/uw3-word-000013.tar
Epoch 0: : 19489it [15:09, 21.44it/s, loss=4.97, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000014.tar to ./_cache/nvdata-ocropus-words/uw3-word-000014.tar
Epoch 0: : 19490it [15:09, 21.44it/s, loss=4.99, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000019.tar
Epoch 0: : 19495it [15:09, 21.44it/s, loss=5.06, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000013.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000014.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000012.tar
Epoch 0: : 29487it [22:43, 21.63it/s, loss=4.26, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000020.tar to ./_cache/nvdata-ocropus-words/uw3-word-000020.tar
Epoch 0: : 29488it [22:43, 21.63it/s, loss=4.26, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000021.tar to ./_cache/nvdata-ocropus-words/uw3-word-000021.tar
Epoch 0: : 29489it [22:43, 21.63it/s, loss=4.27, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000022.tar to ./_cache/nvdata-ocropus-words/uw3-word-000022.tar
Epoch 0: : 29490it [22:43, 21.63it/s, loss=4.3, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000007.tar to ./_cache/nvdata-ocropus-words/uw3-word-000007.tar
Epoch 0: : 29495it [22:43, 21.63it/s, loss=4.37, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000022.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000020.tar
Epoch 0: : 29496it [22:49, 21.54it/s, loss=4.38, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000007.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000021.tar
Epoch 0: : 38077it [29:00, 21.88it/s, loss=4.31, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000006.tar to ./_cache/nvdata-ocropus-words/uw3-word-000006.tar
Epoch 0: : 38085it [29:00, 21.88it/s, loss=4.36, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000006.tar
Epoch 0: : 39487it [30:01, 21.92it/s, loss=4.19, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000004.tar to ./_cache/nvdata-ocropus-words/uw3-word-000004.tar
Epoch 0: : 39488it [30:01, 21.92it/s, loss=4.18, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000005.tar to ./_cache/nvdata-ocropus-words/uw3-word-000005.tar
Epoch 0: : 39490it [30:01, 21.92it/s, loss=4.1, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000011.tar to ./_cache/nvdata-ocropus-words/uw3-word-000011.tar
Epoch 0: : 39495it [30:02, 21.92it/s, loss=4.15, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000005.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000011.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000004.tar
Epoch 0: : 48077it [36:08, 22.17it/s, loss=4.3, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000010.tar to ./_cache/nvdata-ocropus-words/uw3-word-000010.tar
Epoch 0: : 48085it [36:08, 22.17it/s, loss=4.36, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000010.tar
Epoch 0: : 49487it [37:08, 22.21it/s, loss=4.41, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000008.tar to ./_cache/nvdata-ocropus-words/uw3-word-000008.tar
Epoch 0: : 49488it [37:08, 22.21it/s, loss=4.44, v_num=20]
# downloading http://storage.googleapis.com/nvdata-ocropus-words/uw3-word-000009.tar to ./_cache/nvdata-ocropus-words/uw3-word-000009.tar
Epoch 0: : 49495it [37:08, 22.21it/s, loss=4.33, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000008.tar
Epoch 0: : 49496it [37:11, 22.18it/s, loss=4.29, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000009.tar
Epoch 0: : 56445it [42:10, 22.31it/s, loss=4.27, v_num=20]
/home/tmb/proj/webdataset/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:56: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 4. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
warning_cache.warn(
Epoch 1: : 0it [00:00, ?it/s, loss=4, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000002.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000001.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000000.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000003.tar
Epoch 1: : 9487it [06:51, 23.06it/s, loss=4.11, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000016.tar
Epoch 1: : 9488it [06:51, 23.06it/s, loss=4.12, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000017.tar
Epoch 1: : 9489it [06:51, 23.06it/s, loss=4.14, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000018.tar
Epoch 1: : 9490it [06:51, 23.06it/s, loss=4.15, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000015.tar
Epoch 1: : 19487it [13:43, 23.66it/s, loss=3.91, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000012.tar
Epoch 1: : 19488it [13:43, 23.66it/s, loss=3.93, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000013.tar
Epoch 1: : 19489it [13:43, 23.66it/s, loss=3.92, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000014.tar
Epoch 1: : 19490it [13:43, 23.66it/s, loss=3.94, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000019.tar
Epoch 1: : 29487it [20:39, 23.79it/s, loss=3.68, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000020.tar
Epoch 1: : 29488it [20:39, 23.79it/s, loss=3.68, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000021.tar
Epoch 1: : 29489it [20:39, 23.79it/s, loss=3.69, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000022.tar
Epoch 1: : 29490it [20:39, 23.79it/s, loss=3.71, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000007.tar
Epoch 1: : 38077it [26:33, 23.89it/s, loss=3.86, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000006.tar
Epoch 1: : 39487it [27:33, 23.89it/s, loss=3.72, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000004.tar
Epoch 1: : 39488it [27:33, 23.89it/s, loss=3.71, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000005.tar
Epoch 1: : 39490it [27:33, 23.89it/s, loss=3.65, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000011.tar
Epoch 1: : 48077it [33:32, 23.89it/s, loss=3.93, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000010.tar
Epoch 1: : 49487it [34:30, 23.90it/s, loss=4.03, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000008.tar
Epoch 1: : 49488it [34:30, 23.90it/s, loss=4.06, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000009.tar
Epoch 2: : 0it [00:00, ?it/s, loss=3.73, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000002.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000001.tar# opening ./_cache/nvdata-ocropus-words/uw3-word-000000.tar
# opening ./_cache/nvdata-ocropus-words/uw3-word-000003.tar
Epoch 2: : 9487it [06:50, 23.14it/s, loss=3.83, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000016.tar
Epoch 2: : 9488it [06:50, 23.14it/s, loss=3.83, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000017.tar
Epoch 2: : 9489it [06:50, 23.14it/s, loss=3.85, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000018.tar
Epoch 2: : 9490it [06:50, 23.14it/s, loss=3.87, v_num=20]
# opening ./_cache/nvdata-ocropus-words/uw3-word-000015.tar
Epoch 2: : 11193it [08:01, 23.23it/s, loss=3.76, v_num=20]