Transformer model distillation¶
Overview¶
Transformer models which were pre-trained on large corpora, such as BERT/XLNet/XLM, have shown to improve the accuracy of many NLP tasks. However, such models have two distinct disadvantages - (1) model size and (2) speed, since such large models are computationally heavy.
One possible approach to overcome these cons is to use Knowledge Distillation (KD). Using this approach a large model is trained on the data set and then used to teach a much smaller and more efficient network. This is often referred to a Student-Teacher training where a teacher network adds its error to the student’s loss function, thus, helping the student network to converge to a better solution.
Knowledge Distillation¶
One approach is similar to the method in Hinton 2015 [1]. The loss function is modified to include a measure of distributions divergence, which can be measured using KL divergence or MSE between the logits of the student and the teacher network.
\(loss = w_s \cdot loss_{student} + w_d \cdot KL(logits_{student} / T || logits_{teacher} / T)\)
where T is a value representing temperature for softening the logits prior to applying softmax. loss_{student} is the original loss of the student network obtained during regular training. Finally, the losses are weighted.
TeacherStudentDistill
¶
This class can be added to support for distillation in a model.
To add support for distillation, the student model must include handling of training
using TeacherStudentDistill
class, see nlp_architect.procedures.token_tagging.do_kd_training
for
an example how to train a neural tagger using a transformer model using distillation.
-
class
nlp_architect.nn.torch.distillation.
TeacherStudentDistill
(teacher_model: nlp_architect.models.TrainableModel, temperature: float = 1.0, kd_w: float = 0.5, loss_w: float = 0.5)[source]¶ Teacher-Student knowledge distillation helper. Use this object when training a model with KD and a teacher model.
Parameters: - teacher_model (TrainableModel) – teacher model
- temperature (float, optional) – KD temperature. Defaults to 1.0.
- kd_w (float, optional) – teacher loss weight. Defaults to 0.5.
- loss_w (float, optional) – student loss weight. Defaults to 0.5.
-
static
add_args
(parser: argparse.ArgumentParser)[source]¶ Add KD arguments to parser
Parameters: parser (argparse.ArgumentParser) – parser
Supported models¶
NeuralTagger
¶
Useful for training taggers from Transformer models. NeuralTagger
model that uses LSTM and CNN based embedders are ~3M parameters in size (~30-100x smaller than BERT models) and ~10x faster on average.
Usage:
- Train a transformer tagger using
TransformerTokenClassifier
or usingnlp_architect train transformer_token
command - Train a neural tagger
Neural Tagger
using the trained transformer model and use theTeacherStudentDistill
model that was configured with the transformer model. This can be done usingNeural Tagger
’s train loop or by usingnlp_architect train tagger_kd
command
Note
More models supporting distillation will be added in next releases
[1] | Distilling the Knowledge in a Neural Network: Geoffrey Hinton, Oriol Vinyals, Jeff Dean, https://arxiv.org/abs/1503.02531 |