--- title: text.data.seq2seq.summarization keywords: fastai sidebar: home_sidebar summary: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for summarization tasks using architectures like BART and T5. Summarization tasks attempt to generate a human-understandable and sensible representation of a larger body of text (e.g., capture the meaning of a larger document in 1-3 sentences)." description: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for summarization tasks using architectures like BART and T5. Summarization tasks attempt to generate a human-understandable and sensible representation of a larger body of text (e.g., capture the meaning of a larger document in 1-3 sentences)." nb_path: "nbs/21_text-data-seq2seq-summarization.ipynb" ---
raw_datasets = load_dataset("cnn_dailymail", "3.0.0", split=["train", "validation"])
raw_datasets
print(raw_datasets[0][0].keys())
print(raw_datasets[0][0]["highlights"])
print(raw_datasets[1][0].keys())
print(raw_datasets[1][0]["highlights"])
raw_train_ds = raw_datasets[0].shuffle(seed=42).select(range(1000))
raw_valid_ds = raw_datasets[1].shuffle(seed=42).select(range(200))
len(raw_train_ds) + len(raw_valid_ds)
raw_train_df = pd.DataFrame(raw_train_ds)
raw_valid_df = pd.DataFrame(raw_valid_ds)
raw_train_df.head(2)
pretrained_model_name = "sshleifer/distilbart-cnn-6-6"
model_cls = AutoModelForSeq2SeqLM
hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(pretrained_model_name, model_cls=model_cls)
hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
This class can be used for preprocessing summarization tasks, and includes a proc_{your_text_attr}
and proc_{target_text_attr}
attributes containing your modified input and target texts as a result of tokenization (e.g., if you specify a max_length
the proc_{your_text_attr}
may contain truncated text).
preprocessor = SummarizationPreprocessor(
hf_tokenizer,
id_attr="id",
text_attr="article",
target_text_attr="highlights",
max_input_tok_length=128,
max_target_tok_length=30,
min_summary_char_length=10,
)
proc_df = preprocessor.process_df(raw_train_df, raw_valid_df)
proc_df.columns, len(proc_df)
proc_df.head(2)
pretrained_model_name = "facebook/bart-large-cnn"
model_cls = AutoModelForSeq2SeqLM
hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(pretrained_model_name, model_cls=model_cls)
blocks = (Seq2SeqTextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader("article"), get_y=ColReader("highlights"), splitter=RandomSplitter())
dls = dblock.dataloaders(raw_train_df, bs=4)
b = dls.one_batch()
len(b), b[0]["input_ids"].shape, b[0]["labels"].shape, b[1].shape
b[0]["labels"][0], b[1][0]
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=1000, target_trunc_at=250)
pretrained_model_name = "sshleifer/distilbart-cnn-6-6"
model_cls = AutoModelForSeq2SeqLM
hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(pretrained_model_name, model_cls=model_cls)
preprocessor = SummarizationPreprocessor(
hf_tokenizer,
id_attr="id",
text_attr="article",
target_text_attr="highlights",
max_input_tok_length=128,
max_target_tok_length=30,
min_summary_char_length=10,
)
proc_df = preprocessor.process_df(raw_train_df, raw_valid_df)
blocks = (Seq2SeqTextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader("proc_article"), get_y=ColReader("proc_highlights"), splitter=ColSplitter())
dls = dblock.dataloaders(proc_df, bs=4)
dls.show_batch(dataloaders=dls, max_n=2, trunc_at=500)
The purpose of the following tests is to ensure as much as possible, that the core DataBlock code above works for the pretrained summarization models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.
Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)
[model_type for model_type in NLP.get_models(task="ConditionalGeneration") if (not model_type.startswith("TF"))]
pretrained_model_names = [
"facebook/bart-base",
"facebook/blenderbot_small-90M",
"allenai/led-base-16384",
"google/mt5-small",
"google/pegasus-cnn_dailymail",
"t5-small",
"microsoft/prophetnet-large-uncased",
"microsoft/xprophetnet-large-wiki100-cased", # XLMProphetNet
]
path = Path("./")
cnndm_df = pd.read_csv(path / "cnndm_sample.csv")
# hide_output
model_cls = AutoModelForSeq2SeqLM
bsz = 2
seq_sz = 256
trg_seq_sz = 40
test_results = []
for model_name in pretrained_model_names:
error = None
print(f"=== {model_name} ===\n")
hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(model_name, model_cls=model_cls)
print(f"architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n")
# not all architectures include a native pad_token (e.g., gpt2, ctrl, etc...), so we add one here
if hf_tokenizer.pad_token is None:
hf_tokenizer.add_special_tokens({"pad_token": "<pad>"})
hf_config.pad_token_id = hf_tokenizer.get_vocab()["<pad>"]
hf_model.resize_token_embeddings(len(hf_tokenizer))
batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
hf_arch, hf_config, hf_tokenizer, hf_model, padding="max_length", max_length=seq_sz, max_target_length=trg_seq_sz
)
def add_t5_prefix(inp):
return f"summarize: {inp}" if (hf_arch == "t5") else inp
blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=batch_tokenize_tfm), noop)
dblock = DataBlock(
blocks=blocks, get_x=Pipeline([ColReader("article"), add_t5_prefix]), get_y=ColReader("highlights"), splitter=RandomSplitter()
)
dls = dblock.dataloaders(cnndm_df, bs=bsz)
b = dls.one_batch()
try:
print("*** TESTING DataLoaders ***\n")
test_eq(len(b), 2)
test_eq(len(b[0]["input_ids"]), bsz)
test_eq(b[0]["input_ids"].shape, torch.Size([bsz, seq_sz]))
test_eq(len(b[1]), bsz)
test_eq(b[1].shape, torch.Size([bsz, trg_seq_sz]))
if hasattr(hf_tokenizer, "add_prefix_space") and hf_arch not in ["led"]:
test_eq(hf_tokenizer.add_prefix_space, True)
test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, "PASSED", ""))
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=1000)
except Exception as err:
test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, "FAILED", err))