--- title: Pruner keywords: fastai sidebar: home_sidebar summary: "Remove useless filters to recreate a dense network" description: "Remove useless filters to recreate a dense network" nb_path: "nbs/02a_pruner.ipynb" ---
{% include important.html content='The Pruner method currently works on fully-feedforward ConvNets, e.g. VGG16. Support for residual connections, e.g. ResNets is under development.' %}
When our network has filters containing zero values, there is an additional step that we may take. Indeed, those zero-filters can be physically removed from our network, allowing us to get a new, dense, architecture.
This can be done by reexpressing each layer, reducing the number of filter, to match the number of non-zero filters. However, when we remove a filter in a layer, this means that there will be a missing activation map, which should be used by all the filters in the next layer. So, not only should we physically remove the filter, but also its corresponding kernel in each of the filters in the next layer (see Fig. below)
Let's illustrate this with an example:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))
learn = Learner(dls, vgg16_bn(num_classes=2), metrics=accuracy)
count_parameters(learn.model)
Our initial model, a VGG16, possess more than 134 million parameters. Let's see what happens when we make it sparse, on a filter level
sp_cb=SparsifyCallback(end_sparsity=50, granularity='filter', method='local', criteria=large_final, sched_func=sched_onecycle)
learn.fit_one_cycle(3, 3e-4, cbs=sp_cb)
count_parameters(learn.model)
The total amount of parameters hasn't changed! This is because we only replaced the values by zeroes, leading to a sparse model, but they are still there.
The Pruner
will take care of removing those useless filters.
pruner = Pruner()
pruned_model = pruner.prune_model(learn.model)
Done! Let's see if the performance is still the same
pruned_learn = Learner(dls, pruned_model.cuda(), metrics=accuracy)
pruned_learn.validate()
count_parameters(pruned_learn.model)
Now we have 71 million of parameters, approximately 50% of the initial parameters as we asked!