--- title: text.modeling.seq2seq.summarization keywords: fastai sidebar: home_sidebar summary: "This module contains custom models, custom splitters, etc... summarization tasks." description: "This module contains custom models, custom splitters, etc... summarization tasks." nb_path: "nbs/21_text-modeling-seq2seq-summarization.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
{% endraw %} {% raw %}
What we're running with at the time this documentation was generated:
torch: 1.10.1+cu111
fastai: 2.5.6
transformers: 4.16.2
{% endraw %}

Mid-level API

Example

The objective of summarization is to generate a concise and accurate representation of a much larger body of text. For example, we may want to summarize an article in a single sentence.

{% raw %}
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:1000]")
cnndm_df = pd.DataFrame(dataset)
cnndm_df.head(2)
Reusing dataset cnn_dailymail (/home/wgilliam/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234)
article highlights id
0 It's official: U.S. President Barack Obama wants lawmakers to weigh in on whether to use military force in Syria. Obama sent a letter to the heads of the House and Senate on Saturday night, hours after announcing that he believes military action against Syrian targets is the right step to take over the alleged use of chemical weapons. The proposed legislation from Obama asks Congress to approve the use of military force "to deter, disrupt, prevent and degrade the potential for future uses of chemical weapons or other weapons of mass destruction." It's a step that is set to turn an internat... Syrian official: Obama climbed to the top of the tree, "doesn't know how to get down"\nObama sends a letter to the heads of the House and Senate .\nObama to seek congressional approval on military action against Syria .\nAim is to determine whether CW were used, not by whom, says U.N. spokesman . 0001d1afc246a7964130f43ae940af6bc6c57f01
1 (CNN) -- Usain Bolt rounded off the world championships Sunday by claiming his third gold in Moscow as he anchored Jamaica to victory in the men's 4x100m relay. The fastest man in the world charged clear of United States rival Justin Gatlin as the Jamaican quartet of Nesta Carter, Kemar Bailey-Cole, Nickel Ashmeade and Bolt won in 37.36 seconds. The U.S finished second in 37.56 seconds with Canada taking the bronze after Britain were disqualified for a faulty handover. The 26-year-old Bolt has now collected eight gold medals at world championships, equaling the record held by American trio... Usain Bolt wins third gold of world championship .\nAnchors Jamaica to 4x100m relay victory .\nEighth gold at the championships for Bolt .\nJamaica double up in women's 4x100m relay . 0002095e55fcbd3a2f366d9bf92a95433dc305ef
{% endraw %} {% raw %}
pretrained_model_name = "sshleifer/distilbart-cnn-6-6"
hf_arch, hf_config, hf_tokenizer, hf_model = NLP.get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('bart',
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)
{% endraw %} {% raw %}
text_gen_kwargs = {}
if hf_arch in ["bart", "t5"]:
    text_gen_kwargs = {**hf_config.task_specific_params["summarization"], **{"max_length": 30, "min_length": 10}}

# not all "summarization" parameters are for the model.generate method ... remove them here
generate_func_args = list(inspect.signature(hf_model.generate).parameters.keys())
for k in text_gen_kwargs.copy():
    if k not in generate_func_args:
        del text_gen_kwargs[k]

if hf_arch == "mbart":
    text_gen_kwargs["decoder_start_token_id"] = hf_tokenizer.get_vocab()["en_XX"]
{% endraw %} {% raw %}
tok_kwargs = {}
if hf_arch == "mbart":
    tok_kwargs["src_lang"], tok_kwargs["tgt_lang"] = "en_XX", "en_XX"
{% endraw %} {% raw %}
batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
    hf_arch,
    hf_config,
    hf_tokenizer,
    hf_model,
    max_length=256,
    max_target_length=130,
    tok_kwargs=tok_kwargs,
    text_gen_kwargs=text_gen_kwargs,
)

blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=batch_tokenize_tfm), noop)

dblock = DataBlock(blocks=blocks, get_x=ColReader("article"), get_y=ColReader("highlights"), splitter=RandomSplitter())
{% endraw %} {% raw %}
dls = dblock.dataloaders(cnndm_df, bs=2)
{% endraw %} {% raw %}
b = dls.one_batch()
{% endraw %} {% raw %}
len(b), b[0]["input_ids"].shape, b[1].shape
(2, torch.Size([2, 256]), torch.Size([2, 75]))
{% endraw %} {% raw %}
dls.show_batch(dataloaders=dls, max_n=2)
text target
0 <s> Washington (CNN) -- New details emerged of what the White House knew about the Internal Revenue Service targeting of conservative groups, with spokesman Jay Carney disclosing Chief of Staff Denis McDonough was among the top officials made aware of the matter late last month. In a new timeline provided by Carney to reporters on Monday, General Counsel Kathryn Ruemmler learned on April 24 of a pending Treasury inspector general's report on how IRS staff used criteria targeting conservative groups in assessing eligibility for tax-exempt status. According to Carney, Ruemmler told McDonough as well as other Treasury officials about the pending report. It was the first time the White House acknowledged that McDonough was aware of the report before it became public in early May. IRS: By the numbers. In addition, Carney made clear that the information Ruemmler received on April 24 included details of improper acts by IRS officials. At the same time, Carney emphasized that the information was preliminary and could have changed before the inspector general released his final report on May 14. Carney insisted no one -- including Ruemmler and McDonough -- told President Barack Obama anything about the inspector general's pending report before media reports about it began appearing on May 10. "We knew the subject of the investigation</s> A Senate committee holds a hearing Tuesday on the IRS targeting.\nWhite House discloses new details of what it knew about the IRS targeting report.\nWhite House spokesman says President Obama wasn't told of the pending report.\nNEW: First lawsuit filed over IRS targeting.
1 <s> (CNN) -- Last year we published a list of quintessential Americana experiences. You can find it here. They weren't necessarily the most patriotic, obvious or agreeable choices. NASCAR, bourbon, state fairs, Vegas, what's not to love? Apparently, plenty. There was scandal. There was outrage. There was name calling. Because we're gluttons for punishment -- or maybe just because we think we actually can please all of the people all of the time -- we're back for round two. Here's our Volume II of the most authentically American experiences this country has to offer. 1. Seaside boardwalks. Boardwalks have been enhancing beachside amusement since long before the Drifters' released their classic "Under the Boardwalk" in 1964. The first boardwalk was built in Atlantic City in 1870, when a railroad conductor was asked to find a way to prevent sand from filling shorefront hotel entryways. The innovation remains America's favorite wooden path, showing up everywhere from Monopoly, which was inspired by "America's Favorite Playground," to the HBO series "Boardwalk Empire," which takes place in Prohibition-era Atlantic City. Of course, you don't have to travel to Jersey to experience the joy of</s> Pueblos and powwows in New Mexico highlight Native American culture.\nBaseball, football and the Derby made the cut -- yes, we love sports.\nWhat would the country be without soul food?\nWe tried to resist including Burning Man, but failed.
{% endraw %}

Training

{% raw %}
seq2seq_metrics = {
    "rouge": {
        "compute_kwargs": {"rouge_types": ["rouge1", "rouge2", "rougeL", "rougeLsum"], "use_stemmer": True},
        "returns": ["rouge1", "rouge2", "rougeL", "rougeLsum"],
    },
    "bertscore": {"compute_kwargs": {"lang": "en"}, "returns": ["precision", "recall", "f1"]},
}
{% endraw %} {% raw %}
model = BaseModelWrapper(hf_model)
learn_cbs = [BaseModelCallback]
fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

learn = Learner(
    dls,
    model,
    opt_func=partial(Adam),
    loss_func=CrossEntropyLossFlat(),
    cbs=learn_cbs,
    splitter=partial(blurr_seq2seq_splitter, arch=hf_arch),
)

# learn = learn.to_native_fp16() #.to_fp16()
learn.freeze()
{% endraw %} {% raw %}
learn.summary()
{% endraw %} {% raw %}
b = dls.one_batch()
preds = learn.model(b[0])

len(preds), preds["loss"].shape, preds["logits"].shape
(4, torch.Size([]), torch.Size([2, 67, 50264]))
{% endraw %} {% raw %}
len(b), len(b[0]), b[0]["input_ids"].shape, len(b[1]), b[1].shape
(2, 3, torch.Size([2, 256]), 2, torch.Size([2, 67]))
{% endraw %} {% raw %}
print(len(learn.opt.param_groups))
3
{% endraw %} {% raw %}
learn.lr_find(suggest_funcs=[minimum, steep, valley, slide])
SuggestedLRs(minimum=2.7542287716642023e-05, steep=1.9054607491852948e-06, valley=5.248074739938602e-05, slide=1.737800812406931e-05)
{% endraw %} {% raw %}
learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)
epoch train_loss valid_loss rouge1 rouge2 rougeL rougeLsum bertscore_precision bertscore_recall bertscore_f1 time
0 2.186599 2.114544 0.303138 0.128061 0.225260 0.282693 0.889253 0.862950 0.875796 01:16
{% endraw %}

Showing results

And here we create a @typedispatched implementation of Learner.show_results.

{% raw %}
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 (CNN) -- When Ji Yeqing awakened, she was already in the recovery room. Chinese authorities had dragged her out of her home and down four flights of stairs, she said, restraining and beating her husband as he tried to come to her aid. They whisked her into a clinic, held her down on a bed and forced her to undergo an abortion. Her offense? Becoming pregnant with a second child, in violation of China's one-child policy. "After the abortion, I felt empty, as if something was scooped out of me," J China's one-child policy results in forced abortions and sterilizations, activists say.\nWomen tell of emotional and physical consequences from the procedures.\nActivist Chen Guangcheng works to advocate for victims of such practices. [ Ji Yeqing says she was forced to have an abortion after becoming pregnant with a second child .\nShe says she felt, Gay baby boomers are finding more acceptance in mainstream society .\nThey are pushing to make a better world for the LGBT community .\n]
{% endraw %}

Prediction

We add here Learner.blurr_summarize method to bring the results inline with the format returned via Hugging Face's pipeline method

{% raw %}
test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off 
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. 
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino 
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level 
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group 
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the 
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was 
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat 
her, and took off for the French border. The other gunmen followed into France, which is only about 100 
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. 
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the 
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, 
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan 
contributed to this report.
"""
{% endraw %} {% raw %}
outputs = learn.blurr_generate(test_article, key="summary_texts", num_return_sequences=3)
outputs
[{'summary_texts': [' 10 men raid Swiss casino in early hours of Sunday morning .\nOne group tried to break into the vault on the lower level but could not get in .\nThe second group robbed the cashier of the money that was not secured .\nThere were about 600 people in the casino at the time of the robbery .',
   ' 10 men raid Swiss casino in early hours of Sunday morning .\nOne group tried to break into the vault on the lower level but could not get in .\nThe second group robbed the cashier of the money that was not secured .\nThere were no serious injuries, although one guest was kicked in the head .',
   ' 10 men raid Swiss casino in early hours of Sunday morning .\nOne group tried to break into the vault on the lower level but could not get in .\nThe second group robbed the cashier of the money that was not secured .\nThere were about 600 people in the casino at the time .']}]
{% endraw %} {% raw %}

Learner.blurr_summarize[source]

Learner.blurr_summarize(inp, **kwargs)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
learn.blurr_summarize(test_article, num_return_sequences=3)
[{'summary_texts': [' 10 men raid Swiss casino in early hours of Sunday morning .\nOne group tried to break into the vault on the lower level but could not get in .\nThe second group robbed the cashier of the money that was not secured .\nThere were about 600 people in the casino at the time of the robbery .',
   ' 10 men raid Swiss casino in early hours of Sunday morning .\nOne group tried to break into the vault on the lower level but could not get in .\nThe second group robbed the cashier of the money that was not secured .\nThere were no serious injuries, although one guest was kicked in the head .',
   ' 10 men raid Swiss casino in early hours of Sunday morning .\nOne group tried to break into the vault on the lower level but could not get in .\nThe second group robbed the cashier of the money that was not secured .\nThere were about 600 people in the casino at the time .']}]
{% endraw %}

Inference

Using fast.ai Learner.export and load_learner

{% raw %}
export_fname = "summarize_export"
{% endraw %} {% raw %}
learn.metrics = None
learn.export(fname=f"{export_fname}.pkl")
{% endraw %} {% raw %}
inf_learn = load_learner(fname=f"{export_fname}.pkl")
inf_learn.blurr_summarize(test_article)
[{'summary_texts': ' 10 men raid Swiss casino in early hours of Sunday morning .\nOne group tried to break into the vault on the lower level .\n'}]
{% endraw %}

High-level API

{% raw %}

class BlearnerForSummarization[source]

BlearnerForSummarization(dls:DataLoaders, hf_model:PreTrainedModel, base_model_cb:BaseModelCallback=BaseModelCallback, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, cbs=None, metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95)) :: Blearner

Group together a model, some dls and a loss_func to handle training

{% endraw %} {% raw %}
{% endraw %}

Example

{% raw %}
learn = BlearnerForSummarization.from_data(
    cnndm_df,
    "sshleifer/distilbart-cnn-6-6",
    text_attr="article",
    summary_attr="highlights",
    max_length=256,
    max_target_length=130,
    dblock_splitter=RandomSplitter(),
    dl_kwargs={"bs": 2},
).to_fp16()
{% endraw %} {% raw %}
learn.fit_one_cycle(1, lr_max=4e-5, cbs=[BlearnerForSummarization.get_metrics_cb()])
epoch train_loss valid_loss rouge1 rouge2 rougeL rougeLsum bertscore_precision bertscore_recall bertscore_f1 time
0 2.215153 2.136206 0.358332 0.138019 0.237203 0.332453 0.875117 0.883866 0.879403 02:44
{% endraw %} {% raw %}
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 (CNN) -- The generation of gays and lesbians that literally created the modern LGBT movement -- from the heroes of the 1969 Stonewall riots to their slightly younger friends -- is at, or nearing, retirement age. That used to mean the beginning of an extremely difficult time in an LGBT person's life. But as gay baby boomers find more acceptance in mainstream society and continue to do what they've always done -- push to make a better world for the LGBT community -- their retirement options are s LGBT baby boomers changed the visibility of the gay community.\nAs they approach retirement, they face different obstacles than their straight counterparts.\nWithout marriage equality, same-sex couples may face financial hardships.\nAdvocates say the s [ Gay baby boomers are pushing to make a better world for LGBT community .\nBob Witeck, CEO and co-founder of a communications firm, plans to keep working .\nWiteck: "The notion of retirement has never been a part of my vocabulary", Mitt Romney's gaffes reinforce a sitcom-like caricature of the candidate, expert says .\nHe says he was trying to point out differences between his and Obama's campaigns .\nRomney says his comments were "off the cuff" and "not elegantly stated"\nRomney has said he could have made them "more clearly"]
{% endraw %} {% raw %}
test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off 
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. 
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino 
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level 
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group 
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the 
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was 
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat 
her, and took off for the French border. The other gunmen followed into France, which is only about 100 
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. 
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the 
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, 
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan 
contributed to this report.
"""
{% endraw %} {% raw %}
learn.predict(test_article, num_return_sequences=3)
[{'summary_texts': [' 10 men armed with pistols and small machine guns raided a Swiss casino .\nThey robbed several hundred thousand Swiss francs in the early hours of Sunday morning .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino .\nThere were no serious injuries, although one guest on the casino floor was kicked in the head by one of the robbers when he moved .',
   ' 10 men armed with pistols and small machine guns raid a Swiss casino .\nThey robbed several hundred thousand Swiss francs in the early hours of Sunday morning .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .\nThere were no serious injuries, although one guest on the casino floor was kicked in the head by one of the robbers .',
   ' 10 men armed with pistols and small machine guns raided a Swiss casino .\nThey robbed several hundred thousand Swiss francs in the early hours of Sunday morning .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino .\nThere were no serious injuries, although one guest on the casino floor was kicked in the head by one of the robbers .']}]
{% endraw %} {% raw %}
export_fname = "summarize_export"

learn.metrics = None
learn = learn.to_fp32()
learn.export(fname=f"{export_fname}.pkl")

inf_learn = load_learner(fname=f"{export_fname}.pkl")
inf_learn.blurr_summarize(test_article)
[{'summary_texts': ' 10 men armed with pistols and small machine guns raided a Swiss casino .\nThey robbed several hundred thousand Swiss francs in the early hours of Sunday morning .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino .\nThere were no serious injuries, although one guest on the casino floor was kicked in the head by one of the robbers when he moved .'}]
{% endraw %}

Tests

The purpose of the following tests is to ensure as much as possible, that the core training code works for the pretrained summarization models below. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

{% raw %}
arch tokenizer model_name result error
0 bart BartTokenizerFast BartForConditionalGeneration PASSED
1 led LEDTokenizerFast LEDForConditionalGeneration PASSED
2 mbart MBartTokenizerFast MBartForConditionalGeneration PASSED
3 mt5 T5TokenizerFast MT5ForConditionalGeneration PASSED
4 pegasus PegasusTokenizerFast PegasusForConditionalGeneration PASSED
5 t5 T5TokenizerFast T5ForConditionalGeneration PASSED
{% endraw %}