--- title: modeling.summarization keywords: fastai sidebar: home_sidebar summary: "This module contains custom models, loss functions, custom splitters, etc... summarization tasks." description: "This module contains custom models, loss functions, custom splitters, etc... summarization tasks." nb_path: "nbs/02e_modeling-summarization.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti
{% endraw %}

Summarization

The objective of summarization is to generate a concise and accurate representation of a much larger body of text. For example, we may want to summarize an article in a single sentence.

{% raw %}
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)
1000
{% endraw %} {% raw %}
cnndm_df.head(2)
article highlights ds_type
0 (CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el... John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in "century of knowledge" train
1 (CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will "hopefully bring some order" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ... NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says . train
{% endraw %} {% raw %}
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('bart',
 transformers.configuration_bart.BartConfig,
 transformers.tokenization_bart.BartTokenizer,
 transformers.modeling_bart.BartForConditionalGeneration)
{% endraw %} {% raw %}
hf_batch_tfm = HF_SummarizationBatchTransform(hf_arch, hf_tokenizer)

blocks = ( 
    HF_TextBlock(hf_arch, hf_tokenizer, max_length=256), 
    HF_TextBlock(hf_arch, hf_tokenizer, hf_batch_tfm=hf_batch_tfm, max_length=100, hf_input_idxs=[0,1])
)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())
{% endraw %} {% raw %}
dls = dblock.dataloaders(cnndm_df, bs=2)
{% endraw %} {% raw %}
b = dls.one_batch()
{% endraw %} {% raw %}
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([2, 256]), torch.Size([2, 51]))
{% endraw %} {% raw %}
dls.show_batch(hf_tokenizer=hf_tokenizer, max_n=2)
text target
0 Washington (CNN) -- Having trouble pronouncing an Italian word? If you sit on the Supreme Court, consult an expert. On Monday, Justice Sonia Sotomayor was announcing the court's opinion in Krupski v. Costa Crociere SpA (09-337), a lesser-known appeal dealing with the scope of the right to file an amended lawsuit to correct a mistake in a party's identity. The newest justice was having trouble pronouncing the name of the cruise ship company at the center of the case. Costa Cruises is a British and American-owned firm based in Genoa, Italy, where it is registered as Costa Crociere SpA. The appeal involved passenger Wanda Krupksi, who tripped over a cable and fractured her leg in 2007 aboard the Costa Magica. At issue was whether Krupski should have sued Costa Cruises or Costa Crociere SpA in federal court. The justice writing the majority ruling typically announces the decision from the bench in a public session, with a brief oral summary that supplements the official written opinion. That's where the fun began. Sotomayor needed help and knew exactly where to turn. "Costa Cruises responded that she should have sued a related company called Newest high court justice gives oral summary of cruise ship case.\nShe turned to fellow justice for pronunciation of Italian company.\nSotomayor was also honored over the weekend in New York.
1 (CNN) -- An earthquake in central Oklahoma prompted a few calls to the police but no apparent damage Saturday. The U.S. Geological Survey said the quake had a preliminary 4.5 magnitude to the quake and said it was centered near Jones, Oklahoma, 14 miles northeast of Oklahoma City. The shaking lasted 3 to 5 seconds at 12:15 p.m. Central Time, said Oklahoma City police Lt. Jason Samuel. He said a few people called the department to ask what happened. No damage had been reported. Samuel said the quake was strong enough to wake him from a nap at his home. He said it seemed stronger and longer-lasting that other earthquakes in the area in recent years. Oklahoma had a stronger earthquake -- 5.6 magnitude -- on November 5, 2011. Although damage was not widespread, it did buckle U.S. Highway 62 in Lincoln County. CNN's David Simpson and Janet DiGiacomo contributed to this report. Only a few calls to Oklahoma City police after 4.5 quake.\nStronger quake in 2011 buckled highway in Oklahoma.
{% endraw %}

Metrics

In this section, we'll add helpful metrics for summarization tasks

{% raw %}
{% endraw %} {% raw %}

calculate_rouge[source]

calculate_rouge(predicted_txts, reference_txts, rouge_keys=['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

{% endraw %}

Training

Here we create a summarization specific subclass of HF_BaseModelCallback in order to include custom, summarization specific, metrics, and also handle the pre-calculated loss during training

{% raw %}
{% endraw %} {% raw %}

class HF_SummarizationModelCallback[source]

HF_SummarizationModelCallback(rouge_metrics=['rouge1', 'rouge2', 'rougeL'], text_gen_kwargs={}, **kwargs) :: HF_BaseModelCallback

Basic class handling tweaks of the training loop by changing a Learner in various events

{% endraw %}

We add a custom param splitter to give us a bit more depth in applying discriminative learning rates for summarization.

{% raw %}
{% endraw %} {% raw %}

summarization_splitter[source]

summarization_splitter(m, arch)

Custom param splitter for summarization models

{% endraw %}

Even though we don't really need a loss function, we have to provide a custom loss class/function for fastai to function properly (e.g. one with a decodes and activation methods). Why? Because these methods will get called in methods like show_results to get the actual predictions.

{% raw %}
{% endraw %} {% raw %}

class HF_MaskedLMLoss[source]

HF_MaskedLMLoss()

{% endraw %} {% raw %}
text_gen_kwargs = { **hf_config.task_specific_params['summarization'], **{'max_length': 130, 'min_length': 30} }
text_gen_kwargs
{'early_stopping': True,
 'length_penalty': 2.0,
 'max_length': 130,
 'min_length': 30,
 'no_repeat_ngram_size': 3,
 'num_beams': 4}
{% endraw %} {% raw %}
model = HF_BaseModelWrapper(hf_model)
model_cb = HF_SummarizationModelCallback(text_gen_kwargs=text_gen_kwargs)

learn = Learner(dls, 
                model,
                opt_func=ranger,
                loss_func=HF_MaskedLMLoss(),
                cbs=[model_cb],
                splitter=partial(summarization_splitter, arch=hf_arch))#.to_fp16()

learn.create_opt() 
learn.freeze()
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
b = dls.one_batch()
preds = learn.model(b[0])
len(preds),preds[0], preds[1].shape
(3,
 tensor(4.2707, device='cuda:1', grad_fn=<NllLossBackward>),
 torch.Size([2, 53, 50264]))
{% endraw %} {% raw %}
len(b), len(b[0]), b[0]['input_ids'].shape, len(b[1]), b[1].shape
(2, 4, torch.Size([2, 256]), 2, torch.Size([2, 54]))
{% endraw %} {% raw %}
print(len(learn.opt.param_groups))
3
{% endraw %} {% raw %}
learn.lr_find(suggestions=True)
SuggestedLRs(lr_min=0.0001737800776027143, lr_steep=6.309573450380412e-07)
{% endraw %} {% raw %}
learn.fit_one_cycle(3, lr_max=3e-5)
epoch train_loss valid_loss rouge1 rouge2 rougeL time
0 1.684707 1.903918 0.366860 0.151517 0.237965 04:15
1 1.418471 1.897334 0.378602 0.162281 0.251070 04:03
2 1.189154 1.904099 0.379188 0.161845 0.255042 03:55
{% endraw %}

Showing results

Below we'll add in additional functionality to take advantage of huggingface's PreTrainedModel.generate model, which can be used to easily implement beam search, top-k/nucleous sampling, etc... so that we get more human sounding results.

{% raw %}
test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off 
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. 
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino 
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level 
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group 
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the 
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was 
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat 
her, and took off for the French border. The other gunmen followed into France, which is only about 100 
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. 
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the 
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, 
Gill said. The robbers spoke French and drove vehicles with French license plates. CNN's Andreena Narayan 
contributed to this report.
"""
{% endraw %} {% raw %}
res = learn.blurr_predict(test_article)
print(res[0][:20])
 Gun GunAbout 10 men
{% endraw %}

That doesn't look much like a human-generated summary. Let's use huggingface's PreTrainedModel.generate method to create something more human-like.

{% raw %}
test_input_ids = dls.train_ds[0][0]['input_ids'].unsqueeze(0).to(learn.model.hf_model.device)
gen_text = learn.model.hf_model.generate(test_input_ids, num_beams=4, max_length=130, min_length=30)

print('=== Target ===')
print(f'{hf_tokenizer.decode(dls.train_ds[0][1]["input_ids"], skip_special_tokens=True, clean_up_tokenization_spaces=False)}\n')

print('=== Prediction ===')
print(hf_tokenizer.decode(gen_text[0], skip_special_tokens=True, clean_up_tokenization_spaces=False))
=== Target ===
 Newest high court justice gives oral summary of cruise ship case .
She turned to fellow justice for pronunciation of Italian company .
Sotomayor was also honored over the weekend in New York .

=== Prediction ===
 Sotomayor was announcing the court's opinion in Krupski v. Costa Crociere SpA .
The appeal involved a woman who tripped over a cable and fractured her leg aboard a Costa Cruises ship .
At issue was whether she should have sued the cruise line or a related company in federal court .
Costa Cruises is a British and American-owned firm based in Genoa, Italy .
{% endraw %}

We'll add a blurr_summarize method to Learner that uses huggingface's PreTrainedModel.generate to create our predictions. For the full list of arguments you can pass in see here. You can also check out their "How To Generate" notebook for more information about how it all works.

{% raw %}
{% endraw %} {% raw %}

Learner.blurr_summarize[source]

Learner.blurr_summarize(inp, **kwargs)

Uses the built-in generate method to generate the text (see here for a list of arguments you can pass in)

{% endraw %} {% raw %}
outputs = learn.blurr_summarize(test_article, num_return_sequences=3)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')
=== Prediction 1 ===
 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
One group tried to break into the casino's vault on the lower level but could not get in .
A woman driving by unknowingly blocked the armed robbers' vehicles and was beaten to death .

=== Prediction 2 ===
 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
About 600 people were in the casino at the time of the robbery .
There were no serious injuries, although one guest was kicked in the head by one of the robbers .
The robbers spoke French and drove vehicles with French license plates, a police officer says .

=== Prediction 3 ===
 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .
About 600 people were in the casino at the time of the robbery .
There were no serious injuries, although one guest was kicked in the head by one of the robbers .
The robbers spoke French and drove vehicles with French license plates, police said .

{% endraw %}

Much nicer!!! Now, we can update our @typedispatched show_results to use this new method.

{% raw %}
{% endraw %} {% raw %}
learn.show_results(learner=learn)
text target prediction
0 (CNN) -- It's one win for Darrell Wallace Jr., but what will it mean for other African-American race car drivers -- present and future? The answer to that question might not come for years. Nonetheless, NASCAR wasted no time Saturday in hailing Wallace's on-track success at Martinsville Speedway in southern Virginia. "We congratulate Darrell Wallace Jr. on his first national series victory, one that will be remembered as a remarkable moment in our sport's history," said NASCAR chairman and CEO Brian France. Wallace took the Kroger 200 on the racing circuit's Camping World Truck Series, which is on NASCAR's third tier. Still, it is notable given that no African-American has won any NASCAR national series race since December 1, 1963, when Wendell Scott became the first ever to win a race at NASCAR's top level, in a victory at Speedway Park in Jacksonville, Florida. Scott, a Virginia native who served in the Army during World War II, raced in more than 500 races during his career -- finishing in the top five 20 times, though that would be his only victory. Plus, the 20-year-old Wallace isn't just any driver. He's a highly touted graduate of the NASCAR Drive for Diversity, having been featured NEW: "We Came. We Saw. We Conquered," Wallace tweets.\nDarrell Wallace Jr. wins a third-tier NASCAR race in Martinsville.\nIt's the first NASCAR national series win for an African-African since 1963.\nNASCAR's CEO says the win "will be remembered... in our sport's history" NASCAR hails Darrell Wallace Jr.'s win as a "remarkable moment" in the sport's history .\nNo African-American has won a NASCAR national series race since Wendell Scott in 1963 .\nWallace is a graduate of the NASCAR Drive for Diversity program .
1 (CNN) -- Israeli President Shimon Peres said he had an amicable phone conversation with Turkish Prime Minister Recep Tayyip Erdogan, a day after Erdogan stormed offstage during an angry exchange with Peres at the World Economic Forum in Davos, Switzerland. Turkish PM Recep Tayyip Erdogan leaves the stage Thursday, as Israeli President Shimon Peres sits, left. Peres said he and Erdogan did not take the spat personally. "I called him up and said, yes, it's nothing against you, nothing against Turkey. We consider you as a friend," Peres said. He said Erdogan reciprocated. Although there was no mention of an apology, Peres said there was a polite exchange between the two leaders. "I didn't take it personally. I didn't go for a personal fight. I answered unfounded accusations. It was my duty. And they didn't change my mind," he said. Watch Shimon Peres on the Gaza conflict ยป. Turkey, a predominantly Muslim nation, has long been the Jewish state's closest military and economic partner in the region, and Turkey recently mediated indirect peace talks between Israel and Syria. But many Turks have been incensed with Israel over its three-week military operation that ended there earlier this month. And in Turkish Prime Minister angered during debate on Gaza at World Economic Forum.\nRecep Tayyip Erdogan called Israel's Gaza campaign "barbaric," stormed off stage.\nIsraeli president Shimon Peres said he and Erdogan did not take spat personally.\nErdogan returned home to a hero's welcome in Istanbul. Turkish PM Erdogan stormed off stage during angry exchange with Israeli President Peres at Davos summit .\nPeres says he and Erdogan did not take spat personally, had amicable phone conversation .\nTurkey has long been Israel's closest military and economic partner in region .\nBut many Turks have been incensed with Israel over its three-week military operation in Gaza .\nErdogan recently mediated indirect peace talks between Israel and Syria .
{% endraw %}

Inference

{% raw %}
learn.export(fname='summarize_export.pkl')
{% endraw %} {% raw %}
inf_learn = load_learner(fname='summarize_export.pkl')
inf_learn.blurr_summarize(test_article)
[" Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .\nOne group tried to break into the casino's vault on the lower level but could not get in .\nA woman driving by unknowingly blocked the armed robbers' vehicles and was beaten to death ."]
{% endraw %}

Tests

The tests below to ensure the core training code above works for all pretrained summarization models available in huggingface. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

{% raw %}
try: del learn; torch.cuda.empty_cache()
except: pass
{% endraw %} {% raw %}
BLURR_MODEL_HELPER.get_models(task='ConditionalGeneration')
[transformers.modeling_bart.BartForConditionalGeneration,
 transformers.modeling_mbart.MBartForConditionalGeneration,
 transformers.modeling_pegasus.PegasusForConditionalGeneration,
 transformers.modeling_t5.T5ForConditionalGeneration]
{% endraw %} {% raw %}
pretrained_model_names = [
    ('facebook/bart-large-cnn',BartForConditionalGeneration),
    ('t5-small', T5ForConditionalGeneration),
    #('google/pegasus-cnn_dailymail', PegasusForConditionalGeneration), ... don't fit on my 1080TI :(
]
{% endraw %} {% raw %}
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')
{% endraw %} {% raw %}
#hide_output
bsz = 2
inp_seq_sz = 128

test_results = []
for model_name, model_cls in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, 
                                                                                   model_cls=model_cls)
    
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\nmodel:\t\t{type(hf_model).__name__}\n')
    
    # 1. build your DataBlock
    hf_batch_tfm = HF_SummarizationBatchTransform(hf_arch, hf_tokenizer)

    blocks = ( 
        HF_TextBlock(hf_arch, hf_tokenizer, padding='max_length', max_length=inp_seq_sz), 
        HF_TextBlock(hf_arch, hf_tokenizer, hf_batch_tfm=hf_batch_tfm, padding='max_length', max_length=30, 
                     hf_input_idxs=[0,1])
    )
    
    def add_t5_prefix(inp): return f'summarize: {inp}' if (hf_arch == 't5') else inp

    dblock = DataBlock(blocks=blocks, 
                   get_x=Pipeline([ColReader('article'), add_t5_prefix]), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())
    
    dls = dblock.dataloaders(cnndm_df, bs=bsz)

    # 2. build your Learner
    text_gen_kwargs = {}
    if (hf_arch in ['bart', 't5']):
        text_gen_kwargs = { 
            **hf_config.task_specific_params['summarization'], 
            **{'max_length': 30, 'min_length': 10} 
        }
    
    model = HF_BaseModelWrapper(hf_model)
    model_cb = HF_SummarizationModelCallback(text_gen_kwargs=text_gen_kwargs)

    learn = Learner(dls, 
                    model,
                    opt_func=ranger,
                    loss_func=HF_MaskedLMLoss(),
                    cbs=[model_cb],
                    splitter=partial(summarization_splitter, arch=hf_arch))#.to_fp16()

    learn.create_opt() 
    learn.freeze()
    
    # 3. Run your tests
    b = dls.one_batch()

    try:
        print('*** TESTING DataLoaders ***\n')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, inp_seq_sz]))
        test_eq(len(b[1]), bsz)

        print('*** TESTING One pass through the model ***')
        preds = learn.model(b[0])
        test_eq(preds[1].shape[0], bsz)
        test_eq(preds[1].shape[2], hf_config.vocab_size)

        print('*** TESTING Training/Results ***')
        learn.fit_one_cycle(1, lr_max=1e-3)

        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'PASSED', ''))
        learn.show_results(learner=learn, max_n=2)
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'FAILED', err))
    finally:
        # cleanup
        del learn; torch.cuda.empty_cache()
{% endraw %} {% raw %}
arch tokenizer model_name result error
0 bart BartTokenizer BartForConditionalGeneration PASSED
1 t5 T5Tokenizer T5ForConditionalGeneration PASSED
{% endraw %}

Cleanup