--- title: Lottery Ticket Hypothesis keywords: fastai sidebar: home_sidebar summary: "How to find winning tickets with fastai" description: "How to find winning tickets with fastai" nb_path: "nbs/03_tutorial.lottery_ticket.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %}

The Lottery Ticket Hypothesis

The Lottery Ticket Hypothesis is a really intriguing discovery made in 2019 by Frankle & Carbin. It states that:

A randomly-initialized, dense neural network contains a subnetwork that is initialised such that — when trained in isolation — it can match the test accuracy of the original network after training for at most the same number of iterations.

Meaning that, once we find that subnetwork. Every other parameter in the network becomes useless.

The way authors propose to find those subnetwork is as follows:1. Initialize the neural network2. Train it to convergence

  1. Prune the smallest magnitude weights by creating a mask $m$
  2. Reinitialize the weights to their original value; i.e at iteration $0$.
  3. Repeat from step 2 until reaching the desired level of sparsity.
{% raw %}
from fasterai.sparse.all import *
{% endraw %} {% raw %}
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)
{% endraw %}

What we are trying to prove is that: in a neural network A, there exists a subnetwork B able to get an accuracy $a_B > a_A$, in a training time $t_B < t_A$.

Let's get the baseline for network A:

{% raw %}
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
{% endraw %}

Let's save original weights

{% raw %}
initial_weights = deepcopy(learn.model.state_dict())
{% endraw %} {% raw %}
learn.fit(5, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.589757 0.574913 0.703654 00:07
1 0.560946 0.546202 0.702977 00:07
2 0.531937 0.555280 0.709066 00:07
3 0.482919 0.549042 0.727334 00:07
4 0.443224 0.511626 0.756428 00:07
{% endraw %}

We now have our accuracy $a_A$ of $74\%$ and our training time $t_A$ of $5$ epochs

To find the lottery ticket, we will perform iterative pruning but, at each pruning step we will re-initialize the remaining weights to their original values (i.e. before training).

We will restart from the same initialization to be sure to not get lucky.

{% raw %}
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn.model.load_state_dict(initial_weights)
<All keys matched successfully>
{% endraw %}

We can pass the parameters lth=True to make the weights of the network reset to their original value after each pruning step, i.e. step 4) of the LTH. To empirically validate the LTH, we need to retrain the found "lottery ticket" after the pruning phase. Lottery tickets are usually found following an iterative pruning schedule. We set the start_epoch parameter to $5$ to begin the pruning process after $5$ epochs.

{% raw %}
sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, iterative, start_epoch=5, lth=True)
{% endraw %}

As our iterative schedule makes $3$ pruning steps by default, it means that we have to train our network for start_epoch + $3*t_B$, so $20$ epochs in order to get our LTH. After each step, the remaining weights will be reinitialized to their original value

{% raw %}
learn.fit(20, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.598814 0.686804 0.649526 00:07
1 0.560133 0.529173 0.732747 00:07
2 0.531159 0.511585 0.753045 00:07
3 0.498013 0.651968 0.600135 00:07
4 0.463376 0.491074 0.751015 00:07
5 0.555985 0.702323 0.669147 00:08
6 0.527828 0.485827 0.765223 00:08
7 0.468368 0.487057 0.782815 00:08
8 0.421725 0.445869 0.781461 00:08
9 0.391137 0.459242 0.799053 00:08
10 0.471564 0.622142 0.631935 00:08
11 0.430352 0.461739 0.776049 00:08
12 0.384778 0.531771 0.762517 00:08
13 0.352641 0.488232 0.794993 00:08
14 0.320290 0.488196 0.739513 00:08
15 0.386944 0.459122 0.812585 00:08
16 0.349582 0.569863 0.721922 00:08
17 0.312838 0.402464 0.815291 00:08
18 0.261409 0.542125 0.735453 00:08
19 0.245360 0.414822 0.809202 00:08
Sparsity at the end of epoch 0: [0.0]%
Sparsity at the end of epoch 1: [0.0]%
Sparsity at the end of epoch 2: [0.0]%
Sparsity at the end of epoch 3: [0.0]%
Sparsity at the end of epoch 4: [0.0]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 5: [16.67]%
Sparsity at the end of epoch 6: [16.67]%
Sparsity at the end of epoch 7: [16.67]%
Sparsity at the end of epoch 8: [16.67]%
Sparsity at the end of epoch 9: [16.67]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 10: [33.33]%
Sparsity at the end of epoch 11: [33.33]%
Sparsity at the end of epoch 12: [33.33]%
Sparsity at the end of epoch 13: [33.33]%
Sparsity at the end of epoch 14: [33.33]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 15: [50.0]%
Sparsity at the end of epoch 16: [50.0]%
Sparsity at the end of epoch 17: [50.0]%
Sparsity at the end of epoch 18: [50.0]%
Sparsity at the end of epoch 19: [50.0]%
Final Sparsity: [50.0]%
Sparsity in Conv2d 1: 50.00%
Sparsity in Conv2d 7: 50.00%
Sparsity in Conv2d 10: 50.00%
Sparsity in Conv2d 13: 50.00%
Sparsity in Conv2d 16: 50.00%
Sparsity in Conv2d 20: 50.00%
Sparsity in Conv2d 23: 50.00%
Sparsity in Conv2d 26: 50.00%
Sparsity in Conv2d 29: 50.00%
Sparsity in Conv2d 32: 50.00%
Sparsity in Conv2d 36: 50.00%
Sparsity in Conv2d 39: 50.00%
Sparsity in Conv2d 42: 50.00%
Sparsity in Conv2d 45: 50.00%
Sparsity in Conv2d 48: 50.00%
Sparsity in Conv2d 52: 50.00%
Sparsity in Conv2d 55: 50.00%
Sparsity in Conv2d 58: 50.00%
Sparsity in Conv2d 61: 50.00%
Sparsity in Conv2d 64: 50.00%
{% endraw %}

We indeed have a network B, whose accuracy $a_B > a_A$ in the same training time.

Lottery Ticket Hypothesis with Rewinding

In some case, LTH fails for deeper networks, author then propose a solution, which is to rewind the weights to a more advanced iteration instead of the initialization value.

{% raw %}
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn.model.load_state_dict(initial_weights)
<All keys matched successfully>
{% endraw %}

This can be done in fasterai by passing the rewind_epoch parameter, that will save the weights at that epoch, then resetting the weights accordingly.

{% raw %}
sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, iterative, start_epoch=5, lth=True, rewind_epoch=1)
{% endraw %} {% raw %}
learn.fit(20, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of [50]%
epoch train_loss valid_loss accuracy time
0 0.599985 0.566728 0.700271 00:07
1 0.550926 0.539680 0.711773 00:07
2 0.524204 0.647951 0.638701 00:07
3 0.499537 0.605123 0.705683 00:07
4 0.441816 0.577226 0.708390 00:07
5 0.514446 0.545353 0.740866 00:08
6 0.475099 0.517007 0.726658 00:08
7 0.429452 0.438081 0.810555 00:08
8 0.377079 0.444002 0.780785 00:08
9 0.348466 0.384129 0.837618 00:08
10 0.436006 0.475868 0.771313 00:08
11 0.390055 0.398510 0.822733 00:08
12 0.359748 0.468656 0.771313 00:08
13 0.323625 0.505410 0.780785 00:08
14 0.299261 0.387048 0.828146 00:08
15 0.355692 0.437216 0.801759 00:08
16 0.313067 0.469669 0.823410 00:08
17 0.271151 0.433562 0.832882 00:08
18 0.242624 0.531901 0.776049 00:08
19 0.225511 0.370019 0.845737 00:08
Sparsity at the end of epoch 0: [0.0]%
Saving Weights at epoch 1
Sparsity at the end of epoch 1: [0.0]%
Sparsity at the end of epoch 2: [0.0]%
Sparsity at the end of epoch 3: [0.0]%
Sparsity at the end of epoch 4: [0.0]%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 5: [16.67]%
Sparsity at the end of epoch 6: [16.67]%
Sparsity at the end of epoch 7: [16.67]%
Sparsity at the end of epoch 8: [16.67]%
Sparsity at the end of epoch 9: [16.67]%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 10: [33.33]%
Sparsity at the end of epoch 11: [33.33]%
Sparsity at the end of epoch 12: [33.33]%
Sparsity at the end of epoch 13: [33.33]%
Sparsity at the end of epoch 14: [33.33]%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 15: [50.0]%
Sparsity at the end of epoch 16: [50.0]%
Sparsity at the end of epoch 17: [50.0]%
Sparsity at the end of epoch 18: [50.0]%
Sparsity at the end of epoch 19: [50.0]%
Final Sparsity: [50.0]%
Sparsity in Conv2d 1: 50.00%
Sparsity in Conv2d 7: 50.00%
Sparsity in Conv2d 10: 50.00%
Sparsity in Conv2d 13: 50.00%
Sparsity in Conv2d 16: 50.00%
Sparsity in Conv2d 20: 50.00%
Sparsity in Conv2d 23: 50.00%
Sparsity in Conv2d 26: 50.00%
Sparsity in Conv2d 29: 50.00%
Sparsity in Conv2d 32: 50.00%
Sparsity in Conv2d 36: 50.00%
Sparsity in Conv2d 39: 50.00%
Sparsity in Conv2d 42: 50.00%
Sparsity in Conv2d 45: 50.00%
Sparsity in Conv2d 48: 50.00%
Sparsity in Conv2d 52: 50.00%
Sparsity in Conv2d 55: 50.00%
Sparsity in Conv2d 58: 50.00%
Sparsity in Conv2d 61: 50.00%
Sparsity in Conv2d 64: 50.00%
{% endraw %}

Super-Masks

Researchers from Uber AI investigated the LTH and found the existence of what they call "Super-Masks", i.e. masks that, applied on a untrained neural network, allows to reach better-than-random results.

{% raw %}
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
learn.model.load_state_dict(initial_weights)
<All keys matched successfully>
{% endraw %}

To find supermasks, authors perform the LTH method then apply the mask on the original, untrained network. In fasterai, you can pass the parameter reset_end=True, which will reset the weights to their original value at the end of the training, but keeping the pruned weights (i.e. the mask) unchanged.

{% raw %}
sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, iterative, start_epoch=5, lth=True, reset_end=True)
{% endraw %} {% raw %}
learn.fit(20, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.588215 0.603662 0.649526 00:07
1 0.559910 0.551536 0.697564 00:07
2 0.517046 0.473498 0.765900 00:07
3 0.460481 0.544400 0.706360 00:07
4 0.440268 0.441409 0.796346 00:07
5 0.559214 0.614429 0.648850 00:08
6 0.508002 0.699705 0.690798 00:08
7 0.472389 0.575537 0.740189 00:08
8 0.427371 0.578768 0.696211 00:08
9 0.377523 0.527881 0.735453 00:08
10 0.472205 0.708431 0.699594 00:08
11 0.441892 0.465152 0.776049 00:08
12 0.386189 0.442156 0.786874 00:08
13 0.347737 0.584790 0.778755 00:08
14 0.321213 0.891845 0.690798 00:08
15 0.410824 0.511506 0.772666 00:08
16 0.351705 0.389204 0.821380 00:08
17 0.308388 0.363523 0.845061 00:08
18 0.278367 0.380290 0.828146 00:08
19 0.265100 0.408591 0.845061 00:08
Sparsity at the end of epoch 0: [0.0]%
Sparsity at the end of epoch 1: [0.0]%
Sparsity at the end of epoch 2: [0.0]%
Sparsity at the end of epoch 3: [0.0]%
Sparsity at the end of epoch 4: [0.0]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 5: [16.67]%
Sparsity at the end of epoch 6: [16.67]%
Sparsity at the end of epoch 7: [16.67]%
Sparsity at the end of epoch 8: [16.67]%
Sparsity at the end of epoch 9: [16.67]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 10: [33.33]%
Sparsity at the end of epoch 11: [33.33]%
Sparsity at the end of epoch 12: [33.33]%
Sparsity at the end of epoch 13: [33.33]%
Sparsity at the end of epoch 14: [33.33]%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 15: [50.0]%
Sparsity at the end of epoch 16: [50.0]%
Sparsity at the end of epoch 17: [50.0]%
Sparsity at the end of epoch 18: [50.0]%
Sparsity at the end of epoch 19: [50.0]%
Final Sparsity: [50.0]%
Sparsity in Conv2d 1: 50.00%
Sparsity in Conv2d 7: 50.00%
Sparsity in Conv2d 10: 50.00%
Sparsity in Conv2d 13: 50.00%
Sparsity in Conv2d 16: 50.00%
Sparsity in Conv2d 20: 50.00%
Sparsity in Conv2d 23: 50.00%
Sparsity in Conv2d 26: 50.00%
Sparsity in Conv2d 29: 50.00%
Sparsity in Conv2d 32: 50.00%
Sparsity in Conv2d 36: 50.00%
Sparsity in Conv2d 39: 50.00%
Sparsity in Conv2d 42: 50.00%
Sparsity in Conv2d 45: 50.00%
Sparsity in Conv2d 48: 50.00%
Sparsity in Conv2d 52: 50.00%
Sparsity in Conv2d 55: 50.00%
Sparsity in Conv2d 58: 50.00%
Sparsity in Conv2d 61: 50.00%
Sparsity in Conv2d 64: 50.00%
{% endraw %} {% raw %}
learn.validate()
(#2) [2.3516082763671875,0.35520973801612854]
{% endraw %}