--- title: text.modeling.token_classification keywords: fastai sidebar: home_sidebar summary: "This module contains custom models, loss functions, custom splitters, etc... for token classification tasks (e.g., Named entity recognition (NER), Part-of-speech tagging (POS), etc...). The objective of token classification is to predict the correct label for each token provided in the input. In the computer vision world, this is akin to what we do in segmentation tasks whereby we attempt to predict the class/label for each pixel in an image." description: "This module contains custom models, loss functions, custom splitters, etc... for token classification tasks (e.g., Named entity recognition (NER), Part-of-speech tagging (POS), etc...). The objective of token classification is to predict the correct label for each token provided in the input. In the computer vision world, this is akin to what we do in segmentation tasks whereby we attempt to predict the class/label for each pixel in an image." nb_path: "nbs/13_text-modeling-token-classification.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
{% endraw %} {% raw %}
What we're running with at the time this documentation was generated:
torch: 1.10.1+cu111
fastai: 2.5.6
transformers: 4.16.2
{% endraw %}

Setup

We'll use a subset of conll2003 to demonstrate how to configure your BLURR code for token classification

Note: Make sure you set the config.num_labels attribute to the number of labels your model is predicting. The model will update its last layer accordingly as la transfer learning.

{% raw %}
raw_datasets = load_dataset("conll2003")

labels = raw_datasets["train"].features["ner_tags"].feature.names
print(f"Labels: {labels}")

conll2003_df = pd.DataFrame(raw_datasets["train"])
conll2003_df.head()
Reusing dataset conll2003 (/home/wgilliam/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6)
Labels: ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
chunk_tags id ner_tags pos_tags tokens
0 [11, 21, 11, 12, 21, 22, 11, 12, 0] 0 [3, 0, 7, 0, 0, 0, 7, 0, 0] [22, 42, 16, 21, 35, 37, 16, 21, 7] [EU, rejects, German, call, to, boycott, British, lamb, .]
1 [11, 12] 1 [1, 2] [22, 22] [Peter, Blackburn]
2 [11, 12] 2 [5, 0] [22, 11] [BRUSSELS, 1996-08-22]
3 [11, 12, 12, 21, 13, 11, 11, 21, 13, 11, 12, 13, 11, 21, 22, 11, 12, 17, 11, 21, 17, 11, 12, 12, 21, 22, 22, 13, 11, 0] 3 [0, 3, 4, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [12, 22, 22, 38, 15, 22, 28, 38, 15, 16, 21, 35, 24, 35, 37, 16, 21, 15, 24, 41, 15, 16, 21, 21, 20, 37, 40, 35, 21, 7] [The, European, Commission, said, on, Thursday, it, disagreed, with, German, advice, to, consumers, to, shun, British, lamb, until, scientists, determine, whether, mad, cow, disease, can, be, transmitted, to, sheep, .]
4 [11, 11, 12, 13, 11, 12, 12, 11, 12, 12, 12, 12, 21, 13, 11, 12, 21, 22, 11, 13, 11, 1, 13, 11, 17, 11, 12, 12, 21, 1, 0] 4 [5, 0, 0, 0, 0, 3, 4, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0] [22, 27, 21, 35, 12, 22, 22, 27, 16, 21, 22, 22, 38, 15, 22, 24, 20, 37, 21, 15, 24, 16, 15, 22, 15, 12, 16, 21, 38, 17, 7] [Germany, 's, representative, to, the, European, Union, 's, veterinary, committee, Werner, Zwingmann, said, on, Wednesday, consumers, should, buy, sheepmeat, from, countries, other, than, Britain, until, the, scientific, advice, was, clearer, .]
{% endraw %} {% raw %}
model_cls = AutoModelForTokenClassification
pretrained_model_name = "roberta-base"
config = AutoConfig.from_pretrained(pretrained_model_name)

config.num_labels = len(labels)
hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(pretrained_model_name, model_cls=model_cls, config=config)
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('roberta',
 transformers.models.roberta.configuration_roberta.RobertaConfig,
 transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast,
 transformers.models.roberta.modeling_roberta.RobertaForTokenClassification)
{% endraw %} {% raw %}
test_eq(hf_config.num_labels, len(labels))
{% endraw %} {% raw %}
batch_tok_tfm = TokenClassBatchTokenizeTransform(hf_arch, hf_config, hf_tokenizer, hf_model)
blocks = (TextBlock(batch_tokenize_tfm=batch_tok_tfm, input_return_type=TokenClassTextInput), TokenCategoryBlock(vocab=labels))

dblock = DataBlock(blocks=blocks, get_x=ColReader("tokens"), get_y=ColReader("ner_tags"), splitter=RandomSplitter())
{% endraw %} {% raw %}
dls = dblock.dataloaders(conll2003_df, bs=4)
{% endraw %} {% raw %}
b = dls.one_batch()
{% endraw %} {% raw %}
dls.show_batch(dataloaders=dls, max_n=2)
word / target label
0 [('MARKET', 'O'), ('TALK', 'O'), ('-', 'O'), ('USDA', 'B-ORG'), ('net', 'O'), ('change', 'O'), ('in', 'O'), ('weekly', 'O'), ('export', 'O'), ('commitments', 'O'), ('for', 'O'), ('the', 'O'), ('week', 'O'), ('ended', 'O'), ('August', 'O'), ('22', 'O'), (',', 'O'), ('includes', 'O'), ('old', 'O'), ('crop', 'O'), ('and', 'O'), ('new', 'O'), ('crop', 'O'), (',', 'O'), ('were', 'O'), (':', 'O'), ('wheat', 'O'), ('up', 'O'), ('595,400', 'O'), ('tonnes', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), (';', 'O'), ('corn', 'O'), ('up', 'O'), ('1,900', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('319,600', 'O'), ('new', 'O'), (';', 'O'), ('soybeans', 'O'), ('down', 'O'), ('12,300', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('300,800', 'O'), ('new', 'O'), (';', 'O'), ('upland', 'O'), ('cotton', 'O'), ('up', 'O'), ('50,400', 'O'), ('bales', 'O'), ('new', 'O'), (',', 'O'), ('nil', 'O'), ('old', 'O'), (';', 'O'), ('soymeal', 'O'), ('54,800', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('100,600', 'O'), ('new', 'O'), (',', 'O'), ('soyoil', 'O'), ('nil', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('75,000', 'O'), ('new', 'O'), (';', 'O'), ('barley', 'O'), ('up', 'O'), ('1,700', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), (';', 'O'), ('sorghum', 'O'), ('6,200', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('156,700', 'O'), ('new', 'O'), (';', 'O'), ('pima', 'O'), ('cotton', 'O'), ('up', 'O'), ('4,000', 'O'), ('bales', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), (';', 'O'), ('rice', 'O'), ('up', 'O'), ('49,900', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), ('...', 'O')]
1 [('A', 'O'), ('chain-smoking', 'O'), ('former', 'O'), ('paratroop', 'O'), ('general', 'O'), ('with', 'O'), ('a', 'O'), ('sharp', 'O'), ('line', 'O'), ('in', 'O'), ('deadpan', 'O'), ('putdowns', 'O'), ('and', 'O'), ('a', 'O'), ('soldier', 'O'), ("'s", 'O'), ('knack', 'O'), ('for', 'O'), ('making', 'O'), ('life', 'O'), ('sound', 'O'), ('simple', 'O'), (',', 'O'), ('Lebed', 'B-PER'), ('managed', 'O'), ('to', 'O'), ('arrange', 'O'), ('an', 'O'), ('ambitious', 'O'), ('ceasefire', 'O'), ('in', 'O'), ('the', 'O'), ('region', 'O'), ('last', 'O'), ('week', 'O'), (',', 'O'), ('days', 'O'), ('after', 'O'), ('the', 'O'), ('Russian', 'B-MISC'), ('army', 'O'), ('threatened', 'O'), ('to', 'O'), ('bomb', 'O'), ('its', 'O'), ('way', 'O'), ('back', 'O'), ('into', 'O'), ('the', 'O'), ('rebel-held', 'O'), ('Chechen', 'B-MISC'), ('capital', 'O'), ('Grozny', 'B-LOC'), ('.', 'O')]
{% endraw %}

Mid-level API

In this section, we'll add helpful metrics for token classification tasks

{% raw %}
{% endraw %} {% raw %}

calculate_token_class_metrics[source]

calculate_token_class_metrics(pred_toks, targ_toks, metric_key)

{% endraw %} {% raw %}

class TokenClassMetricsCallback[source]

TokenClassMetricsCallback(tok_metrics=['accuracy', 'precision', 'recall', 'f1'], **kwargs) :: Callback

A fastai friendly callback that includes accuracy, precision, recall, and f1 metrics using the seqeval library. Additionally, this metric knows how to not include your 'ignore_token' in it's calculations.

See here for more information on seqeval.

{% endraw %} {% raw %}
{% endraw %}

Example

Training

{% raw %}
model = BaseModelWrapper(hf_model)
learn_cbs = [BaseModelCallback]
fit_cbs = [TokenClassMetricsCallback()]

learn = Learner(dls, model, opt_func=partial(Adam), loss_func=PreCalculatedCrossEntropyLoss(), cbs=learn_cbs, splitter=blurr_splitter)

learn.freeze()
{% endraw %} {% raw %}
learn.summary()
{% endraw %} {% raw %}
b = dls.one_batch()
preds = learn.model(b[0])
len(preds), type(preds), preds.keys()
(2,
 transformers.modeling_outputs.TokenClassifierOutput,
 odict_keys(['loss', 'logits']))
{% endraw %} {% raw %}
len(b), len(b[0]), b[0]["input_ids"].shape, len(b[1]), b[1].shape
(2, 3, torch.Size([4, 156]), 4, torch.Size([4, 156]))
{% endraw %} {% raw %}
preds.logits.shape
torch.Size([4, 156, 9])
{% endraw %} {% raw %}
print(preds.logits.view(-1, preds.logits.shape[-1]).shape, b[1].view(-1).shape)
test_eq(preds.logits.view(-1, preds.logits.shape[-1]).shape[0], b[1].view(-1).shape[0])
torch.Size([624, 9]) torch.Size([624])
{% endraw %} {% raw %}
print(len(learn.opt.param_groups))
3
{% endraw %} {% raw %}
learn.unfreeze()
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])
SuggestedLRs(minimum=0.00020892962347716094, steep=0.00015848931798245758, valley=5.248074739938602e-05, slide=0.0014454397605732083)
{% endraw %} {% raw %}
learn.fit_one_cycle(1, lr_max=3e-5, moms=(0.8, 0.7, 0.8), cbs=fit_cbs)
epoch train_loss valid_loss accuracy precision recall f1 time
0 0.056628 0.054247 0.987888 0.934055 0.926676 0.930351 03:17
{% endraw %} {% raw %}
print(learn.token_classification_report)
              precision    recall  f1-score   support

         LOC       0.96      0.94      0.95      1456
        MISC       0.85      0.86      0.85       702
         ORG       0.91      0.89      0.90      1346
         PER       0.98      0.97      0.97      1433

   micro avg       0.93      0.93      0.93      4937
   macro avg       0.92      0.92      0.92      4937
weighted avg       0.93      0.93      0.93      4937

{% endraw %}

Showing results

Below we'll add in additional functionality to more intuitively show the results of our model.

{% raw %}
{% endraw %} {% raw %}
learn.show_results(learner=learn, max_n=2, trunc_at=10)
token / target label / predicted label
0 [('15', 'O', 'O'), ('-', 'O', 'O'), ('Christian', 'B-PER', 'B-PER'), ('Cullen', 'I-PER', 'I-PER'), (',', 'O', 'O'), ('14', 'O', 'O'), ('-', 'O', 'O'), ('Jeff', 'B-PER', 'B-PER'), ('Wilson', 'I-PER', 'I-PER'), (',', 'O', 'O')]
1 [('"', 'O', 'O'), ('I', 'O', 'O'), ('still', 'O', 'O'), ('feel', 'O', 'O'), ('it', 'O', 'O'), ("'s", 'O', 'O'), ('embarrassing', 'O', 'O'), ('what', 'O', 'O'), ('happened', 'O', 'O'), ('and', 'O', 'O')]
{% endraw %}

Prediction

The default Learner.predict method returns a prediction per subtoken, including the special tokens for each architecture's tokenizer. Starting with version 2.0 of BLURR, we bring token prediction in-line with Hugging Face's token classification pipeline, both in terms of supporting the same aggregation strategies via Blurr's TokenAggregationStrategies class, and also the output via BLURR's @patched Learner method, blurr_predict_tokens.

{% raw %}

class TokenAggregationStrategies[source]

TokenAggregationStrategies(hf_tokenizer:PreTrainedTokenizerBase, labels:List[str], non_entity_label:str='O')

Provides the equivalanet of Hugging Face's token classification pipeline's aggregation_strategy support across various token classication tasks (e.g, NER, POS, chunking, etc...)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}

Learner.blurr_predict_tokens[source]

Learner.blurr_predict_tokens(items:Union[str, typing.List[str]], aggregation_strategy:str='simple', non_entity_label:str='O', slow_word_ids_func:Optional[typing.Callable]=None)

Type Default Details
items typing.Union[str, typing.List[str]] The str (or list of strings) you want to get token classification predictions for
aggregation_strategy str simple How entities are grouped and scored
non_entity_label str O The label used to idendity non-entity related words/tokens
slow_word_ids_func typing.Optional[typing.Callable] None If using a slow tokenizer, users will need to prove a slow_word_ids_func that accepts a
tokenizzer, example index, and a batch encoding as arguments and in turn returnes the
equavlient of fast tokenizer's `word_ids``
{% endraw %} {% raw %}
res = learn.blurr_predict_tokens(
    items=["My name is Wayde and I live in San Diego and using Hugging Face", "Bayern Munich is a soccer team in Germany"],
    aggregation_strategy="max",
)

print(len(res))
print(res[1])
2
[{'entity_group': 'ORG', 'score': 0.9898321628570557, 'word': 'Bayern Munich', 'start': 0, 'end': 13}, {'entity_group': 'LOC', 'score': 0.9972659349441528, 'word': 'Germany', 'start': 34, 'end': 41}]
{% endraw %} {% raw %}
txt = "Hi! My name is Wayde Gilliam from ohmeow.com. I live in California."
txt2 = "I wish covid was over so I could go to Germany and watch Bayern Munich play in the Bundesliga."
{% endraw %} {% raw %}
res = learn.blurr_predict_tokens(txt)
print(res)
[[{'entity_group': 'PER', 'score': 0.9760623723268509, 'word': 'Wayde Gilliam', 'start': 15, 'end': 28}, {'entity_group': 'ORG', 'score': 0.5695258180300394, 'word': 'ohmeow', 'start': 34, 'end': 40}, {'entity_group': 'ORG', 'score': 0.4891514480113983, 'word': 'com', 'start': 41, 'end': 44}, {'entity_group': 'MISC', 'score': 0.16531413793563843, 'word': '.', 'start': 44, 'end': 45}, {'entity_group': 'LOC', 'score': 0.9955629110336304, 'word': 'California', 'start': 56, 'end': 66}, {'entity_group': 'MISC', 'score': 0.16531400382518768, 'word': '.', 'start': 66, 'end': 67}]]
{% endraw %} {% raw %}
results = learn.blurr_predict_tokens([txt, txt2])
for res in results:
    print(f"{res}\n")
[{'entity_group': 'PER', 'score': 0.9760623723268509, 'word': 'Wayde Gilliam', 'start': 15, 'end': 28}, {'entity_group': 'ORG', 'score': 0.5695258180300394, 'word': 'ohmeow', 'start': 34, 'end': 40}, {'entity_group': 'ORG', 'score': 0.4891514480113983, 'word': 'com', 'start': 41, 'end': 44}, {'entity_group': 'MISC', 'score': 0.16531413793563843, 'word': '.', 'start': 44, 'end': 45}, {'entity_group': 'LOC', 'score': 0.9955629110336304, 'word': 'California', 'start': 56, 'end': 66}, {'entity_group': 'MISC', 'score': 0.16531400382518768, 'word': '.', 'start': 66, 'end': 67}]

[{'entity_group': 'LOC', 'score': 0.9958367347717285, 'word': 'Germany', 'start': 39, 'end': 46}, {'entity_group': 'ORG', 'score': 0.9850218594074249, 'word': 'Bayern Munich', 'start': 57, 'end': 70}, {'entity_group': 'MISC', 'score': 0.9450671672821045, 'word': 'Bundesliga', 'start': 83, 'end': 93}, {'entity_group': 'MISC', 'score': 0.17792454361915588, 'word': '.', 'start': 93, 'end': 94}]

{% endraw %}

Inference

{% raw %}
export_fname = "tok_class_learn_export"
{% endraw %} {% raw %}
learn.export(fname=f"{export_fname}.pkl")
inf_learn = load_learner(fname=f"{export_fname}.pkl")

results = inf_learn.blurr_predict_tokens([txt, txt2])
for res in results:
    print(f"{res}\n")
[{'entity_group': 'PER', 'score': 0.9760623574256897, 'word': 'Wayde Gilliam', 'start': 15, 'end': 28}, {'entity_group': 'ORG', 'score': 0.5695257385571798, 'word': 'ohmeow', 'start': 34, 'end': 40}, {'entity_group': 'ORG', 'score': 0.4891512393951416, 'word': 'com', 'start': 41, 'end': 44}, {'entity_group': 'MISC', 'score': 0.16531416773796082, 'word': '.', 'start': 44, 'end': 45}, {'entity_group': 'LOC', 'score': 0.9955630302429199, 'word': 'California', 'start': 56, 'end': 66}, {'entity_group': 'MISC', 'score': 0.16531400382518768, 'word': '.', 'start': 66, 'end': 67}]

[{'entity_group': 'LOC', 'score': 0.9958367347717285, 'word': 'Germany', 'start': 39, 'end': 46}, {'entity_group': 'ORG', 'score': 0.9850217401981354, 'word': 'Bayern Munich', 'start': 57, 'end': 70}, {'entity_group': 'MISC', 'score': 0.9450671672821045, 'word': 'Bundesliga', 'start': 83, 'end': 93}, {'entity_group': 'MISC', 'score': 0.17792455852031708, 'word': '.', 'start': 93, 'end': 94}]

{% endraw %}

High-level API

{% raw %}

class BlearnerForTokenClassification[source]

BlearnerForTokenClassification(dls:DataLoaders, hf_model:PreTrainedModel, base_model_cb:BaseModelCallback=BaseModelCallback, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, cbs=None, metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95)) :: Blearner

Group together a model, some dls and a loss_func to handle training

{% endraw %} {% raw %}
{% endraw %}

Example

Define your Blearner

{% raw %}
learn = BlearnerForTokenClassification.from_data(
    conll2003_df,
    "distilroberta-base",
    tokens_attr="tokens",
    token_labels_attr="ner_tags",
    labels=labels,
    dl_kwargs={"bs": 2},
)

learn.unfreeze()
{% endraw %} {% raw %}
learn.dls.show_batch(dataloaders=learn.dls, max_n=2)
word / target label
0 [('MARKET', 'O'), ('TALK', 'O'), ('-', 'O'), ('USDA', 'B-ORG'), ('net', 'O'), ('change', 'O'), ('in', 'O'), ('weekly', 'O'), ('export', 'O'), ('commitments', 'O'), ('for', 'O'), ('the', 'O'), ('week', 'O'), ('ended', 'O'), ('August', 'O'), ('22', 'O'), (',', 'O'), ('includes', 'O'), ('old', 'O'), ('crop', 'O'), ('and', 'O'), ('new', 'O'), ('crop', 'O'), (',', 'O'), ('were', 'O'), (':', 'O'), ('wheat', 'O'), ('up', 'O'), ('595,400', 'O'), ('tonnes', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), (';', 'O'), ('corn', 'O'), ('up', 'O'), ('1,900', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('319,600', 'O'), ('new', 'O'), (';', 'O'), ('soybeans', 'O'), ('down', 'O'), ('12,300', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('300,800', 'O'), ('new', 'O'), (';', 'O'), ('upland', 'O'), ('cotton', 'O'), ('up', 'O'), ('50,400', 'O'), ('bales', 'O'), ('new', 'O'), (',', 'O'), ('nil', 'O'), ('old', 'O'), (';', 'O'), ('soymeal', 'O'), ('54,800', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('100,600', 'O'), ('new', 'O'), (',', 'O'), ('soyoil', 'O'), ('nil', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('75,000', 'O'), ('new', 'O'), (';', 'O'), ('barley', 'O'), ('up', 'O'), ('1,700', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), (';', 'O'), ('sorghum', 'O'), ('6,200', 'O'), ('old', 'O'), (',', 'O'), ('up', 'O'), ('156,700', 'O'), ('new', 'O'), (';', 'O'), ('pima', 'O'), ('cotton', 'O'), ('up', 'O'), ('4,000', 'O'), ('bales', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), (';', 'O'), ('rice', 'O'), ('up', 'O'), ('49,900', 'O'), ('old', 'O'), (',', 'O'), ('nil', 'O'), ('new', 'O'), ('...', 'O')]
1 [('"', 'O'), ('This', 'O'), ('finding', 'O'), ('is', 'O'), ('important', 'O'), ('because', 'O'), ('one', 'O'), ('of', 'O'), ('the', 'O'), ('jars', 'O'), ('still', 'O'), ('contains', 'O'), ('substances', 'O'), ('and', 'O'), ('materials', 'O'), ('used', 'O'), ('in', 'O'), ('the', 'O'), ('conservation', 'O'), ('of', 'O'), ('mummies', 'O'), ('and', 'O'), ('the', 'O'), ('conservation', 'O'), ('of', 'O'), ('the', 'O'), ('intestines', 'O'), ('and', 'O'), ('all', 'O'), ('the', 'O'), ('things', 'O'), ('which', 'O'), ('were', 'O'), ('in', 'O'), ('the', 'O'), ('cavity', 'O'), ('of', 'O'), ('a', 'O'), ('person', 'O'), ('we', 'O'), ('have', 'O'), ('not', 'O'), ('identified', 'O'), ('yet', 'O'), (',', 'O'), ('"', 'O'), ('Saleh', 'B-PER'), ('said', 'O'), ('.', 'O')]
{% endraw %}

Train

{% raw %}
learn.fit_one_cycle(1, lr_max=3e-5, moms=(0.8, 0.7, 0.8), cbs=[BlearnerForTokenClassification.get_metrics_cb()])
epoch train_loss valid_loss accuracy precision recall f1 time
0 0.075006 0.053937 0.987387 0.932157 0.928798 0.930474 04:03
{% endraw %} {% raw %}
learn.show_results(learner=learn, max_n=2, trunc_at=10)
token / target label / predicted label
0 [('Squad', 'O', 'O'), (':', 'O', 'O'), ('Alan', 'B-PER', 'B-PER'), ('Kelly', 'I-PER', 'I-PER'), (',', 'O', 'O'), ('Shay', 'B-PER', 'B-PER'), ('Given', 'I-PER', 'I-PER'), (',', 'O', 'O'), ('Denis', 'B-PER', 'B-PER'), ('Irwin', 'I-PER', 'I-PER')]
1 [('The', 'O', 'O'), ('newspaper', 'O', 'O'), ('said', 'O', 'O'), ('Bamerindus', 'B-ORG', 'B-ORG'), ('has', 'O', 'O'), ('sent', 'O', 'O'), ('to', 'O', 'O'), ('the', 'O', 'O'), ('Central', 'B-ORG', 'B-ORG'), ('Bank', 'I-ORG', 'I-ORG')]
{% endraw %} {% raw %}
print(learn.token_classification_report)
              precision    recall  f1-score   support

         LOC       0.96      0.96      0.96      1420
        MISC       0.85      0.87      0.86       670
         ORG       0.91      0.89      0.90      1297
         PER       0.97      0.97      0.97      1332

   micro avg       0.93      0.93      0.93      4719
   macro avg       0.92      0.92      0.92      4719
weighted avg       0.93      0.93      0.93      4719

{% endraw %}

Prediction

{% raw %}
txt = "Hi! My name is Wayde Gilliam from ohmeow.com. I live in California."
txt2 = "I wish covid was over so I could watch Lewandowski score some more goals for Bayern Munich in the Bundesliga."
{% endraw %} {% raw %}
results = learn.predict([txt, txt2])
for res in results:
    print(f"{res}\n")
[{'entity_group': 'PER', 'score': 0.9957719445228577, 'word': 'Way', 'start': 15, 'end': 18}, {'entity_group': 'PER', 'score': 0.9542139172554016, 'word': 'de Gilliam', 'start': 18, 'end': 28}, {'entity_group': 'ORG', 'score': 0.379508301615715, 'word': 'ohme', 'start': 34, 'end': 38}, {'entity_group': 'LOC', 'score': 0.9953558444976807, 'word': 'California', 'start': 56, 'end': 66}]

[{'entity_group': 'PER', 'score': 0.7999598979949951, 'word': 'cov', 'start': 7, 'end': 10}, {'entity_group': 'PER', 'score': 0.990045577287674, 'word': 'Lewandowski', 'start': 39, 'end': 50}, {'entity_group': 'ORG', 'score': 0.988516092300415, 'word': 'Bayern Munich', 'start': 77, 'end': 90}, {'entity_group': 'MISC', 'score': 0.9711833596229553, 'word': 'Bundesliga', 'start': 98, 'end': 108}]

{% endraw %}

Tests

The tests below to ensure the token classification training code above works for all pretrained token classification models available in Hugging Face. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained token classification models you are working with ... and if any of your pretrained token classification models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

{% raw %}
raw_datasets = load_dataset("conll2003")
labels = raw_datasets["train"].features["ner_tags"].feature.names
conll2003_df = pd.DataFrame(raw_datasets["train"])
Reusing dataset conll2003 (/home/wgilliam/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6)
{% endraw %} {% raw %}
arch tokenizer model_name result error
0 albert AlbertTokenizerFast AlbertForTokenClassification PASSED
1 bert BertTokenizerFast BertForTokenClassification PASSED
2 big_bird BigBirdTokenizerFast BigBirdForTokenClassification PASSED
3 camembert CamembertTokenizerFast CamembertForTokenClassification PASSED
4 convbert ConvBertTokenizerFast ConvBertForTokenClassification PASSED
5 deberta DebertaTokenizerFast DebertaForTokenClassification PASSED
6 bert BertTokenizerFast BertForTokenClassification PASSED
7 electra ElectraTokenizerFast ElectraForTokenClassification PASSED
8 funnel FunnelTokenizerFast FunnelForTokenClassification PASSED
9 gpt2 GPT2TokenizerFast GPT2ForTokenClassification PASSED
10 layoutlm LayoutLMTokenizerFast LayoutLMForTokenClassification PASSED
11 longformer LongformerTokenizerFast LongformerForTokenClassification PASSED
12 mpnet MPNetTokenizerFast MPNetForTokenClassification PASSED
13 ibert RobertaTokenizerFast IBertForTokenClassification PASSED
14 mobilebert MobileBertTokenizerFast MobileBertForTokenClassification PASSED
15 rembert RemBertTokenizerFast RemBertForTokenClassification PASSED
16 roformer RoFormerTokenizerFast RoFormerForTokenClassification PASSED
17 roberta RobertaTokenizerFast RobertaForTokenClassification PASSED
18 squeezebert SqueezeBertTokenizerFast SqueezeBertForTokenClassification PASSED
19 xlm_roberta XLMRobertaTokenizerFast XLMRobertaForTokenClassification PASSED
20 xlnet XLNetTokenizerFast XLNetForTokenClassification PASSED
{% endraw %}