--- title: Lottery Ticket Hypothesis keywords: fastai sidebar: home_sidebar summary: "How to find winning tickets with fastai" description: "How to find winning tickets with fastai" nb_path: "nbs/03_tutorial.lottery_ticket.ipynb" ---
The Lottery Ticket Hypothesis is a really intriguing discovery made in 2019 by Frankle & Carbin. It states that:
A randomly-initialized, dense neural network contains a subnetwork that is initialised such that — when trained in isolation — it can match the test accuracy of the original network after training for at most the same number of iterations.
Meaning that, once we find that subnetwork. Every other parameter in the network becomes useless.
The way authors propose to find those subnetwork is as follows:1. Initialize the neural network2. Train it to convergence
from fasterai.sparse.all import *
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)
What we are trying to prove is that: in a neural network A, there exists a subnetwork B able to get an accuracy $a_B > a_A$, in a training time $t_B < t_A$.
Let's get the baseline for network A:
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
Let's save original weights
initial_weights = deepcopy(learn.model.state_dict())
learn.fit(5, 1e-3)
We now have our accuracy $a_A$ of $74\%$ and our training time $t_A$ of $5$ epochs
To find the lottery ticket, we will perform iterative pruning but, at each pruning step we will re-initialize the remaining weights to their original values (i.e. before training).
We will restart from the same initialization to be sure to not get lucky.
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn.model.load_state_dict(initial_weights)
We can pass the parameters lth=True
to make the weights of the network reset to their original value after each pruning step, i.e. step 4) of the LTH. To empirically validate the LTH, we need to retrain the found "lottery ticket" after the pruning phase. Lottery tickets are usually found following an iterative pruning schedule. We set the start_epoch
parameter to $5$ to begin the pruning process after $5$ epochs.
sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, iterative, start_epoch=5, lth=True)
As our iterative
schedule makes $3$ pruning steps by default, it means that we have to train our network for start_epoch
+ $3*t_B$, so $20$ epochs in order to get our LTH. After each step, the remaining weights will be reinitialized to their original value
learn.fit(20, 1e-3, cbs=sp_cb)
We indeed have a network B, whose accuracy $a_B > a_A$ in the same training time.
In some case, LTH fails for deeper networks, author then propose a solution, which is to rewind the weights to a more advanced iteration instead of the initialization value.
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn.model.load_state_dict(initial_weights)
This can be done in fasterai by passing the rewind_epoch
parameter, that will save the weights at that epoch, then resetting the weights accordingly.
sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, iterative, start_epoch=5, lth=True, rewind_epoch=1)
learn.fit(20, 1e-3, cbs=sp_cb)
Researchers from Uber AI investigated the LTH and found the existence of what they call "Super-Masks", i.e. masks that, applied on a untrained neural network, allows to reach better-than-random results.
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn.model.load_state_dict(initial_weights)
To find supermasks, authors perform the LTH method then apply the mask on the original, untrained network. In fasterai, you can pass the parameter reset_end=True
, which will reset the weights to their original value at the end of the training, but keeping the pruned weights (i.e. the mask) unchanged.
sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, iterative, start_epoch=5, lth=True, reset_end=True)
learn.fit(20, 1e-3, cbs=sp_cb)
learn.validate()