import os
import sys
for i in range(16):
target = f"imagenet-{i:06d}.tgz"
if os.path.islink(target):
continue
os.symlink("../testdata/imagenet-000000.tgz", target)
%%writefile demo.py
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import webdataset as wds
from functools import partial
from torch.nn.parallel import DistributedDataParallel as DDP
def mockdata(_):
return torch.randn(10), torch.randn(5)
os.environ["GOPEN_VERBOSE"] = "1"
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def make_loader():
shardlist = partial(wds.PytorchShardList, epoch_shuffle=True, verbose=False)
dataset = wds.WebDataset("imagenet-{000000..000015}.tgz", shardlist=shardlist).map(mockdata)
loader = wds.WebLoader(dataset, num_workers=4, batch_size=20)
return loader
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank} of {world_size}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
loader = make_loader()
for epoch in range(2):
print("=== epoch", epoch)
os.environ["WDS_EPOCH"] = str(epoch)
for inputs, labels in loader:
optimizer.zero_grad()
outputs = ddp_model(inputs.to(rank))
loss_fn(outputs, labels.to(rank)).backward()
optimizer.step()
cleanup()
Overwriting demo.py
from imp import reload
import demo
reload(demo)
<module 'demo' from '/home1/tmb/proj/webdataset/notebooks/demo.py'>
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
run_demo(demo.demo_basic, 2)