--- title: data.token_classification keywords: fastai sidebar: home_sidebar summary: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for token classification tasks (e.g., NER or named entity recognition, etc...)." description: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for token classification tasks (e.g., NER or named entity recognition, etc...)." nb_path: "nbs/01a_data-token-classification.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti
{% endraw %}

Token classification tokenization, batch transform, and DataBlock methods

Token classification tasks attempt to predict a class for each token. The idea is similar to that in image segmentation models where the objective is to predict a class for each pixel. Such models are common in building named entity recognition (NER) systems.

{% raw %}
df_converters = {'tokens': ast.literal_eval, 'labels': ast.literal_eval, 'nested-labels': ast.literal_eval}

path = Path('./')
germ_eval_df = pd.read_csv(path/'germeval2014_sample.csv', converters=df_converters); len(germ_eval_df)
1000
{% endraw %} {% raw %}
labels = sorted(list(set([lbls for sublist in germ_eval_df.labels.tolist() for lbls in sublist])))
print(labels)
['B-LOC', 'B-LOCderiv', 'B-LOCpart', 'B-ORG', 'B-ORGpart', 'B-OTH', 'B-OTHderiv', 'B-OTHpart', 'B-PER', 'B-PERderiv', 'B-PERpart', 'I-LOC', 'I-LOCderiv', 'I-ORG', 'I-ORGpart', 'I-OTH', 'I-PER', 'O']
{% endraw %} {% raw %}
model_cls = AutoModelForTokenClassification

# pretrained_model_name = "bert-base-multilingual-cased"
pretrained_model_name = 'roberta-base'
n_labels = len(labels)

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, 
                                                                  model_cls=model_cls,
                                                                  config_kwargs={'num_labels': n_labels})
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('roberta',
 transformers.models.roberta.configuration_roberta.RobertaConfig,
 transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast,
 transformers.models.roberta.modeling_roberta.RobertaForTokenClassification)
{% endraw %}

Below, we define a new class and transform for token classification targets/predictions.

{% raw %}

class HF_TokenTensorCategory[source]

HF_TokenTensorCategory(x, **kwargs) :: TensorBase

A Tensor which support subclass pickling, and maintains metadata when casting or after methods

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class HF_TokenCategorize[source]

HF_TokenCategorize(vocab=None, ignore_token=None, ignore_token_id=None) :: Transform

Reversible transform of a list of category string to vocab id

{% endraw %} {% raw %}
{% endraw %}

HF_TokenCategorize modifies the fastai Categorize transform in a couple of ways. First, it allows your targets to consist of a Category per token, and second, it uses the idea of an ignore_token_id to mask subtokens that don't need a prediction. For example, the target of special tokens (e.g., pad, cls, sep) are set to ignore_token_id as are subsequent sub-tokens of a given token should more than 1 sub-token make it up.

{% raw %}
{% endraw %} {% raw %}

HF_TokenCategoryBlock[source]

HF_TokenCategoryBlock(vocab=None, ignore_token=None, ignore_token_id=None)

TransformBlock for single-label categorical targets

{% endraw %}

Again, we define a custom class, HF_TokenClassInput, for the @typedispatched methods to use so that we can override how token classification inputs/targets are assembled, as well as, how the data is shown via methods like show_batch and show_results.

{% raw %}

class HF_TokenClassInput[source]

HF_TokenClassInput(x, **kwargs) :: HF_BaseInput

A Tensor which support subclass pickling, and maintains metadata when casting or after methods

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class HF_TokenClassBeforeBatchTransform[source]

HF_TokenClassBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model, ignore_token_id=-100, max_length=None, padding=True, truncation=True, is_split_into_words=True, tok_kwargs={}, **kwargs) :: HF_BeforeBatchTransform

Handles everything you need to assemble a mini-batch of inputs and targets, as well as decode the dictionary produced as a byproduct of the tokenization process in the encodes method.

{% endraw %} {% raw %}
{% endraw %}

HF_TokenClassBeforeBatchTransform is used to exclude any of the target's tokens we don't want to include in the loss calcuation (e.g. padding, cls, sep, etc...).

{% raw %}
before_batch_tfm = HF_TokenClassBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                     is_split_into_words=True, 
                                                     tok_kwargs={ 'return_special_tokens_mask': True })

blocks = (
    HF_TextBlock(before_batch_tfm=before_batch_tfm, input_return_type=HF_TokenClassInput), 
    HF_TokenCategoryBlock(vocab=labels)
)

def get_y(inp):
    return [ (label, len(hf_tokenizer.tokenize(str(entity)))) for entity, label in zip(inp.tokens, inp.labels) ]

dblock = DataBlock(blocks=blocks, get_x=ColReader('tokens'), get_y=get_y, splitter=RandomSplitter())
{% endraw %}

Note in the example above we had to define a get_y in order to return both the entity we want to predict a category for, as well as, how many subtokens are used by the hf_tokenizer to represent it. This is necessary for the input/target alignment discussed above.

{% raw %}
 
{% endraw %} {% raw %}
dls = dblock.dataloaders(germ_eval_df, bs=4)
{% endraw %} {% raw %}
b = dls.one_batch()
{% endraw %} {% raw %}
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([4, 98]), torch.Size([4, 98]))
{% endraw %} {% raw %}
{% endraw %} {% raw %}
dls.show_batch(dataloaders=dls, max_n=2, trunc_at=10)
token / target label
0 [('Helbig', 'B-OTH'), ('et', 'I-OTH'), ('al', 'I-OTH'), ('.', 'O'), ('(', 'O'), ('1994', 'O'), (')', 'O'), ('S.', 'O'), ('593.', 'O'), ('Wink', 'B-OTH')]
1 [('Scenes', 'B-OTH'), ('of', 'I-OTH'), ('a', 'I-OTH'), ('Sexual', 'I-OTH'), ('Nature', 'I-OTH'), ('(', 'O'), ('GB', 'O'), ('2006', 'O'), (')', 'O'), ('-', 'O')]
{% endraw %}

Tests

The tests below to ensure the core DataBlock code above works for all pretrained token classification models available in huggingface. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained classification models you are working with ... and if any of your pretrained token classification models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

{% raw %}
[ model_type for model_type in BLURR.get_models(task='TokenClassification') 
 if (not model_type.__name__.startswith('TF')) ]
[transformers.models.albert.modeling_albert.AlbertForTokenClassification,
 transformers.models.bert.modeling_bert.BertForTokenClassification,
 transformers.models.big_bird.modeling_big_bird.BigBirdForTokenClassification,
 transformers.models.camembert.modeling_camembert.CamembertForTokenClassification,
 transformers.models.convbert.modeling_convbert.ConvBertForTokenClassification,
 transformers.models.deberta.modeling_deberta.DebertaForTokenClassification,
 transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2ForTokenClassification,
 transformers.models.distilbert.modeling_distilbert.DistilBertForTokenClassification,
 transformers.models.electra.modeling_electra.ElectraForTokenClassification,
 transformers.models.flaubert.modeling_flaubert.FlaubertForTokenClassification,
 transformers.models.funnel.modeling_funnel.FunnelForTokenClassification,
 transformers.models.ibert.modeling_ibert.IBertForTokenClassification,
 transformers.models.layoutlm.modeling_layoutlm.LayoutLMForTokenClassification,
 transformers.models.longformer.modeling_longformer.LongformerForTokenClassification,
 transformers.models.mpnet.modeling_mpnet.MPNetForTokenClassification,
 transformers.models.mobilebert.modeling_mobilebert.MobileBertForTokenClassification,
 transformers.models.roberta.modeling_roberta.RobertaForTokenClassification,
 transformers.models.squeezebert.modeling_squeezebert.SqueezeBertForTokenClassification,
 transformers.models.xlm.modeling_xlm.XLMForTokenClassification,
 transformers.models.xlm_roberta.modeling_xlm_roberta.XLMRobertaForTokenClassification,
 transformers.models.xlnet.modeling_xlnet.XLNetForTokenClassification]
{% endraw %} {% raw %}
pretrained_model_names = [
    'albert-base-v1',
    'bert-base-multilingual-cased',
    'camembert-base',
    'distilbert-base-uncased',
    'monologg/electra-small-finetuned-imdb',
    'flaubert/flaubert_small_cased',
    'huggingface/funnel-small-base',
    'allenai/longformer-base-4096',
    'microsoft/mpnet-base',
    'google/mobilebert-uncased',
    'roberta-base',
    'squeezebert/squeezebert-uncased',
    'xlm-mlm-en-2048',
    'xlm-roberta-base',
    'xlnet-base-cased'
]
{% endraw %} {% raw %}
#hide_output
model_cls = AutoModelForTokenClassification
bsz = 2
seq_sz = 128

test_results = []
for model_name in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(model_name, model_cls=model_cls)
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n')
    
    # not all architectures include a native pad_token (e.g., gpt2, ctrl, etc...), so we add one here
    if (hf_tokenizer.pad_token is None): 
        hf_tokenizer.add_special_tokens({'pad_token': '<pad>'})  
        hf_config.pad_token_id = hf_tokenizer.get_vocab()['<pad>']
        hf_model.resize_token_embeddings(len(hf_tokenizer))   
    
    before_batch_tfm = HF_TokenClassBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                         padding='max_length', 
                                                         max_length=seq_sz, 
                                                         is_split_into_words=True, 
                                                         tok_kwargs={ 'return_special_tokens_mask': True })

    blocks = (
        HF_TextBlock(before_batch_tfm=before_batch_tfm, input_return_type=HF_TokenClassInput), 
        HF_TokenCategoryBlock(vocab=labels)
    )

    dblock = DataBlock(blocks=blocks, 
                       get_x=ColReader('tokens'),
                       get_y= lambda inp: [ (label, len(hf_tokenizer.tokenize(str(entity)))) 
                                           for entity, label in zip(inp.tokens, inp.labels) ],
                       splitter=RandomSplitter())
    
    dls = dblock.dataloaders(germ_eval_df, bs=bsz)
    b = dls.one_batch()
    
    try:
        print('*** TESTING DataLoaders ***\n')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, seq_sz]))
        test_eq(len(b[1]), bsz)

        if (hasattr(hf_tokenizer, 'add_prefix_space')):
             test_eq(hf_tokenizer.add_prefix_space, True)
                
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'PASSED', ''))
        dls.show_batch(dataloaders=dls, max_n=2, trunc_at=10)
        
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'FAILED', err))
{% endraw %} {% raw %}
arch tokenizer model_name result error
0 albert AlbertTokenizerFast albert-base-v1 PASSED
1 bert BertTokenizerFast bert-base-multilingual-cased PASSED
2 camembert CamembertTokenizerFast camembert-base PASSED
3 distilbert DistilBertTokenizerFast distilbert-base-uncased PASSED
4 electra ElectraTokenizerFast monologg/electra-small-finetuned-imdb PASSED
5 flaubert FlaubertTokenizer flaubert/flaubert_small_cased PASSED
6 funnel FunnelTokenizerFast huggingface/funnel-small-base PASSED
7 longformer LongformerTokenizerFast allenai/longformer-base-4096 PASSED
8 mpnet MPNetTokenizerFast microsoft/mpnet-base PASSED
9 mobilebert MobileBertTokenizerFast google/mobilebert-uncased PASSED
10 roberta RobertaTokenizerFast roberta-base PASSED
11 squeezebert SqueezeBertTokenizerFast squeezebert/squeezebert-uncased PASSED
12 xlm XLMTokenizer xlm-mlm-en-2048 PASSED
13 xlm_roberta XLMRobertaTokenizerFast xlm-roberta-base PASSED
14 xlnet XLNetTokenizerFast xlnet-base-cased PASSED
{% endraw %}

Cleanup