--- title: text.data.seq2seq.core keywords: fastai sidebar: home_sidebar summary: "This module contains the core seq2seq (e.g., language modeling, summarization, translation) bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data in a way modelable by Hugging Face transformer implementations." description: "This module contains the core seq2seq (e.g., language modeling, summarization, translation) bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data in a way modelable by Hugging Face transformer implementations." nb_path: "nbs/20_text-data-seq2seq-core.ipynb" ---
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration)
hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
A Seq2SeqTextInput
object is returned from the decodes method of Seq2SeqBatchTokenizeTransform
as a means to customize @typedispatch
ed functions like DataLoaders.show_batch
and Learner.show_results
. The value will the your "input_ids".
We create a subclass of BatchTokenizeTransform
for summarization tasks to add decoder_input_ids
and labels
(if we want Hugging Face to calculate the loss for us) to our inputs during training. See here and here for more information on these additional inputs used in summarization, translation, and conversational training tasks. How they should look for particular architectures can be found by looking at those model's forward
function's docs (See here for BART for example)
Note also that labels
is simply target_ids shifted to the right by one since the task to is to predict the next token based on the current (and all previous) decoder_input_ids
.
And lastly, we also update our targets to just be the input_ids
of our target sequence so that fastai's Learner.show_results
works (again, almost all the fastai bits require returning a single tensor to work).
default_text_gen_kwargs(hf_config, hf_model)