--- title: Knowledge Distillation keywords: fastai sidebar: home_sidebar summary: "Train a network in a teacher-student fashion" description: "Train a network in a teacher-student fashion" nb_path: "nbs/05_knowledge_distillation.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %}

Knowledge Distillation, sometimes called teacher-student training, is a compression method in which a small (the student) model is trained to mimic the behaviour of a larger (the teacher) model.

The main goal is to reveal what is called the Dark Knowledge hidden in the teacher model.

If we take the same example provided by Geoffrey Hinton et al., we have

The main problem of classification is that the output activation function (softmax) will, by design, make a single value really high and squash others.

$$ p_{i}=\frac{\exp \left(z_{i}\right)}{\sum_{j} \exp \left(z_{j}\right)} $$

With $p_i$ the probability of class $i$, computed from the logits $z$

Here is an example to illustrate this phenomenon:

Let's say that we have trained a model to discriminate between the following 5 classes: [cow, dog, plane, cat, car]

And here is the output of the final layer (the logits) when the model is fed a new input image:

{% raw %}
logits = torch.tensor([1.3, 3.1, 0.2, 1.9, -0.3])
{% endraw %}

By judging on the predictions, the model seems confident that the input data is a dog and quite confident that it is definitely not a plane nor a car, with predictions for cow and cat being moderately high.

So the model not only has learned to recognize a dog in the image, but also that a dog is very different from a car and a plane and share similarities with cats and cows. This information is what is called dark knowledge !

When passing those predictions through a softmax, we have:

{% raw %}
predictions = F.softmax(logits, dim=-1); predictions
tensor([0.1063, 0.6431, 0.0354, 0.1937, 0.0215])
{% endraw %}

This is accuenting the differences that we had earlier, discarding some of the dark knowledge acquired earlier. The way to keep this knowledge is to "soften" our softmax outputs, by adding a temperature parameter. The higher the temperature, the softer the predictions.

{% raw %}
soft_predictions = F.softmax(logits/3, dim=-1); soft_predictions
tensor([0.1879, 0.3423, 0.1302, 0.2294, 0.1102])
{% endraw %}

{% include note.html content='if the Temperature is equal to 1, then we have regular softmax' %}

When applying Knowledge Distillation, we want to keep the Dark Knowledge that the teacher model has acquired during its training but not rely entirely on it. So we combine two losses:

  • The Teacher loss between the softened predictions of the teacher and the softened predictions of the student
  • The Classification loss, which is the regular loss between hard labels and hard predictions

The combination between those losses are weighted by an additional parameter α, as:

$$ L_{K D}=\alpha * \text { CrossEntropy }\left(p_{S}^{\tau}, p_{T}^{\tau}\right)+(1-\alpha) * \text { CrossEntropy }\left(p_{S}, y_{\text {true }}\right) $$

With $p^{\tau}$ being the softened predictions of the student and teacher

{% include note.html content='In practice, the distillation loss will be a bit different in the implementation' %}

distill

{% raw %}
{% endraw %}

This can be done with fastai, using the Callback system !

{% raw %}

class KnowledgeDistillationCallback[source]

KnowledgeDistillationCallback(teacher, loss, activations_student=None, activations_teacher=None, weight=0.5) :: Callback

Basic class handling tweaks of the training loop by changing a Learner in various events

{% endraw %} {% raw %}
{% endraw %} {% raw %}

get_model_layers[source]

get_model_layers(model, getLayerRepr=False)

{% endraw %} {% raw %}

get_module_by_name[source]

get_module_by_name(module:Union[Tensor, Module], access_string:str)

{% endraw %} {% raw %}
{% endraw %}

The loss function that is used may depend on the use case. For classification, we usually use the one presented above, named SoftTarget in fasterai. But for regression cases, we may want to perform regression on the logits directly.

export

def SoftTarget(pred, teacher_pred, T=5, *kwargs): return nn.KLDivLoss(reduction='batchmean')(F.log_softmax(pred/T, dim=1), F.softmax(teacher_pred/T, dim=1)) (T*T)

def Logits(pred, teacher_pred, **kwargs): return F.mse_loss(preds, teacher_pred)

def Mutual(pred, teacher_pred, **kwargs): return nn.KLDivLoss(reduction='batchmean')(F.log_softmax(pred, dim=1), F.softmax(teacher_pred, dim=1))

def Attention(pred, teacher_pred, fm_s, fm_t, p=2, **kwargs): return sum([F.mse_loss(F.normalize(fm_s[name_st].pow(p).mean(1),dim=(1,2)), F.normalize(fm_t[name_t].pow(p).mean(1),dim=(1,2))) for name_st, name_t in zip(fm_s, fm_t)])

def ActivationBoundaries(pred, teacher_pred, fm_s, fm_t, m=2, *kwargs): return sum([((fm_s[name_st] + m).pow(2) ((fm_s[name_st] > -m) & (fm_t[name_t] <= 0)).float() + (fm_s[name_st] - m).pow(2) * ((fm_s[name_st] <= m) & (fm_t[name_t] > 0)).float()).mean() for name_st, name_t in zip(fm_s, fm_t)])

def FitNet(pred, teacher_pred, fm_s, fm_t, **kwargs): return sum([F.mse_loss(fm_s[name_st],fm_t[name_t]) for name_st, name_t in zip(fm_s, fm_t)])

def Similarity(fm_s, fm_t, p=2, **kwargs): return sum([F.mse_loss(F.normalize(fm_s[name_st].view(fm_s[name_st].size(0), -1) @ fm_s[name_st].view(fm_s[name_st].size(0), -1).t(), p=p, dim=1), F.normalize(fm_t[name_t].view(fm_t[name_t].size(0), -1) @ fm_t[name_t].view(fm_t[name_t].size(0), -1).t(), p=p, dim=1)) for name_st, name_t in zip(fm_s, fm_t)])

{% raw %}

SoftTarget[source]

SoftTarget(pred, teacher_pred, T=5, **kwargs)

{% endraw %} {% raw %}

Logits[source]

Logits(pred, teacher_pred, **kwargs)

{% endraw %} {% raw %}

Mutual[source]

Mutual(pred, teacher_pred, **kwargs)

{% endraw %} {% raw %}

Attention[source]

Attention(fm_s, fm_t, p=2, **kwargs)

{% endraw %} {% raw %}

ActivationBoundaries[source]

ActivationBoundaries(fm_s, fm_t, m=2, **kwargs)

{% endraw %} {% raw %}

FitNet[source]

FitNet(fm_s, fm_t, **kwargs)

{% endraw %} {% raw %}

Similarity[source]

Similarity(fm_s, fm_t, pred, p=2, **kwargs)

{% endraw %} {% raw %}
{% endraw %}