import os
import sys
for i in range(16):
    target = f"imagenet-{i:06d}.tgz"
    if os.path.islink(target):
        continue
    os.symlink("../testdata/imagenet-000000.tgz", target)
%%writefile demo.py

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import webdataset as wds
from functools import partial

from torch.nn.parallel import DistributedDataParallel as DDP

def mockdata(_):
    return torch.randn(10), torch.randn(5)

os.environ["GOPEN_VERBOSE"] = "1"

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def make_loader():
    shardlist = partial(wds.PytorchShardList, epoch_shuffle=True, verbose=False)
    dataset = wds.WebDataset("imagenet-{000000..000015}.tgz", shardlist=shardlist).map(mockdata)
    loader = wds.WebLoader(dataset, num_workers=4, batch_size=20)
    return loader

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank} of {world_size}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    loader = make_loader()
    for epoch in range(2):
        print("=== epoch", epoch)
        os.environ["WDS_EPOCH"] = str(epoch)
        for inputs, labels in loader:
            optimizer.zero_grad()
            outputs = ddp_model(inputs.to(rank))
            loss_fn(outputs, labels.to(rank)).backward()
            optimizer.step()

    cleanup()
Overwriting demo.py
from imp import reload
import demo
reload(demo)
<module 'demo' from '/home1/tmb/proj/webdataset/notebooks/demo.py'>
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)

run_demo(demo.demo_basic, 2)