--- title: Pruning Schedules keywords: fastai sidebar: home_sidebar summary: "Make your neural network sparse with fastai" description: "Make your neural network sparse with fastai" nb_path: "nbs/04a_tutorial.schedules.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %}

Neural Network Pruning usually follows one of the next 3 schedules:

alt text

In fasterai, all those 3 schedules can be applied from the same callback. We'll cover each below

In the SparsifyCallback, there are several parameters to 'shape' our pruning schedule:

  • start_sparsity: the initial sparsity of our model, generally kept at 0 as after initialization, our weights are generally non-zero.
  • end_sparsity: the target sparsity at the end of the training
  • start_epoch: we can decide to start pruning right from the beginning or let it train a bit before removing weights.
  • sched_func: this is where the general shape of the schedule is specified as it specifies how the sparsity evolves along the training. You can either use a schedule available in fastai our even coming with your own !

{% raw %}
path = untar_data(URLs.PETS)

files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)
{% endraw %}

We will first train a network without any pruning, which will serve as a baseline.

{% raw %}
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

learn.fit_one_cycle(6)
epoch train_loss valid_loss accuracy time
0 0.720675 0.688383 0.805819 00:09
1 0.448731 0.249166 0.901218 00:08
2 0.261298 0.229913 0.897158 00:08
3 0.172998 0.222149 0.915426 00:08
4 0.099862 0.192763 0.927605 00:09
5 0.051141 0.175377 0.935724 00:09
{% endraw %}

One-Shot Pruning

The simplest way to perform pruning is called One-Shot Pruning. It consists of the following three steps:

  1. You first need to train a network
  2. You then need to remove some weights (depending on your criteria, needs,...)
  3. You fine-tune the remaining weights to recover from the loss of parameters.

With fasterai, this is really easy to do. Let's illustrate it by an example:

{% raw %}
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
{% endraw %}

In this case, your network needs to be trained before pruning. This training can be done independently from the pruning callback, or simulated by the start_epoch that will delay the pruning process.

You thus only need to create the Callback with the one_shot schedule and set the start_epoch argument, i.e. how many epochs you want to train your network before pruning it.

{% raw %}
sp_cb=SparsifyCallback(end_sparsity=50, granularity='weight', method='local', criteria=large_final, sched_func=one_shot, start_epoch=3)
{% endraw %}

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

{% raw %}
learn.fit_one_cycle(6, cbs=sp_cb)
Pruning of weight until a sparsity of 50%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.736246 1.251360 0.769283 00:09
1 0.449560 0.284590 0.890392 00:09
2 0.272213 0.262885 0.893775 00:09
3 0.162493 0.219732 0.923545 00:09
4 0.087920 0.182203 0.933694 00:10
5 0.066148 0.178582 0.938430 00:09
Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 50.00%
Sparsity at the end of epoch 4: 50.00%
Sparsity at the end of epoch 5: 50.00%
Final Sparsity: 50.00
Sparsity in Conv2d 2: 50.00%
Sparsity in Conv2d 8: 50.00%
Sparsity in Conv2d 11: 50.00%
Sparsity in Conv2d 14: 50.00%
Sparsity in Conv2d 17: 50.00%
Sparsity in Conv2d 21: 50.00%
Sparsity in Conv2d 24: 50.00%
Sparsity in Conv2d 27: 50.00%
Sparsity in Conv2d 30: 50.00%
Sparsity in Conv2d 33: 50.00%
Sparsity in Conv2d 37: 50.00%
Sparsity in Conv2d 40: 50.00%
Sparsity in Conv2d 43: 50.00%
Sparsity in Conv2d 46: 50.00%
Sparsity in Conv2d 49: 50.00%
Sparsity in Conv2d 53: 50.00%
Sparsity in Conv2d 56: 50.00%
Sparsity in Conv2d 59: 50.00%
Sparsity in Conv2d 62: 50.00%
Sparsity in Conv2d 65: 50.00%
{% endraw %}

Iterative Pruning

Researchers have come up with a better way to do pruning than pruning all the weigths in once (as in One-Shot Pruning). The idea is to perform several iterations of pruning and fine-tuning and is thus called Iterative Pruning.

  1. You first need to train a network
  2. You then need to remove a part of the weights weights (depending on your criteria, needs,...)
  3. You fine-tune the remaining weights to recover from the loss of parameters.
  4. Back to step 2.
{% raw %}
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
{% endraw %}

In this case, your network needs to be trained before pruning.

You only need to create the Callback with the iterative schedule and set the start_epoch argument, i.e. how many epochs you want to train your network before pruning it.

{% raw %}
def iterative(start, end, pos, n_steps=3):
    "Perform iterative pruning, and pruning in `n_steps` steps"
    return start + ((end-start)/n_steps)*(np.ceil((pos)*n_steps))
{% endraw %}

The iterative schedules has a n_stepsparameter, i.e. how many iterations of pruning/fine-tuning you want to perform. To modify its value, we can use the partial function like this:

iterative = partial(iterative, n_steps=5)
{% raw %}
sp_cb=SparsifyCallback(end_sparsity=50, granularity='weight', method='local', criteria=large_final, sched_func=iterative, start_epoch=3)
{% endraw %}

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

{% raw %}
learn.fit_one_cycle(6, cbs=sp_cb)
Pruning of weight until a sparsity of 50%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.676544 0.345599 0.855886 00:08
1 0.402555 0.298006 0.887010 00:08
2 0.236813 0.293453 0.887010 00:08
3 0.133723 0.194966 0.921516 00:09
4 0.090915 0.161868 0.941137 00:09
5 0.056463 0.158669 0.945873 00:09
Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 16.67%
Sparsity at the end of epoch 4: 33.33%
Sparsity at the end of epoch 5: 50.00%
Final Sparsity: 50.00
Sparsity in Conv2d 2: 50.00%
Sparsity in Conv2d 8: 50.00%
Sparsity in Conv2d 11: 50.00%
Sparsity in Conv2d 14: 50.00%
Sparsity in Conv2d 17: 50.00%
Sparsity in Conv2d 21: 50.00%
Sparsity in Conv2d 24: 50.00%
Sparsity in Conv2d 27: 50.00%
Sparsity in Conv2d 30: 50.00%
Sparsity in Conv2d 33: 50.00%
Sparsity in Conv2d 37: 50.00%
Sparsity in Conv2d 40: 50.00%
Sparsity in Conv2d 43: 50.00%
Sparsity in Conv2d 46: 50.00%
Sparsity in Conv2d 49: 50.00%
Sparsity in Conv2d 53: 50.00%
Sparsity in Conv2d 56: 50.00%
Sparsity in Conv2d 59: 50.00%
Sparsity in Conv2d 62: 50.00%
Sparsity in Conv2d 65: 50.00%
{% endraw %}

Gradual Pruning

Here is for example how to implement the Automated Gradual Pruning schedule.

{% raw %}
def sched_agp(start, end, pos): return end + start - end * (1 - pos)**3
{% endraw %} {% raw %}
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
{% endraw %} {% raw %}
sp_cb=SparsifyCallback(end_sparsity=50, granularity='weight', method='local', criteria=large_final, sched_func=sched_agp, start_epoch=3)
{% endraw %}

Let's start pruning after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

{% raw %}
learn.fit_one_cycle(6, cbs=sp_cb)
Pruning of weight until a sparsity of 50%
Saving Weights at epoch 0
epoch train_loss valid_loss accuracy time
0 0.707103 0.570778 0.849120 00:08
1 0.457157 0.277922 0.876861 00:08
2 0.264160 0.321768 0.867388 00:08
3 0.187102 0.202760 0.918133 00:09
4 0.096316 0.206133 0.930311 00:09
5 0.058155 0.160002 0.944520 00:09
Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 35.19%
Sparsity at the end of epoch 4: 48.15%
Sparsity at the end of epoch 5: 50.00%
Final Sparsity: 50.00
Sparsity in Conv2d 2: 50.00%
Sparsity in Conv2d 8: 50.00%
Sparsity in Conv2d 11: 50.00%
Sparsity in Conv2d 14: 50.00%
Sparsity in Conv2d 17: 50.00%
Sparsity in Conv2d 21: 50.00%
Sparsity in Conv2d 24: 50.00%
Sparsity in Conv2d 27: 50.00%
Sparsity in Conv2d 30: 50.00%
Sparsity in Conv2d 33: 50.00%
Sparsity in Conv2d 37: 50.00%
Sparsity in Conv2d 40: 50.00%
Sparsity in Conv2d 43: 50.00%
Sparsity in Conv2d 46: 50.00%
Sparsity in Conv2d 49: 50.00%
Sparsity in Conv2d 53: 50.00%
Sparsity in Conv2d 56: 50.00%
Sparsity in Conv2d 59: 50.00%
Sparsity in Conv2d 62: 50.00%
Sparsity in Conv2d 65: 50.00%
{% endraw %}

Even though they are often considered as different pruning methods, those 3 schedules can be captured by the same Callback. Here is how the sparsity in the network evolves for those methods;

Let's take an example here. Let's say that we want to train our network for 3 epochs without pruning and then 7 epochs with pruning.

{% raw %}
train = np.zeros(300)
prune = np.linspace(0,1, 700) 
{% endraw %}

Then this is what our different pruning schedules will look like:

{% raw %}
plt.plot(np.concatenate([train, iterative(0,50, prune)]), label='Iterative Pruning')
plt.plot(np.concatenate([train, os]), label='One-Shot Pruning')
plt.plot(np.concatenate([train, sched_agp(0,50, prune)]), label='Automated Gradual Pruning')
plt.legend()
<matplotlib.legend.Legend at 0x7f7ae7f2c9d0>
{% endraw %}

You can also come up with your own pruning schedule !