--- title: modeling.token_classification keywords: fastai sidebar: home_sidebar summary: "This module contains custom models, loss functions, custom splitters, etc... for token classification tasks like named entity recognition." description: "This module contains custom models, loss functions, custom splitters, etc... for token classification tasks like named entity recognition." nb_path: "nbs/02a_modeling-token-classification.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti
{% endraw %}

Token classification

The objective of token classification is to predict the correct label for each token provided in the input. In the computer vision world, this is akin to what we do in segmentation tasks whereby we attempt to predict the class/label for each pixel in an image. Named entity recognition (NER) is an example of token classification in the NLP space

{% raw %}
df_converters = {'tokens': ast.literal_eval, 'labels': ast.literal_eval, 'nested-labels': ast.literal_eval}

# full nlp dataset
# germ_eval_df = pd.read_csv('./data/task-token-classification/germeval2014ner_cleaned.csv', converters=df_converters)

# demo nlp dataset
germ_eval_df = pd.read_csv('./germeval2014_sample.csv', converters=df_converters)

print(len(germ_eval_df))
germ_eval_df.head()
1000
id source tokens labels nested-labels ds_type
0 0 n-tv.de vom 26.02.2005 [2005-02-26] [Schartau, sagte, dem, ", Tagesspiegel, ", vom, Freitag, ,, Fischer, sei, ", in, einer, Weise, aufgetreten, ,, die, alles, andere, als, überzeugend, war, ", .] [B-PER, O, O, O, B-ORG, O, O, O, O, B-PER, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O] [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O] train
1 1 welt.de vom 29.10.2005 [2005-10-29] [Firmengründer, Wolf, Peter, Bree, arbeitete, Anfang, der, siebziger, Jahre, als, Möbelvertreter, ,, als, er, einen, fliegenden, Händler, aus, dem, Libanon, traf, .] [O, B-PER, I-PER, I-PER, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, B-LOC, O, O] [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O] train
2 2 http://www.stern.de/sport/fussball/krawalle-in-der-fussball-bundesliga-dfb-setzt-auf-falsche-konzepte-1553657.html#utm_source=standard&utm_medium=rss-feed&utm_campaign=sport [2010-03-25] [Ob, sie, dabei, nach, dem, Runden, Tisch, am, 23., April, in, Berlin, durch, ein, pädagogisches, Konzept, unterstützt, wird, ,, ist, allerdings, zu, bezweifeln, .] [O, O, O, O, O, O, O, O, O, O, O, B-LOC, O, O, O, O, O, O, O, O, O, O, O, O] [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O] train
3 3 stern.de vom 21.03.2006 [2006-03-21] [Bayern, München, ist, wieder, alleiniger, Top-, Favorit, auf, den, Gewinn, der, deutschen, Fußball-Meisterschaft, .] [B-ORG, I-ORG, O, O, O, O, O, O, O, O, O, B-LOCderiv, O, O] [B-LOC, B-LOC, O, O, O, O, O, O, O, O, O, O, O, O] train
4 4 http://www.fr-online.de/in_und_ausland/sport/aktuell/1618625_Frings-schaut-finster-in-die-Zukunft.html [2008-10-24] [Dabei, hätte, der, tapfere, Schlussmann, allen, Grund, gehabt, ,, sich, viel, früher, aufzuregen, .] [O, O, O, O, O, O, O, O, O, O, O, O, O, O] [O, O, O, O, O, O, O, O, O, O, O, O, O, O] train
{% endraw %}

We are only going to be working with small sample from the GermEval 2014 data set ... so the results might not be all that great :).

{% raw %}
labels = sorted(list(set([lbls for sublist in germ_eval_df.labels.tolist() for lbls in sublist])))
print(labels)
['B-LOC', 'B-LOCderiv', 'B-LOCpart', 'B-ORG', 'B-ORGpart', 'B-OTH', 'B-OTHderiv', 'B-OTHpart', 'B-PER', 'B-PERderiv', 'B-PERpart', 'I-LOC', 'I-LOCderiv', 'I-ORG', 'I-ORGpart', 'I-OTH', 'I-PER', 'O']
{% endraw %} {% raw %}
task = HF_TASKS_AUTO.TokenClassification
pretrained_model_name = "bert-base-multilingual-cased"
config = AutoConfig.from_pretrained(pretrained_model_name)

config.num_labels = len(labels)
{% endraw %}

Notice above how I set the config.num_labels attribute to the number of labels we want our model to be able to predict. The model will update its last layer accordingly (this concept is essentially transfer learning).

{% raw %}
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               task=task, 
                                                                               config=config)
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
('bert',
 transformers.configuration_bert.BertConfig,
 transformers.tokenization_bert.BertTokenizer,
 transformers.modeling_bert.BertForTokenClassification)
{% endraw %} {% raw %}
test_eq(hf_config.num_labels, len(labels))
{% endraw %} {% raw %}
hf_batch_tfm = HF_TokenClassBatchTransform(hf_arch, hf_tokenizer)

blocks = (
    HF_TextBlock(hf_arch, hf_tokenizer, hf_batch_tfm=hf_batch_tfm, max_length=128, is_pretokenized=True,
                 tok_kwargs={ 'return_special_tokens_mask': True }), 
    HF_TokenCategoryBlock(vocab=labels)
)

def get_y(inp):
    return [ (label, len(hf_tokenizer.tokenize(str(entity)))) for entity, label in zip(inp.tokens, inp.labels) ]

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('tokens'),
                   get_y=get_y,
                   splitter=RandomSplitter())
{% endraw %}

We have to define a get_y that creates the same number of labels as there are subtokens for a particular token. For example, my name "Wayde" gets split up into two subtokens, "Way" and "##de". The label for "Wayde" is "B-PER" and we just repeat it for the subtokens. This all get cleaned up when we show results and get predictions.

{% raw %}
dls = dblock.dataloaders(germ_eval_df, bs=2)
{% endraw %} {% raw %}
dls.show_batch(dataloaders=dls, max_n=2)
token / target label
0 [('In', 'O'), ('den', 'O'), ('Jahren', 'O'), ('1991', 'O'), ('bis', 'O'), ('1996', 'O'), ('bekleidete', 'O'), ('er', 'O'), ('die', 'O'), ('Funktion', 'O'), ('des', 'O'), ('Direktors', 'O'), ('des', 'O'), ('International', 'B-ORG'), ('Film', 'I-ORG'), ('Festival', 'I-ORG'), ('Rotterdam', 'I-ORG'), ('und', 'O'), ('des', 'O'), ('Hubert', 'B-ORG'), ('Bals', 'I-ORG'), ('Fonds', 'I-ORG'), (',', 'O'), ('mit', 'O'), ('dem', 'O'), ('Filmproduktionen', 'O'), ('in', 'O'), ('Entwicklungsländern', 'O'), ('gefördert', 'O'), ('werden', 'O'), ('.', 'O')]
1 [('Die', 'O'), ('nächste', 'O'), ('AchemAsia', 'B-ORG'), ('eröffnet', 'O'), ('im', 'O'), ('Frühjahr', 'O'), ('2013', 'O'), ('in', 'O'), ('Beijing', 'B-LOC'), ('.', 'O')]
{% endraw %}

Metrics

In this section, we'll add helpful metrics for token classification tasks

{% raw %}
{% endraw %} {% raw %}

calculate_token_class_metrics[source]

calculate_token_class_metrics(pred_toks, targ_toks, metric_key)

{% endraw %}

Training

{% raw %}
{% endraw %} {% raw %}

class HF_TokenClassCallback[source]

HF_TokenClassCallback(tok_metrics=['accuracy', 'precision', 'recall', 'f1'], **kwargs) :: HF_BaseModelCallback

A fastai friendly callback that includes accuracy, precision, recall, and f1 metrics using the seqeval library. Additionally, this metric knows how to not include your 'ignore_token' in it's calculations.

See here for more information on seqeval.

{% endraw %} {% raw %}
model = HF_BaseModelWrapper(hf_model)

learn = Learner(dls, 
                model,
                opt_func=partial(Adam),
                cbs=[HF_TokenClassCallback],
                splitter=hf_splitter)


learn.create_opt()             # -> will create your layer groups based on your "splitter" function
learn.freeze()
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
b = dls.one_batch()
preds = learn.model(b[0])
len(preds),preds[0].shape
(1, torch.Size([2, 40, 18]))
{% endraw %} {% raw %}
len(b), len(b[0]), b[0]['input_ids'].shape, len(b[1]), b[1].shape
(2, 3, torch.Size([2, 40]), 2, torch.Size([2, 40]))
{% endraw %} {% raw %}
print(preds[0].view(-1, preds[0].shape[-1]).shape, b[1].view(-1).shape)
test_eq(preds[0].view(-1, preds[0].shape[-1]).shape[0], b[1].view(-1).shape[0])
torch.Size([80, 18]) torch.Size([80])
{% endraw %} {% raw %}
print(len(learn.opt.param_groups))
4
{% endraw %} {% raw %}
learn.unfreeze()
learn.lr_find(suggestions=True)
SuggestedLRs(lr_min=0.0007585775572806596, lr_steep=3.0199516913853586e-05)
{% endraw %} {% raw %}
learn.fit_one_cycle(3, lr_max= 3e-5, moms=(0.8,0.7,0.8))
epoch train_loss valid_loss accuracy precision recall f1 time
0 0.259039 0.139523 0.962448 0.636735 0.655462 0.645963 00:36
1 0.108007 0.130293 0.962726 0.702041 0.625455 0.661538 00:36
2 0.053934 0.109036 0.970793 0.759184 0.720930 0.739563 00:36
{% endraw %} {% raw %}
print(learn.token_classification_report)
           precision    recall  f1-score   support

      PER       0.96      0.97      0.96        69
      ORG       0.75      0.66      0.70        61
      LOC       0.78      0.68      0.73        66
 LOCderiv       0.87      0.67      0.75        30
  ORGpart       0.14      1.00      0.25         1
      OTH       0.52      0.40      0.45        30
  LOCpart       0.12      1.00      0.22         1

micro avg       0.76      0.72      0.74       258
macro avg       0.80      0.72      0.75       258

{% endraw %}

Showing results

Below we'll add in additional functionality to more intuitively show the results of our model.

{% raw %}
{% endraw %} {% raw %}
learn.show_results(learner=learn, max_n=2)
token / target label / predicted label
0 [('Neben', 'O', 'O'), ('einem', 'O', 'O'), ('4', 'O', 'O'), ('-', 'O', 'O'), ('in', 'O', 'O'), ('-', 'O', 'O'), ('1', 'O', 'O'), ('Kartenleser', 'O', 'O'), ('und', 'B-ORG', 'B-ORG'), ('Bluetooth', 'O', 'O'), ('2', 'O', 'O'), ('.', 'O', 'O'), ('0', 'O', 'O'), ('hat', 'O', 'O'), ('Medion', 'O', 'O'), ('einen', 'O', 'O'), ('8', 'O', 'O'), ('-', 'O', 'O'), ('fach', 'O', 'O'), ('DVD', 'O', 'O'), ('-', 'O', 'O')]
1 [('Das', 'O', 'O'), ('ist', 'O', 'O'), ('die', 'O', 'O'), ('Geschäftspolitik', 'O', 'O'), ('meines', 'O', 'O'), ('Vorgängers', 'O', 'O'), (',', 'O', 'O'), ('die', 'O', 'O'), ('ich', 'O', 'O'), ('sehr', 'O', 'O'), ('schätze', 'O', 'O'), (':', 'O', 'O'), ('Jedes', 'O', 'O'), ('Jahr', 'O', 'O'), ('gibt', 'O', 'O'), ('es', 'O', 'O'), ('ein', 'O', 'O'), ('bisschen', 'O', 'O'), ('mehr', 'O', 'O'), (',', 'O', 'O'), ('ergänzte', 'O', 'O'), ('Lutz', 'B-PER', 'B-PER'), (',', 'O', 'O'), ('der', 'O', 'O'), ('seit', 'O', 'O'), ('zwei', 'O', 'O'), ('Jahren', 'O', 'O'), ('Vorstandsvorsitzender', 'O', 'O'), ('der', 'O', 'O'), ('BayWa', 'B-ORG', 'B-ORG'), ('AG', 'I-ORG', 'I-ORG'), ('ist', 'O', 'O'), ('.', 'O', 'O')]
{% endraw %} {% raw %}
res = learn.blurr_predict('My name is Wayde and I live in San Diego'.split())
print(res[0])
['O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'O']
{% endraw %}

The default Learner.predict method returns a prediction per subtoken, including the special tokens for each architecture's tokenizer.

{% raw %}
{% endraw %} {% raw %}

Learner.blurr_predict_tokens[source]

Learner.blurr_predict_tokens(inp, **kargs)

Remove all the unnecessary predicted tokens after calling Learner.predict, so that you only get the predicted labels, label ids, and probabilities for what you passed into it in addition to the input

{% endraw %} {% raw %}
txt ="Hi! My name is Wayde Gilliam from ohmeow.com. I live in California."
{% endraw %} {% raw %}
res = learn.blurr_predict_tokens(txt.split())
print([(tok, lbl) for tok,lbl in zip(res[0],res[1])])
[('Hi!', 'O'), ('My', 'O'), ('name', 'O'), ('is', 'O'), ('Wayde', 'B-PER'), ('Gilliam', 'I-PER'), ('from', 'O'), ('ohmeow.com.', 'B-ORG'), ('I', 'O'), ('live', 'O'), ('in', 'O'), ('California.', 'B-LOC')]
{% endraw %}

It's interesting (and very cool) how well this model performs on English even thought it was trained against a German corpus.

Tests

The tests below to ensure the token classification training code above works for all pretrained token classification models available in huggingface. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained token classification models you are working with ... and if any of your pretrained token classification models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

{% raw %}
try: del learn; torch.cuda.empty_cache()
except: pass
{% endraw %} {% raw %}
BLURR_MODEL_HELPER.get_models(task='TokenClassification')
[transformers.modeling_albert.AlbertForTokenClassification,
 transformers.modeling_auto.AutoModelForTokenClassification,
 transformers.modeling_bert.BertForTokenClassification,
 transformers.modeling_camembert.CamembertForTokenClassification,
 transformers.modeling_distilbert.DistilBertForTokenClassification,
 transformers.modeling_electra.ElectraForTokenClassification,
 transformers.modeling_longformer.LongformerForTokenClassification,
 transformers.modeling_mobilebert.MobileBertForTokenClassification,
 transformers.modeling_roberta.RobertaForTokenClassification,
 transformers.modeling_xlm.XLMForTokenClassification,
 transformers.modeling_xlm_roberta.XLMRobertaForTokenClassification,
 transformers.modeling_xlnet.XLNetForTokenClassification]
{% endraw %} {% raw %}
pretrained_model_names = [
    'albert-base-v1',
    'bert-base-multilingual-cased',
    'camembert-base',
    'distilbert-base-uncased',
    #'<electra>', # currently no pre-trained electra model works for token classification
    'allenai/longformer-base-4096',
    'google/mobilebert-uncased',
    'roberta-base',
    'xlm-mlm-ende-1024',
    'xlm-roberta-base',
    'xlnet-base-cased'
]
{% endraw %} {% raw %}
#hide_output
task = HF_TASKS_AUTO.TokenClassification
bsz = 2

test_results = []
for model_name in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    config = AutoConfig.from_pretrained(model_name)
    config.num_labels = len(labels)
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, 
                                                                                   task=task, 
                                                                                   config=config)
    
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n')
    
    hf_batch_tfm = HF_TokenClassBatchTransform(hf_arch, hf_tokenizer)

    blocks = (
        HF_TextBlock(hf_arch, hf_tokenizer, is_pretokenized=True, max_length=32, padding='max_length',
                     hf_batch_tfm=hf_batch_tfm,
                     tok_kwargs={ 'return_special_tokens_mask': True }), 
        HF_TokenCategoryBlock(vocab=labels)
    )

    dblock = DataBlock(blocks=blocks, 
                       get_x=ColReader('tokens'),
                       get_y= lambda inp: [ (label, len(hf_tokenizer.tokenize(str(entity)))) for entity, label in zip(inp.tokens, inp.labels) ],
                       splitter=RandomSplitter())
    
    dls = dblock.dataloaders(germ_eval_df, bs=bsz)

    model = HF_BaseModelWrapper(hf_model)
    learn = Learner(dls, 
                model,
                opt_func=partial(Adam),
                cbs=[HF_TokenClassCallback],
                splitter=hf_splitter)

    learn.create_opt()             # -> will create your layer groups based on your "splitter" function
    learn.unfreeze()
    
    b = dls.one_batch()
    
    try:
        print('*** TESTING DataLoaders ***')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, 32]))
        test_eq(len(b[1]), bsz)

        print('*** TESTING One pass through the model ***')
        preds = learn.model(b[0])
        test_eq(len(preds[0]), bsz)
        test_eq(preds[0].shape, torch.Size([bsz, 32, len(labels)]))

        print('*** TESTING Training/Results ***')
        learn.fit_one_cycle(1, lr_max= 3e-5, moms=(0.8,0.7,0.8))

        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'PASSED', ''))
        learn.show_results(learner=learn, max_n=2)
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'FAILED', err))
    finally:
        # cleanup
        del learn; torch.cuda.empty_cache()
{% endraw %} {% raw %}
arch tokenizer model_name result error
0 albert AlbertTokenizer AlbertForTokenClassification PASSED
1 bert BertTokenizer BertForTokenClassification PASSED
2 camembert CamembertTokenizer CamembertForTokenClassification PASSED
3 distilbert DistilBertTokenizer DistilBertForTokenClassification PASSED
4 longformer LongformerTokenizer LongformerForTokenClassification PASSED
5 mobilebert MobileBertTokenizer MobileBertForTokenClassification PASSED
6 roberta RobertaTokenizer RobertaForTokenClassification PASSED
7 xlm XLMTokenizer XLMForTokenClassification PASSED
8 xlm_roberta XLMRobertaTokenizer XLMRobertaForTokenClassification PASSED
9 xlnet XLNetTokenizer XLNetForTokenClassification PASSED
{% endraw %}

Cleanup