--- title: text.data.seq2seq.translation keywords: fastai sidebar: home_sidebar summary: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for translation tasks" description: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for translation tasks" nb_path: "nbs/22_text-data-seq2seq-translation.ipynb" ---
raw_dataset = load_dataset("wmt16", "de-en", split="train[:1%]")
raw_dataset
print(raw_dataset[0].keys())
print(raw_dataset[0])
wmt_df = pd.DataFrame(raw_dataset["translation"], columns=["de", "en"])
print(len(wmt_df))
wmt_df.head(2)
pretrained_model_name = "facebook/bart-large-cnn"
model_cls = AutoModelForSeq2SeqLM
hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(pretrained_model_name, model_cls=model_cls)
hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
This class can be used for preprocessing translation tasks, and includes a proc_{your_text_attr}
and proc_{target_text_attr}
attributes containing your modified input and target texts as a result of tokenization (e.g., if you specify a max_length
the proc_{your_text_attr}
may contain truncated text).
preprocessor = TranslationPreprocessor(
hf_tokenizer, text_attr="de", target_text_attr="en", max_input_tok_length=128, max_target_tok_length=128
)
proc_df = preprocessor.process_df(wmt_df)
proc_df.columns, len(proc_df)
proc_df.head(2)
pretrained_model_name = "facebook/bart-large-cnn"
model_cls = AutoModelForSeq2SeqLM
hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(pretrained_model_name, model_cls=model_cls)
blocks = (Seq2SeqTextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader("de"), get_y=ColReader("en"), splitter=RandomSplitter())
dls = dblock.dataloaders(wmt_df, bs=4)
b = dls.one_batch()
len(b), b[0]["input_ids"].shape, b[0]["labels"].shape, b[1].shape
b[0]["labels"][0], b[1][0]
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=250, target_trunc_at=250)
pretrained_model_name = "facebook/bart-large-cnn"
model_cls = AutoModelForSeq2SeqLM
hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(pretrained_model_name, model_cls=model_cls)
preprocessor = TranslationPreprocessor(
hf_tokenizer, text_attr="de", target_text_attr="en", max_input_tok_length=128, max_target_tok_length=128
)
proc_df = preprocessor.process_df(wmt_df)
blocks = (Seq2SeqTextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader("proc_de"), get_y=ColReader("proc_en"), splitter=RandomSplitter())
dls = dblock.dataloaders(proc_df, bs=4)
b = dls.one_batch()
len(b), b[0]["input_ids"].shape, b[0]["labels"].shape, b[1].shape
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=250, target_trunc_at=250)
The purpose of the following tests is to ensure as much as possible, that the core DataBlock code above works for the pretrained translation models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.
Note: Feel free to modify the code below to test whatever pretrained translation models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)
[model_type for model_type in NLP.get_models(task="ConditionalGeneration") if (not model_type.startswith("TF"))]
pretrained_model_names = [
"facebook/bart-base",
"facebook/wmt19-de-en", # FSMT
"Helsinki-NLP/opus-mt-de-en", # MarianMT
"sshleifer/tiny-mbart",
"google/mt5-small",
"t5-small",
]
path = Path("./")
wmt_df = pd.DataFrame(raw_dataset["translation"], columns=["de", "en"])
# hide_output
model_cls = AutoModelForSeq2SeqLM
bsz = 2
seq_sz = 128
trg_seq_sz = 128
test_results = []
for model_name in pretrained_model_names:
error = None
print(f"=== {model_name} ===\n")
hf_tok_kwargs = {}
if model_name == "sshleifer/tiny-mbart":
hf_tok_kwargs["src_lang"], hf_tok_kwargs["tgt_lang"] = "de_DE", "en_XX"
hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(model_name, model_cls=model_cls, tokenizer_kwargs=hf_tok_kwargs)
print(f"architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n")
# not all architectures include a native pad_token (e.g., gpt2, ctrl, etc...), so we add one here
if hf_tokenizer.pad_token is None:
hf_tokenizer.add_special_tokens({"pad_token": "<pad>"})
hf_config.pad_token_id = hf_tokenizer.get_vocab()["<pad>"]
hf_model.resize_token_embeddings(len(hf_tokenizer))
batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
hf_arch, hf_config, hf_tokenizer, hf_model, padding="max_length", max_length=seq_sz, max_target_length=trg_seq_sz
)
def add_t5_prefix(inp):
return f"translate German to English: {inp}" if (hf_arch == "t5") else inp
blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=batch_tokenize_tfm), noop)
dblock = DataBlock(blocks=blocks, get_x=Pipeline([ColReader("de"), add_t5_prefix]), get_y=ColReader("en"), splitter=RandomSplitter())
dls = dblock.dataloaders(wmt_df, bs=bsz)
b = dls.one_batch()
try:
print("*** TESTING DataLoaders ***\n")
test_eq(len(b), 2)
test_eq(len(b[0]["input_ids"]), bsz)
test_eq(b[0]["input_ids"].shape, torch.Size([bsz, seq_sz]))
test_eq(len(b[1]), bsz)
test_eq(b[1].shape, torch.Size([bsz, trg_seq_sz]))
if hasattr(hf_tokenizer, "add_prefix_space"):
test_eq(hf_tokenizer.add_prefix_space, True)
test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, "PASSED", ""))
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=1000)
except Exception as err:
test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, "FAILED", err))