--- title: Prune Transformers keywords: fastai sidebar: home_sidebar summary: "Prune transformers architecture with fasterai" description: "Prune transformers architecture with fasterai" nb_path: "nbs/04c_tutorial.transformers.ipynb" ---
{% include note.html content='This example code is taken from the fastai docs' %}
pretrained_weights = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights)
model = GPT2LMHeadModel.from_pretrained(pretrained_weights)
path = untar_data(URLs.WIKITEXT_TINY)
Let's create our fastai Learner
.
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), cbs=[DropOutput], metrics=Perplexity())
And let's try to extend a given prompt with the pretrained model.
prompt = "\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn"
preds = learn.model.generate(inp, max_length=40, num_beams=5, temperature=1.5)
tokenizer.decode(preds[0].cpu().numpy())
learn.validate()
learn.fit_one_cycle(1, 1e-4)
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]
preds = learn.model.generate(inp.cuda(), max_length=40, num_beams=5, temperature=1.5)
tokenizer.decode(preds[0].cpu().numpy())
Let's see now if we retrain our model, this time introducing sparsity
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), cbs=[DropOutput], metrics=Perplexity())
Also, when working with text, fastai defines the number of processed batches differently, so we have to adjust our SparsifyCallback
accordingly (luckily, fastai makes it available as the n_batches
attribute.
@patch_to(SparsifyCallback)
def before_fit(self):
print(f'Pruning of {self.granularity} until a sparsity of {self.end_sparsity}%')
self.end_epoch = self.n_epoch if self.end_epoch is None else self.end_epoch
assert self.end_epoch <= self.n_epoch, 'Your end_epoch must be smaller than total number of epoch'
model = self.learn.model if self.model is None else self.model # Pass a model if you don't want the whole model to be pruned
self.sparsifier = Sparsifier(model, self.granularity, self.method, self.criteria, self.layer_type)
self.total_iters = self.end_epoch * self.dls.n_batches
self.start_iter = self.start_epoch * self.dls.n_batches
Let's define our SparsifyCallback
. Let's say we want to make our model 30% sparse, by removing the highest-norm weight in each attention head.
sp_cb = SparsifyCallback(end_sparsity=30, granularity='weight', method='local', criteria=large_final, sched_func=sched_onecycle, layer_type=Conv1D)
We now only have to pass our callback to fastai
learn.fit_one_cycle(1, 1e-4, cbs=sp_cb)
And we can check the predicion to the same prompt as before
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]
preds = learn.model.generate(inp.cuda(), max_length=40, num_beams=5, temperature=1.5)
tokenizer.decode(preds[0].cpu().numpy())
print_sparsity(learn.model)
That's it ! You now have a sparse Transformer as performant as the whole model. However, this model is currently not more efficient speed and storage wise. To have such a speed-up, I suggest you to look at the granularity section.