--- title: Knowledge Distillation keywords: fastai sidebar: home_sidebar summary: "How to apply knowledge distillation with fasterai" description: "How to apply knowledge distillation with fasterai" nb_path: "nbs/04b_tutorial.knowledge_distillation.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %}

We'll illustrate how to use Knowledge Distillation to distill the knowledge of a Resnet34 (the teacher), to a Resnet18 (the student)

Let's us grab some data

{% raw %}
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))
{% endraw %}

The first step is then to train the teacher model. We'll start from a pretrained model, ensuring to get good results on our dataset.

{% raw %}
teacher = cnn_learner(dls, resnet34, metrics=accuracy)
teacher.unfreeze()
teacher.fit_one_cycle(10, 1e-3)
/home/HubensN/miniconda3/envs/deep/lib/python3.8/site-packages/fastai/vision/learner.py:265: UserWarning: `cnn_learner` has been renamed to `vision_learner` -- please update your code
  warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")
epoch train_loss valid_loss accuracy time
0 0.738216 0.514556 0.847091 00:12
1 0.491984 0.307432 0.863329 00:11
2 0.447923 1.710455 0.701624 00:11
3 0.452576 0.340987 0.844384 00:11
4 0.287033 0.239666 0.893099 00:11
5 0.238854 0.230530 0.909337 00:11
6 0.144284 0.175067 0.927605 00:11
7 0.098495 0.161163 0.930988 00:11
8 0.050567 0.149135 0.942490 00:11
9 0.032685 0.151816 0.941137 00:11
{% endraw %}

Without KD

We'll now train a Resnet18 from scratch, and without any help from the teacher model, to get that as a baseline

{% raw %}
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
student.fit_one_cycle(10, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.592990 0.880996 0.677943 00:11
1 0.568779 0.633168 0.636671 00:11
2 0.533047 0.738055 0.539242 00:10
3 0.486673 0.749287 0.717185 00:11
4 0.438561 0.451227 0.784844 00:16
5 0.395264 0.384985 0.828146 00:16
6 0.339680 0.372352 0.845061 00:11
7 0.264073 0.353506 0.842355 00:11
8 0.194599 0.351618 0.857239 00:11
9 0.158134 0.342614 0.853857 00:11
{% endraw %}

With KD

And now we train the same model, but with the help of the teacher. The chosen loss is a combination of the regular classification loss (Cross-Entropy) and a loss pushing the student to learn from the teacher's predictions.

{% raw %}
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, SoftTarget)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 2.296719 1.884460 0.711773 00:12
1 2.191021 1.805818 0.729364 00:12
2 1.973543 1.675716 0.752368 00:12
3 1.770022 1.516037 0.769283 00:12
4 1.556814 1.472682 0.790257 00:12
5 1.290988 1.036582 0.841001 00:12
6 1.036041 0.882328 0.857239 00:12
7 0.809941 0.858502 0.852503 00:13
8 0.608483 0.804021 0.863329 00:12
9 0.481355 0.793109 0.864005 00:12
{% endraw %}

When helped, the student model performs better !

There exist more complicated KD losses, such as the one coming from Paying Attention to Attention, where the student tries to replicate the same attention maps of the teacher at intermediate layers.

Using such a loss requires to be able to specify from which layer we want to replicate those attention maps. To do so, we have to specify them from their string name, which can be obtained with the get_model_layers function.

For example, we set the loss to be applied after each Residual block of our models:

{% raw %}
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, Attention, ['layer1', 'layer2', 'layer3', 'layer4'], ['0.4', '0.5', '0.6', '0.7'], weight=0.9)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 0.091224 0.094667 0.686739 00:13
1 0.083294 0.079744 0.701624 00:13
2 0.069417 0.064017 0.788904 00:12
3 0.061072 0.062791 0.787551 00:12
4 0.054276 0.053314 0.843708 00:13
5 0.048175 0.050120 0.826793 00:13
6 0.040943 0.051322 0.849120 00:13
7 0.031932 0.048541 0.855210 00:13
8 0.024598 0.043699 0.881597 00:13
9 0.021455 0.043395 0.879567 00:12
{% endraw %}