--- title: text.modeling.seq2seq.translation keywords: fastai sidebar: home_sidebar summary: "This module contains custom models, custom splitters, etc... translation tasks." description: "This module contains custom models, custom splitters, etc... translation tasks." nb_path: "nbs/22_text-modeling-seq2seq-translation.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
{% endraw %} {% raw %}
What we're running with at the time this documentation was generated:
torch: 1.10.1+cu111
fastai: 2.5.6
transformers: 4.16.2
{% endraw %}

Mid-level API

Prepare the data

Example

The objective in translation is to generate a representation of a given text in another style. For example, we may want to translate German into English or modern English into old English.

{% raw %}
dataset = load_dataset("wmt16", "de-en", split="train")
dataset = dataset.shuffle(seed=32).select(range(1200))
wmt_df = pd.DataFrame(dataset["translation"], columns=["de", "en"])
len(wmt_df)
wmt_df.head(2)
Reusing dataset wmt16 (/home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f)
Loading cached shuffled indices for dataset at /home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f/cache-8fc54b133c8c43b7.arrow
de en
0 Tada se dio stanovništva preselio uz samu obalu - Pristan, gdje je i nastao Novi grad početkom XX vijeka. In that period the majority of the population moved close to the seaside, where the first sea port was founded at the beginning of the 20th century, and later a new city was built.
1 "Dieses Video ist nicht verfügbar loger" bitch, daß das Böse, der sein Video auf YouTube hochgeladen hatte nearsyx? "This video is no loger available" that evil bitch, who had uploaded his video on youtube nearsyx?
{% endraw %} {% raw %}
pretrained_model_name = "Helsinki-NLP/opus-mt-de-en"
model_cls = AutoModelForSeq2SeqLM

hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(pretrained_model_name, model_cls=model_cls)
hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
('marian',
 transformers.models.marian.tokenization_marian.MarianTokenizer,
 transformers.models.marian.configuration_marian.MarianConfig,
 transformers.models.marian.modeling_marian.MarianMTModel)
{% endraw %} {% raw %}
blocks = (Seq2SeqTextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader("de"), get_y=ColReader("en"), splitter=RandomSplitter())
{% endraw %} {% raw %}
dls = dblock.dataloaders(wmt_df, bs=2)
{% endraw %} {% raw %}
b = dls.one_batch()
{% endraw %} {% raw %}
len(b), b[0]["input_ids"].shape, b[1].shape
(2, torch.Size([2, 168]), torch.Size([2, 140]))
{% endraw %} {% raw %}
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=250, target_trunc_at=250)
text target
0 "In▁Erwägung▁nachstehender▁Gründe▁sollte das▁Europäische▁Parlament▁keinerlei▁Doppelmoral tolerieren. Indessen und um▁politischen Druck auf▁Journalisten▁auszuüben, die▁Korruptionsfälle aufdecken, die in▁Verbindung mit▁hochrangigen▁Beamten und▁regieren 'whereas the European Parliament shall not accept double standards; whereas, in order to put political pressure on journalists disclosing corruption cases linked to high-ranking officials and ruling party politicians, the Government administration in
1 Die Oberligamannschaft Empor▁Lauter wurde 1954 nach Rostock▁delegiert und bestritt am▁14. November 1954 vor 17 000▁Zuschauern▁gegen Chemie Karl-Marx-Stadt (0:0) das▁erste Oberligapunktspiel im▁Ostseestadion. Die▁Gründung des FC Hansa Rostock▁fand▁dan Bundesliga, was founded on December 28, 1965, when the football department of SC Empor was made independent of their parent sports club under a government sanctioned program that would groom young talent and provide the East German (DDR) national tea
{% endraw %}

Training

{% raw %}
seq2seq_metrics = {"bleu": {"returns": "bleu"}, "meteor": {"returns": "meteor"}, "sacrebleu": {"returns": "score"}}

model = BaseModelWrapper(hf_model)
learn_cbs = [BaseModelCallback]
fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

learn = Learner(
    dls,
    model,
    opt_func=partial(Adam),
    loss_func=PreCalculatedCrossEntropyLoss(),  # CrossEntropyLossFlat()
    cbs=learn_cbs,
    splitter=partial(blurr_seq2seq_splitter, arch=hf_arch),
)

# learn = learn.to_native_fp16() #.to_fp16()
learn.freeze()
[nltk_data] Downloading package wordnet to /home/wgilliam/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/wgilliam/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/wgilliam/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
{% endraw %} {% raw %}
learn.summary()
{% endraw %} {% raw %}
b = dls.one_batch()
preds = learn.model(b[0])

len(preds), preds["loss"].shape, preds["logits"].shape
(4, torch.Size([]), torch.Size([2, 140, 58101]))
{% endraw %} {% raw %}
len(b), len(b[0]), b[0]["input_ids"].shape, len(b[1]), b[1].shape
(2, 3, torch.Size([2, 168]), 2, torch.Size([2, 140]))
{% endraw %} {% raw %}
print(len(learn.opt.param_groups))
3
{% endraw %} {% raw %}
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])
SuggestedLRs(minimum=6.309573450380412e-08, steep=1.0964781722577754e-06, valley=0.0003981071640737355, slide=1.4454397387453355e-05)
{% endraw %} {% raw %}
learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)
epoch train_loss valid_loss bleu meteor sacrebleu time
0 2.066144 2.020720 0.294544 0.547619 28.638813 00:56
{% endraw %}

Showing results

And here we create a @typedispatched implementation of Learner.show_results.

{% raw %}
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=500)
text target prediction
0 Die Messe Frankfurt hat im▁Jahr 2006 in▁Kooperation mit Stylepark ein▁neues und▁bisher so▁noch nicht▁existierendes Format▁etabliert: The Design Annual.Bei The Design Annual▁handelt es sich um eine▁jährlich▁stattfindende Veranstaltung für High-End Design, die Frankfurt am Main▁auch 2008▁wieder zum Zentrum des▁internationalen Designgeschehens▁werden▁lässt. In der▁Frankfurter Festhalle▁präsentieren die▁beiden Partner die▁besten▁Hersteller des High-End Designsegments▁aus den▁Bereichen▁Möbel und▁Auße In 2006, Messe Frankfurt in cooperation with Stylepark established a new and quite unprecedented format for a trade fair, namely The Design Annual. [Messe Frankfurt established a new format in 2006 in cooperation with Stylepark: The Design Annual.The Design Annual is an annual high-end design event that will once again make Frankfurt the centre of international design in 2008. In the Frankfurt Festhalle, the two partners will present the best manufacturers of the high-end design segment in the areas of furniture and outdoor furniture, textiles and floor coverings, bathroom and kitchens, office furniture, lighting, accessories, cutlery, electrical installations and home entertainment., If there are eggs that are not legally produced after January 2012, these eggs must not be marketed, and if there is evidence of non-compliance with the provision, the Commission could, of course, take all the measures it is entitled to under the current legal framework - and initiate infringement procedures to ensure that EU legislation is properly implemented.]
{% endraw %}

Prediction

We add here Learner.blurr_translate method to bring the results inline with the format returned via Hugging Face's pipeline method

{% raw %}
test_de = "Ich trinke gerne Bier"
{% endraw %} {% raw %}
outputs = learn.blurr_generate(test_de, key="translation_texts", num_return_sequences=3)
outputs
[{'translation_texts': ['I like to drink beer',
   'I like to drink beer.',
   'I like to drink']}]
{% endraw %} {% raw %}

Learner.blurr_translate[source]

Learner.blurr_translate(inp, **kwargs)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
learn.blurr_translate(test_de, num_return_sequences=3)
[{'translation_texts': ['I like to drink beer',
   'I like to drink beer.',
   'I like to drink']}]
{% endraw %}

Inference

Using fast.ai Learner.export and load_learner

{% raw %}
export_fname = "translation_export"
{% endraw %} {% raw %}
learn.metrics = None
learn.export(fname=f"{export_fname}.pkl")
{% endraw %} {% raw %}
inf_learn = load_learner(fname=f"{export_fname}.pkl")
inf_learn.blurr_translate(test_de)
[{'translation_texts': 'I like to drink beer'}]
{% endraw %}

High-level API

{% raw %}

class BlearnerForTranslation[source]

BlearnerForTranslation(dls:DataLoaders, hf_model:PreTrainedModel, base_model_cb:BaseModelCallback=BaseModelCallback, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, cbs=None, metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95)) :: Blearner

Group together a model, some dls and a loss_func to handle training

{% endraw %} {% raw %}
{% endraw %} {% raw %}
learn = BlearnerForTranslation.from_data(
    wmt_df,
    "Helsinki-NLP/opus-mt-de-en",
    src_lang_name="German",
    src_lang_attr="de",
    trg_lang_name="English",
    trg_lang_attr="en",
    dl_kwargs={"bs": 2},
)
{% endraw %} {% raw %}
metrics_cb = BlearnerForTranslation.get_metrics_cb()
learn.fit_one_cycle(1, lr_max=4e-5, cbs=[metrics_cb])
[nltk_data] Downloading package wordnet to /home/wgilliam/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/wgilliam/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/wgilliam/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
epoch train_loss valid_loss bleu meteor sacrebleu time
0 2.060747 2.086900 0.332818 0.555242 32.818199 00:52
{% endraw %} {% raw %}
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 Die▁Zunahme der▁Gewalttätigkeiten▁bei▁nationalen und▁internationalen Sportveranstaltungen▁resultiert nicht▁aus dem▁Fehlen von▁Informationsnetzen und▁ausreichenden▁Unterdrückungsmechanismen,▁sondern▁ist auf die▁Kommerzialisierung des Sports, die▁damit▁zusammenhängenden▁enormen▁wirtschaftlichen▁Interessen, die▁Förderung▁eines desorientierenden'sportlichen'▁Geistes des Fanatismus (Fußballrowdytums)▁sowie die Propagierung▁einer▁Psychologie der▁Gewalt▁insbesondere▁unter den▁Jugendlichen▁zurückzuführe The increase in violent clashes at national and international sporting events is not due to a lack of information networks or adequate suppression mechanisms; it is due to the commercialisation of sport, the huge financial interests tied up in it, th [The increase in violence at national and international sporting events is not due to the lack of information networks and adequate mechanisms of repression, but to the commercialisation of sport, the associated enormous economic interests, the promotion of a disorienting 'sporty' spirit of fanaticism (football rovdytums) and the promotion of a psychology of violence, especially among young people., Perhaps, however, you would like to ask someone - as Mr Corbett will always be in such cases - to look at our Rules of Procedure, because it is rather strange in this respect: if more than 50 amendments are tabled in plenary on a report, the President, after consultation with the chairman, can ask the committee responsible to hold a meeting to consider the amendments.]
{% endraw %} {% raw %}
test_de = "Ich trinke gerne Bier"
{% endraw %} {% raw %}
learn.blurr_translate(test_de)
[{'translation_texts': 'I like to drink beer'}]
{% endraw %} {% raw %}
export_fname = "translation_export"

learn.metrics = None
learn = learn.to_fp32()
learn.export(fname=f"{export_fname}.pkl")

inf_learn = load_learner(fname=f"{export_fname}.pkl")
inf_learn.blurr_generate(test_de)
[{'generated_texts': 'I like to drink beer'}]
{% endraw %}

Tests

The purpose of the following tests is to ensure as much as possible, that the core training code works for the pretrained translation models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained translation models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

{% raw %}
try:
    del learn
    torch.cuda.empty_cache()
except:
    pass
{% endraw %} {% raw %}
[model_type for model_type in NLP.get_models(task="ConditionalGeneration") if (not model_type.startswith("TF"))]
['BartForConditionalGeneration',
 'BigBirdPegasusForConditionalGeneration',
 'BlenderbotForConditionalGeneration',
 'BlenderbotSmallForConditionalGeneration',
 'FSMTForConditionalGeneration',
 'LEDForConditionalGeneration',
 'M2M100ForConditionalGeneration',
 'MBartForConditionalGeneration',
 'MT5ForConditionalGeneration',
 'PegasusForConditionalGeneration',
 'ProphetNetForConditionalGeneration',
 'Speech2TextForConditionalGeneration',
 'T5ForConditionalGeneration',
 'XLMProphetNetForConditionalGeneration']
{% endraw %} {% raw %}
pretrained_model_names = [
    "facebook/bart-base",
    "facebook/wmt19-de-en",  # FSMT
    "Helsinki-NLP/opus-mt-de-en",  # MarianMT
    #'sshleifer/tiny-mbart',
    #'google/mt5-small',
    "t5-small",
]
{% endraw %} {% raw %}
dataset = load_dataset("wmt16", "de-en", split="train")
dataset = dataset.shuffle(seed=32).select(range(1200))
wmt_df = pd.DataFrame(dataset["translation"], columns=["de", "en"])
len(wmt_df)
Reusing dataset wmt16 (/home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f)
Loading cached shuffled indices for dataset at /home/wgilliam/.cache/huggingface/datasets/wmt16/de-en/1.0.0/af3c5d746b307726d0de73ebe7f10545361b9cb6f75c83a1734c000e48b6264f/cache-8fc54b133c8c43b7.arrow
1200
{% endraw %} {% raw %}
# hide_output
model_cls = AutoModelForSeq2SeqLM
bsz = 2
inp_seq_sz = 128
trg_seq_sz = 128

test_results = []
for model_name in pretrained_model_names:
    error = None

    print(f"=== {model_name} ===\n")

    hf_tok_kwargs = {}
    if model_name == "sshleifer/tiny-mbart":
        hf_tok_kwargs["src_lang"], hf_tok_kwargs["tgt_lang"] = "de_DE", "en_XX"

    hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(model_name, model_cls=model_cls, tokenizer_kwargs=hf_tok_kwargs)

    print(f"architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\nmodel:\t\t{type(hf_model).__name__}\n")

    # 1. build your DataBlock
    text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task="translation")

    def add_t5_prefix(inp):
        return f"translate German to English: {inp}" if (hf_arch == "t5") else inp

    batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
        hf_arch,
        hf_config,
        hf_tokenizer,
        hf_model,
        padding="max_length",
        max_length=inp_seq_sz,
        max_target_length=trg_seq_sz,
        text_gen_kwargs=text_gen_kwargs,
    )

    blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=batch_tokenize_tfm), noop)
    dblock = DataBlock(blocks=blocks, get_x=Pipeline([ColReader("de"), add_t5_prefix]), get_y=ColReader("en"), splitter=RandomSplitter())

    dls = dblock.dataloaders(wmt_df, bs=bsz)
    b = dls.one_batch()

    # 2. build your Learner
    seq2seq_metrics = {}

    model = BaseModelWrapper(hf_model)
    fit_cbs = [ShortEpochCallback(0.05, short_valid=True), Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

    learn = Learner(
        dls,
        model,
        opt_func=ranger,
        loss_func=PreCalculatedCrossEntropyLoss(),
        cbs=[BaseModelCallback],
        splitter=partial(blurr_seq2seq_splitter, arch=hf_arch),
    ).to_fp16()

    learn.create_opt()
    learn.freeze()

    # 3. Run your tests
    try:
        print("*** TESTING DataLoaders ***\n")
        test_eq(len(b), 2)
        test_eq(len(b[0]["input_ids"]), bsz)
        test_eq(b[0]["input_ids"].shape, torch.Size([bsz, inp_seq_sz]))
        test_eq(len(b[1]), bsz)

        #         print('*** TESTING One pass through the model ***')
        #         preds = learn.model(b[0])
        #         test_eq(preds[1].shape[0], bsz)
        #         test_eq(preds[1].shape[2], hf_config.vocab_size)

        print("*** TESTING Training/Results ***")
        learn.fit_one_cycle(1, lr_max=1e-3, cbs=fit_cbs)

        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, "PASSED", ""))
        learn.show_results(learner=learn, max_n=2, input_trunc_at=500, target_trunc_at=250)
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, "FAILED", err))
    finally:
        # cleanup
        del learn
        torch.cuda.empty_cache()
{% endraw %} {% raw %}
arch tokenizer model_name result error
0 bart BartTokenizerFast BartForConditionalGeneration PASSED
1 fsmt FSMTTokenizer FSMTForConditionalGeneration PASSED
2 marian MarianTokenizer MarianMTModel PASSED
3 t5 T5TokenizerFast T5ForConditionalGeneration PASSED
{% endraw %}