--- title: Prune Transformers keywords: fastai sidebar: home_sidebar summary: "Prune transformers architecture with fasterai" description: "Prune transformers architecture with fasterai" nb_path: "nbs/04c_tutorial.transformers.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %}

{% include note.html content='This example code is taken from the fastai docs' %}

{% raw %}
pretrained_weights = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights)
model = GPT2LMHeadModel.from_pretrained(pretrained_weights)
{% endraw %} {% raw %}
path = untar_data(URLs.WIKITEXT_TINY)
{% endraw %}

Let's create our fastai Learner.

{% raw %}
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), cbs=[DropOutput], metrics=Perplexity())
{% endraw %}

And let's try to extend a given prompt with the pretrained model.

{% raw %}
prompt = "\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn"
{% endraw %} {% raw %}
preds = learn.model.generate(inp, max_length=40, num_beams=5, temperature=1.5)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
{% endraw %} {% raw %}
tokenizer.decode(preds[0].cpu().numpy())
'\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn on its head.\n\nA unicorn is a magical creature with a rainbow tail and a horn'
{% endraw %} {% raw %}
learn.validate()
(#2) [3.695716381072998,40.2744140625]
{% endraw %} {% raw %}
learn.fit_one_cycle(1, 1e-4)
epoch train_loss valid_loss perplexity time
0 3.139103 2.843017 17.167484 07:58
{% endraw %} {% raw %}
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]

preds = learn.model.generate(inp.cuda(), max_length=40, num_beams=5, temperature=1.5)

tokenizer.decode(preds[0].cpu().numpy())
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
'\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn on its head. It is a member of the <unk> <unk> <unk>'
{% endraw %}

Make it sparse !

Let's see now if we retrain our model, this time introducing sparsity

{% raw %}
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), cbs=[DropOutput], metrics=Perplexity())
{% endraw %}

Also, when working with text, fastai defines the number of processed batches differently, so we have to adjust our SparsifyCallback accordingly (luckily, fastai makes it available as the n_batches attribute.

{% raw %}
@patch_to(SparsifyCallback)     
def before_fit(self):
    print(f'Pruning of {self.granularity} until a sparsity of {self.end_sparsity}%')
    self.end_epoch = self.n_epoch if self.end_epoch is None else self.end_epoch
    assert self.end_epoch <= self.n_epoch, 'Your end_epoch must be smaller than total number of epoch'

    model = self.learn.model if self.model is None else self.model # Pass a model if you don't want the whole model to be pruned
    self.sparsifier = Sparsifier(model, self.granularity, self.method, self.criteria, self.layer_type)
    self.total_iters = self.end_epoch * self.dls.n_batches
    self.start_iter = self.start_epoch * self.dls.n_batches
{% endraw %}

Let's define our SparsifyCallback. Let's say we want to make our model 30% sparse, by removing the highest-norm weight in each attention head.

{% raw %}
sp_cb = SparsifyCallback(end_sparsity=30, granularity='weight', method='local', criteria=large_final, sched_func=sched_onecycle, layer_type=Conv1D)
{% endraw %}

We now only have to pass our callback to fastai

{% raw %}
learn.fit_one_cycle(1, 1e-4, cbs=sp_cb)
Pruning of weight until a sparsity of [30]%
Saving Weights at epoch 0
epoch train_loss valid_loss perplexity time
0 3.004998 2.860594 17.471899 12:16
Sparsity at the end of epoch 0: [30.0]%
Final Sparsity: [30.0]%
Sparsity in Conv1D 9: 30.00%
Sparsity in Conv1D 10: 30.00%
Sparsity in Conv1D 15: 30.00%
Sparsity in Conv1D 16: 30.00%
Sparsity in Conv1D 21: 30.00%
Sparsity in Conv1D 22: 30.00%
Sparsity in Conv1D 27: 30.00%
Sparsity in Conv1D 28: 30.00%
Sparsity in Conv1D 33: 30.00%
Sparsity in Conv1D 34: 30.00%
Sparsity in Conv1D 39: 30.00%
Sparsity in Conv1D 40: 30.00%
Sparsity in Conv1D 45: 30.00%
Sparsity in Conv1D 46: 30.00%
Sparsity in Conv1D 51: 30.00%
Sparsity in Conv1D 52: 30.00%
Sparsity in Conv1D 57: 30.00%
Sparsity in Conv1D 58: 30.00%
Sparsity in Conv1D 63: 30.00%
Sparsity in Conv1D 64: 30.00%
Sparsity in Conv1D 69: 30.00%
Sparsity in Conv1D 70: 30.00%
Sparsity in Conv1D 75: 30.00%
Sparsity in Conv1D 76: 30.00%
Sparsity in Conv1D 81: 30.00%
Sparsity in Conv1D 82: 30.00%
Sparsity in Conv1D 87: 30.00%
Sparsity in Conv1D 88: 30.00%
Sparsity in Conv1D 93: 30.00%
Sparsity in Conv1D 94: 30.00%
Sparsity in Conv1D 99: 30.00%
Sparsity in Conv1D 100: 30.00%
Sparsity in Conv1D 105: 30.00%
Sparsity in Conv1D 106: 30.00%
Sparsity in Conv1D 111: 30.00%
Sparsity in Conv1D 112: 30.00%
Sparsity in Conv1D 117: 30.00%
Sparsity in Conv1D 118: 30.00%
Sparsity in Conv1D 123: 30.00%
Sparsity in Conv1D 124: 30.00%
Sparsity in Conv1D 129: 30.00%
Sparsity in Conv1D 130: 30.00%
Sparsity in Conv1D 135: 30.00%
Sparsity in Conv1D 136: 30.00%
Sparsity in Conv1D 141: 30.00%
Sparsity in Conv1D 142: 30.00%
Sparsity in Conv1D 147: 30.00%
Sparsity in Conv1D 148: 30.00%
{% endraw %}

And we can check the predicion to the same prompt as before

{% raw %}
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]

preds = learn.model.generate(inp.cuda(), max_length=40, num_beams=5, temperature=1.5)

tokenizer.decode(preds[0].cpu().numpy())
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
'\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn @-@ like head. The unicorn is a member of the <unk> <unk>'
{% endraw %} {% raw %}
print_sparsity(learn.model)
Sparsity in Conv1D 9: 30.00%
Sparsity in Conv1D 10: 30.00%
Sparsity in Conv1D 15: 30.00%
Sparsity in Conv1D 16: 30.00%
Sparsity in Conv1D 21: 30.00%
Sparsity in Conv1D 22: 30.00%
Sparsity in Conv1D 27: 30.00%
Sparsity in Conv1D 28: 30.00%
Sparsity in Conv1D 33: 30.00%
Sparsity in Conv1D 34: 30.00%
Sparsity in Conv1D 39: 30.00%
Sparsity in Conv1D 40: 30.00%
Sparsity in Conv1D 45: 30.00%
Sparsity in Conv1D 46: 30.00%
Sparsity in Conv1D 51: 30.00%
Sparsity in Conv1D 52: 30.00%
Sparsity in Conv1D 57: 30.00%
Sparsity in Conv1D 58: 30.00%
Sparsity in Conv1D 63: 30.00%
Sparsity in Conv1D 64: 30.00%
Sparsity in Conv1D 69: 30.00%
Sparsity in Conv1D 70: 30.00%
Sparsity in Conv1D 75: 30.00%
Sparsity in Conv1D 76: 30.00%
Sparsity in Conv1D 81: 30.00%
Sparsity in Conv1D 82: 30.00%
Sparsity in Conv1D 87: 30.00%
Sparsity in Conv1D 88: 30.00%
Sparsity in Conv1D 93: 30.00%
Sparsity in Conv1D 94: 30.00%
Sparsity in Conv1D 99: 30.00%
Sparsity in Conv1D 100: 30.00%
Sparsity in Conv1D 105: 30.00%
Sparsity in Conv1D 106: 30.00%
Sparsity in Conv1D 111: 30.00%
Sparsity in Conv1D 112: 30.00%
Sparsity in Conv1D 117: 30.00%
Sparsity in Conv1D 118: 30.00%
Sparsity in Conv1D 123: 30.00%
Sparsity in Conv1D 124: 30.00%
Sparsity in Conv1D 129: 30.00%
Sparsity in Conv1D 130: 30.00%
Sparsity in Conv1D 135: 30.00%
Sparsity in Conv1D 136: 30.00%
Sparsity in Conv1D 141: 30.00%
Sparsity in Conv1D 142: 30.00%
Sparsity in Conv1D 147: 30.00%
Sparsity in Conv1D 148: 30.00%
{% endraw %}

That's it ! You now have a sparse Transformer as performant as the whole model. However, this model is currently not more efficient speed and storage wise. To have such a speed-up, I suggest you to look at the granularity section.