--- title: Pruning Schedules keywords: fastai sidebar: home_sidebar summary: "Make your neural network sparse with fastai" description: "Make your neural network sparse with fastai" nb_path: "nbs/04a_tutorial.schedules.ipynb" ---
Neural Network Pruning usually follows one of the next 3 schedules:
In fasterai, all those 3 schedules can be applied from the same callback. We'll cover each below
In the SparsifyCallback, there are several parameters to 'shape' our pruning schedule:
start_sparsity
: the initial sparsity of our model, generally kept at 0 as after initialization, our weights are generally non-zero.end_sparsity
: the target sparsity at the end of the training start_epoch
: we can decide to start pruning right from the beginning or let it train a bit before removing weights.sched_func
: this is where the general shape of the schedule is specified as it specifies how the sparsity evolves along the training. You can either use a schedule available in fastai our even coming with your own !path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)
We will first train a network without any pruning, which will serve as a baseline.
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(6)
The simplest way to perform pruning is called One-Shot Pruning. It consists of the following three steps:
With fasterai, this is really easy to do. Let's illustrate it by an example:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
In this case, your network needs to be trained before pruning. This training can be done independently from the pruning callback, or simulated by the start_epoch
that will delay the pruning process.
You thus only need to create the Callback with the one_shot
schedule and set the start_epoch
argument, i.e. how many epochs you want to train your network before pruning it.
sp_cb=SparsifyCallback(end_sparsity=50, granularity='weight', method='local', criteria=large_final, sched_func=one_shot, start_epoch=3)
Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before
learn.fit_one_cycle(6, cbs=sp_cb)
Researchers have come up with a better way to do pruning than pruning all the weigths in once (as in One-Shot Pruning). The idea is to perform several iterations of pruning and fine-tuning and is thus called Iterative Pruning.
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
In this case, your network needs to be trained before pruning.
You only need to create the Callback with the iterative
schedule and set the start_epoch
argument, i.e. how many epochs you want to train your network before pruning it.
def iterative(start, end, pos, n_steps=3):
"Perform iterative pruning, and pruning in `n_steps` steps"
return start + ((end-start)/n_steps)*(np.ceil((pos)*n_steps))
The iterative
schedules has a n_steps
parameter, i.e. how many iterations of pruning/fine-tuning you want to perform. To modify its value, we can use the partial
function like this:
iterative = partial(iterative, n_steps=5)
sp_cb=SparsifyCallback(end_sparsity=50, granularity='weight', method='local', criteria=large_final, sched_func=iterative, start_epoch=3)
Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before
learn.fit_one_cycle(6, cbs=sp_cb)
Here is for example how to implement the Automated Gradual Pruning schedule.
def sched_agp(start, end, pos): return end + start - end * (1 - pos)**3
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
sp_cb=SparsifyCallback(end_sparsity=50, granularity='weight', method='local', criteria=large_final, sched_func=sched_agp, start_epoch=3)
Let's start pruning after 3 epochs and train our model for 6 epochs to have the same total amount of training as before
learn.fit_one_cycle(6, cbs=sp_cb)
Even though they are often considered as different pruning methods, those 3 schedules can be captured by the same Callback. Here is how the sparsity in the network evolves for those methods;
Let's take an example here. Let's say that we want to train our network for 3 epochs without pruning and then 7 epochs with pruning.
train = np.zeros(300)
prune = np.linspace(0,1, 700)
Then this is what our different pruning schedules will look like:
plt.plot(np.concatenate([train, iterative(0,50, prune)]), label='Iterative Pruning')
plt.plot(np.concatenate([train, os]), label='One-Shot Pruning')
plt.plot(np.concatenate([train, sched_agp(0,50, prune)]), label='Automated Gradual Pruning')
plt.legend()
You can also come up with your own pruning schedule !