--- title: data.seq2seq.translation keywords: fastai sidebar: home_sidebar summary: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for translation tasks" description: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for translation tasks" nb_path: "nbs/01zd_data-seq2seq-translation.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
{% endraw %} {% raw %}
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti
{% endraw %} {% raw %}
ds = load_dataset('wmt16', 'de-en', split='train[:1%]')
Reusing dataset wmt16 (/home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/7b2c4443a7d34c2e13df267eaa8cab4c62dd82f6b62b0d9ecc2e3a673ce17308)
{% endraw %}

Translation tokenization, batch transform, and DataBlock methods

Translation tasks attempt to convert text in one language into another

{% raw %}
path = Path('./')
wmt_df = pd.DataFrame(ds['translation'], columns=['de', 'en']); len(wmt_df)
45489
{% endraw %} {% raw %}
wmt_df.head(2)
de en
0 Wiederaufnahme der Sitzungsperiode Resumption of the session
1 Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, daß Sie schöne Ferien hatten. I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.
{% endraw %} {% raw %}
pretrained_model_name = "facebook/bart-large-cnn"
task = HF_TASKS_AUTO.Seq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, task=task)

hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
('bart',
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)
{% endraw %} {% raw %}
blocks = (HF_Seq2SeqBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader('de'), get_y=ColReader('en'), splitter=RandomSplitter())
{% endraw %}

Two lines! Notice we pass in noop for our targets (e.g. our summaries) because the batch transform will take care of both out inputs and targets.

{% raw %}
 
{% endraw %} {% raw %}
dls = dblock.dataloaders(wmt_df, bs=4)
{% endraw %} {% raw %}
b = dls.one_batch()
{% endraw %} {% raw %}
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([4, 483]), torch.Size([4, 132]))
{% endraw %} {% raw %}
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=250, target_trunc_at=250)
text target
0 Was nun die Ergebnisse der Verhandlungen über die Anwendung der Artikel 3, 4, 5, 6 und 12 des Interimsabkommens bezüglich Warenhandel, öffentlicher Aufträge, Wettbewerb, Konsultationsmechanismen bei Fragen des geistigen Eigentums und Beilegung von S Although for certain sectors there may be flaws - I am thinking specifically of the textiles sector, where the rules of origin issue causes great concern - the effects will be beneficial for both the European Union and Mexico. For the European Union
1 Dienlich sind in diesem Zusammenhang die heute und hier diskutierte Mitteilung der Kommission und natürlich der Bericht von Herrn Viceconte - zu dem ich ihn beglückwünsche -, der 34 Schlußfolgerungen enthält, von denen ich die 15 Anträge an die Komm This report contains 34 conclusions, including 15 calls to the Commission to adopt specific measures and actions without forgetting the environmental dimension and cultural protection, particularly in outermost and island regions. It also contains f
{% endraw %}

Tests

The purpose of the following tests is to ensure as much as possible, that the core DataBlock code above works for the pretrained translation models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained translation models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

{% raw %}
BLURR_MODEL_HELPER.get_models(task='ConditionalGeneration')
[transformers.models.bart.modeling_bart.BartForConditionalGeneration,
 transformers.models.blenderbot.modeling_blenderbot.BlenderbotForConditionalGeneration,
 transformers.models.fsmt.modeling_fsmt.FSMTForConditionalGeneration,
 transformers.models.mbart.modeling_mbart.MBartForConditionalGeneration,
 transformers.models.mt5.modeling_mt5.MT5ForConditionalGeneration,
 transformers.models.pegasus.modeling_pegasus.PegasusForConditionalGeneration,
 transformers.models.prophetnet.modeling_prophetnet.ProphetNetForConditionalGeneration,
 transformers.models.t5.modeling_t5.T5ForConditionalGeneration,
 transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration,
 transformers.models.blenderbot.modeling_tf_blenderbot.TFBlenderbotForConditionalGeneration,
 transformers.models.mbart.modeling_tf_mbart.TFMBartForConditionalGeneration,
 transformers.models.mt5.modeling_tf_mt5.TFMT5ForConditionalGeneration,
 transformers.models.pegasus.modeling_tf_pegasus.TFPegasusForConditionalGeneration,
 transformers.models.t5.modeling_tf_t5.TFT5ForConditionalGeneration,
 transformers.models.xlm_prophetnet.modeling_xlm_prophetnet.XLMProphetNetForConditionalGeneration]
{% endraw %} {% raw %}
pretrained_model_names = [
    'facebook/bart-base',
    'facebook/wmt19-de-en',                      # FSMT
    'Helsinki-NLP/opus-mt-de-en',                # MarianMT
    'sshleifer/tiny-mbart',
    'google/mt5-small',
    't5-small'
]
{% endraw %} {% raw %}
path = Path('./')
wmt_df = pd.DataFrame(ds['translation'], columns=['de', 'en'])
{% endraw %} {% raw %}
#hide_output
task = HF_TASKS_AUTO.Seq2SeqLM
bsz = 2
seq_sz = 128
trg_seq_sz = 128

test_results = []
for model_name in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, task=task)
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n')
    
    
    before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                      padding='max_length', 
                                                      max_length=seq_sz,
                                                      max_target_length=trg_seq_sz)
    
    def add_t5_prefix(inp): return f'translate German to English: {inp}' if (hf_arch == 't5') else inp
    
    blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)
    dblock = DataBlock(blocks=blocks, 
                   get_x=Pipeline([ColReader('de'), add_t5_prefix]), 
                   get_y=ColReader('en'), 
                   splitter=RandomSplitter())

    dls = dblock.dataloaders(wmt_df, bs=bsz) 
    b = dls.one_batch()
    
    try:
        print('*** TESTING DataLoaders ***\n')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, seq_sz]))
        test_eq(len(b[1]), bsz)
        test_eq(b[1].shape, torch.Size([bsz, trg_seq_sz]))

        if (hasattr(hf_tokenizer, 'add_prefix_space')):
             test_eq(hf_tokenizer.add_prefix_space, True)
            
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'PASSED', ''))
        dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=1000)
        
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'FAILED', err))
{% endraw %} {% raw %}
arch tokenizer model_name result error
0 bart BartTokenizerFast facebook/bart-base PASSED
1 fsmt FSMTTokenizer facebook/wmt19-de-en PASSED
2 marian MarianTokenizer Helsinki-NLP/opus-mt-de-en PASSED
3 mbart MBartTokenizerFast sshleifer/tiny-mbart PASSED
4 mt5 T5TokenizerFast google/mt5-small PASSED
5 t5 T5TokenizerFast t5-small PASSED
{% endraw %}

Cleanup