--- title: modeling.summarization keywords: fastai sidebar: home_sidebar summary: "This module contains custom models, loss functions, custom splitters, etc... summarization tasks." description: "This module contains custom models, loss functions, custom splitters, etc... summarization tasks." nb_path: "nbs/02e_modeling-summarization.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti
{% endraw %}

Summarization

The objective of summarization is to generate a concise and accurate representation of a much larger body of text. For example, we may want to summarize an article in a single sentence.

{% raw %}
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)
1000
{% endraw %} {% raw %}
cnndm_df.head(2)
article highlights ds_type
0 (CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el... John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in "century of knowledge" train
1 (CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will "hopefully bring some order" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ... NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says . train
{% endraw %} {% raw %}
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('bart',
 transformers.configuration_bart.BartConfig,
 transformers.tokenization_bart.BartTokenizer,
 transformers.modeling_bart.BartForConditionalGeneration)
{% endraw %} {% raw %}
hf_batch_tfm = HF_SummarizationBatchTransform(hf_arch, hf_tokenizer, max_length=[256, 130])

blocks = (HF_TextBlock(hf_batch_tfm=hf_batch_tfm), noop)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())
{% endraw %} {% raw %}
dls = dblock.dataloaders(cnndm_df, bs=2)
{% endraw %} {% raw %}
b = dls.one_batch()
{% endraw %} {% raw %}
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([2, 256]), torch.Size([2, 84]))
{% endraw %} {% raw %}
dls.show_batch(dataloaders=dls, max_n=2)
text target
0 (CNN) -- Home to up to 10 percent of all known species, Mexico is recognized as one of the most biodiverse regions on the planet. The twin threats of climate change and human encroachment on natural environments are, however, threatening the existence of the country's rich wildlife. And there is a great deal to lose. In the United Nations Environment Program (UNEP) World Conservation Monitoring Centre's list of megadiverse countries Mexico ranks 11th. The list represents a group of 17 countries that harbor the majority of the Earth's species and are therefore considered extremely biodiverse. From its coral reefs in the Caribbean Sea to its tropical jungles in Chiapas and the Yucatan peninsula and its deserts and prairies in the north, Mexico boasts an incredibly rich variety of flora and fauna. Some 574 out of 717 reptile species found in Mexico -- the most in any country -- can only be encountered within its borders. It is home to 502 types of mammals, 290 species of birds, 1,150 varieties of birds and 26,000 classifications of plants. Pronatura, a non-profit organization that works to promote conservation and sustainable development in Mexico, has selected six species which it says symbolize the problems faced by the Mexico hosts to up to 10 percent of all known species on Earth.\nIt is home to 502 types of mammals, 290 bird species and 26,000 types of plants.\nHuman development and climate change is placing a big strain on its biodiversity.\nThe Golden Eagle is under threat in spite of being the country's national symbol.
1 Michael Zehaf-Bibeau, the 32-year-old gunman who attacked the Canadian Parliament and killed a soldier last week, had a familiar profile. It is that of a young man alienated from mainstream society, with few friends, without a steady job, drifting from one place to another -- often with a history of petty crime and drug abuse. Then comes the conversion to or rediscovery of Islam, and the adoption of a jihadist mindset, fed by media and online coverage of the West's involvement in wars in Iraq and Afghanistan, and by the well-oiled propaganda machine of groups like ISIS. Whether these young men acting alone (and they are almost always men) should be better described as "lone-wolf" terrorists or deranged criminals is debatable. "There is no single, universally accepted definition of terrorism," says the FBI. In many cases, their conversion to militant Islam is about seeking identity and purpose, or even a sense of adventure. Few of these men have a deep understanding of Salafism, the deeply conservative brand of Islam that's the philosophical underpinning of groups like al Qaeda, and jihad; their writings are often incoherent. Frequently they see radical Islam as a form of redemption from past misdeeds Like many "lone wolf" terrorists, Ottawa gunman was alienated drifter who converted to Islam.\nConversion to militant Islam is often about seeking identity, purpose or adventure.\nSome countries have tried "de-radicalization" programs to help prevent violence.\nBut with resources stretched thin, the focus is often on increased law enforcement.
{% endraw %}

Metrics

In this section, we'll add helpful metrics for summarization tasks

{% raw %}
{% endraw %} {% raw %}

calculate_rouge[source]

calculate_rouge(predicted_txts, reference_txts, rouge_keys=['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

{% endraw %}

Training

Here we create a summarization specific subclass of HF_BaseModelCallback in order to include custom, summarization specific, metrics, and also handle the pre-calculated loss during training

{% raw %}
{% endraw %} {% raw %}

class HF_SummarizationModelCallback[source]

HF_SummarizationModelCallback(rouge_metrics=['rouge1', 'rouge2', 'rougeL'], text_gen_kwargs={}, **kwargs) :: HF_BaseModelCallback

Basic class handling tweaks of the training loop by changing a Learner in various events

{% endraw %}

We add a custom param splitter to give us a bit more depth in applying discriminative learning rates for summarization.

{% raw %}
{% endraw %} {% raw %}

summarization_splitter[source]

summarization_splitter(m, arch)

Custom param splitter for summarization models

{% endraw %}

Even though we don't really need a loss function, we have to provide a custom loss class/function for fastai to function properly (e.g. one with a decodes and activation methods). Why? Because these methods will get called in methods like show_results to get the actual predictions.

{% raw %}
{% endraw %} {% raw %}

class HF_MaskedLMLoss[source]

HF_MaskedLMLoss()

{% endraw %} {% raw %}
text_gen_kwargs = { **hf_config.task_specific_params['summarization'], **{'max_length': 130, 'min_length': 30} }
text_gen_kwargs
{'early_stopping': True,
 'length_penalty': 2.0,
 'max_length': 130,
 'min_length': 30,
 'no_repeat_ngram_size': 3,
 'num_beams': 4}
{% endraw %} {% raw %}
model = HF_BaseModelWrapper(hf_model)
model_cb = HF_SummarizationModelCallback(text_gen_kwargs=text_gen_kwargs)

learn = Learner(dls, 
                model,
                opt_func=ranger,
                loss_func=HF_MaskedLMLoss(),
                cbs=[model_cb],
                splitter=partial(summarization_splitter, arch=hf_arch))#.to_fp16()

learn.create_opt() 
learn.freeze()
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
b = dls.one_batch()
preds = learn.model(b[0])
len(preds),preds[0], preds[1].shape
(3,
 tensor(4.0109, device='cuda:1', grad_fn=<NllLossBackward>),
 torch.Size([2, 73, 50264]))
{% endraw %} {% raw %}
len(b), len(b[0]), b[0]['input_ids'].shape, len(b[1]), b[1].shape
(2, 4, torch.Size([2, 256]), 2, torch.Size([2, 74]))
{% endraw %} {% raw %}
print(len(learn.opt.param_groups))
3
{% endraw %} {% raw %}
learn.lr_find(suggestions=True)
SuggestedLRs(lr_min=0.00020892962347716094, lr_steep=0.00015848931798245758)
{% endraw %} {% raw %}
learn.fit_one_cycle(1, lr_max=4e-5)
epoch train_loss valid_loss rouge1 rouge2 rougeL time
0 1.690670 1.899561 0.357975 0.144255 0.234915 04:04
{% endraw %}

Showing results

Below we'll add in additional functionality to take advantage of huggingface's PreTrainedModel.generate model, which can be used to easily implement beam search, top-k/nucleous sampling, etc... so that we get more human sounding results.

{% raw %}
test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off 
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. 
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino 
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level 
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group 
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the 
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was 
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat 
her, and took off for the French border. The other gunmen followed into France, which is only about 100 
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. 
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the 
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, 
Gill said. The robbers spoke French and drove vehicles with French license plates. CNN's Andreena Narayan 
contributed to this report.
"""
{% endraw %} {% raw %}
res = learn.blurr_predict(test_article)
print(hf_tokenizer.decode(res[0][:20]))
 10 10About 10 men armed with pistols and machine machine guns raided a casino in Switzerland and made off
{% endraw %}

That doesn't look much like a human-generated summary. Let's use huggingface's PreTrainedModel.generate method to create something more human-like.

{% raw %}
b = dls.valid.one_batch()

test_input_ids = b[0]['input_ids'][0].unsqueeze(0).to(learn.model.hf_model.device)
test_trg_ids = b[1][0].unsqueeze(0).to(learn.model.hf_model.device)

gen_text = learn.model.hf_model.generate(test_input_ids, num_beams=4, max_length=130, min_length=30)

print('=== Target ===')
print(f'{hf_tokenizer.decode(test_trg_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)}\n')

print('=== Prediction ===')
print(hf_tokenizer.decode(gen_text[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
=== Target ===
 NEW: Planes depart Australia to resume their search for airplane debris.
NEW: Official: Passengers' relatives are moved to a different Kuala Lumpur hotel.
Objects seen on satellite spark intensive search in southern Indian Ocean.
U.S. officials: Files were deleted from flight simulator's hard drive after February 3.

=== Prediction ===
 Malaysia Airlines Flight 370 has been missing for two weeks, with no sign of the Boeing 777.
Two large objects detected by satellite Sunday floating on waters over 1,400 miles off Australia's west coast.
Australian military planes resume their search for the objects amid skepticism they will turn up anything.
Malaysia Airlines flight 370 was carrying 227 passengers and 12 crew members, bound for Beijing from Kuala Lumpur.
Search has covered 2.97 million square miles, nearly the size of the continental United States.
{% endraw %}

We'll add a blurr_summarize method to Learner that uses huggingface's PreTrainedModel.generate to create our predictions. For the full list of arguments you can pass in see here. You can also check out their "How To Generate" notebook for more information about how it all works.

{% raw %}
{% endraw %} {% raw %}

Learner.blurr_summarize[source]

Learner.blurr_summarize(inp, **kwargs)

Uses the built-in generate method to generate the text (see here for a list of arguments you can pass in)

{% endraw %} {% raw %}
outputs = learn.blurr_summarize(test_article, num_return_sequences=3)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')
=== Prediction 1 ===
 About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off with several hundred thousand Swiss francs .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
A woman driving by unknowingly blocked the armed robbers' vehicles and was beaten to death .
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the robbers .

=== Prediction 2 ===
 About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off with several hundred thousand Swiss francs .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
A woman driving by unknowingly blocked the armed robbers' vehicles and was beaten to death .
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the robbers when he moved .

=== Prediction 3 ===
 About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off with several hundred thousand Swiss francs .
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .
A woman driving by unknowingly blocked the armed robbers' vehicles and was beaten to death .
There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved .

{% endraw %}

Much nicer!!! Now, we can update our @typedispatched show_results to use this new method.

{% raw %}
{% endraw %} {% raw %}
learn.show_results(learner=learn)
text target prediction
0 (CNN) -- Two weeks. Two gut-wrenching, frustrating, mysterious weeks. That's how long it's been since 227 passengers and 12 crew members boarded Malaysia Airlines Flight 370, destined for Beijing. A routine trip, it seemed, to catch up relatives in time for the weekend, start on a work assignment or just get away. Where they got to, still unknown. An exhaustive search -- covering a mind-boggling 2.97 million square miles, which is nearly the size of the continental United States -- has yielded some clues, but no proof of where the Boeing 777 is or definitively what happened to it. The latest, most notable lead revolved around two large objects detected by satellite Sunday floating on waters over 1,400 miles off of Australia's west coast. The first of several Australian military planes, as well as two long-range commercial jets, resumed their search Saturday morning to find any trace of the objects, amid some skepticism that they or ships in the area ever will and, if they do, that whatever they find will be related to the missing aircraft. Australian Prime Minister Tony Abbott on Friday defended the decision to announce the find, saying Australia owes it to families of those missing "to give them information as soon as it's NEW: Planes depart Australia to resume their search for airplane debris.\nNEW: Official: Passengers' relatives are moved to a different Kuala Lumpur hotel.\nObjects seen on satellite spark intensive search in southern Indian Ocean.\nU.S. officials: Files were deleted from flight simulator's hard drive after February 3. Malaysia Airlines Flight 370 has been missing for two weeks, with no sign of the Boeing 777 .\nTwo large objects detected by satellite Sunday floating on waters over 1,400 miles off Australia's west coast .\nAustralian military planes resume their search for the objects amid skepticism they will turn up anything .\nMalaysia Airlines flight 370 was carrying 227 passengers and 12 crew members, bound for Beijing from Kuala Lumpur .\nSearch has covered 2.97 million square miles, nearly the size of the continental United States .
1 As a growing number of airplanes scoured the southern Indian Ocean in the search for Malaysia Airlines Flight 370, authorities released new details that paint a different picture of what may have happened in the plane's cockpit. Military radar tracking shows that the aircraft changed altitude after making a sharp turn over the South China Sea as it headed toward the Strait of Malacca, a source close to the investigation into the missing flight told CNN. The plane flew as low as 12,000 feet at some point before it disappeared from radar, according to the source. The sharp turn seemed to be intentional, the source said, because executing it would have taken the Boeing 777 two minutes -- a time period during which the pilot or co-pilot could have sent an emergency signal if there had been a fire or other emergency onboard. Authorities say the plane didn't send any emergency signals, though some analysts say it's still unclear whether the pilots tried but weren't able to communicate because of a catastrophic failure. The official, who is not authorized to speak to the media, told CNN that the area the plane flew in after the turn is a heavily trafficked air corridor and that flying at 12,000 feet would have kept the jet well out of the way of that traffic. Earlier Sunday, Malaysian authorities U.S. Navy sending listening device to help find voice and data recorders if wreckage is found.\nSource: Plane changed altitude, flying as low as 12,000 feet after making short turn.\nSchiavo: Altitude information "explains so many pieces that didn't fit together"\n10 aircraft set to comb southern region for missing plane as search resumes Monday. Malaysia Airlines Flight 370 changed altitude after making a sharp turn over the South China Sea, a source says .\nThe plane flew as low as 12,000 feet at some point before it disappeared from radar, the source tells CNN .\nAuthorities say the plane didn't send any emergency signals, though some analysts say it's still unclear .\nIt's unclear whether the pilots tried to communicate because of a catastrophic failure, analysts say .
{% endraw %}

Inference

{% raw %}
learn.export(fname='summarize_export.pkl')
{% endraw %} {% raw %}
inf_learn = load_learner(fname='summarize_export.pkl')
inf_learn.blurr_summarize(test_article)
[" About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off with several hundred thousand Swiss francs .\nThe men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .\nA woman driving by unknowingly blocked the armed robbers' vehicles and was beaten to death .\nThere were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the robbers ."]
{% endraw %}

Tests

The tests below to ensure the core training code above works for all pretrained summarization models available in huggingface. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

{% raw %}
try: del learn; torch.cuda.empty_cache()
except: pass
{% endraw %} {% raw %}
BLURR_MODEL_HELPER.get_models(task='ConditionalGeneration')
[transformers.modeling_bart.BartForConditionalGeneration,
 transformers.modeling_mbart.MBartForConditionalGeneration,
 transformers.modeling_pegasus.PegasusForConditionalGeneration,
 transformers.modeling_t5.T5ForConditionalGeneration]
{% endraw %} {% raw %}
pretrained_model_names = [
    ('facebook/bart-large-cnn',BartForConditionalGeneration),
    ('t5-small', T5ForConditionalGeneration),
    #('google/pegasus-cnn_dailymail', PegasusForConditionalGeneration), ... don't fit on my 1080TI :(
]
{% endraw %} {% raw %}
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')
{% endraw %} {% raw %}
#hide_output
bsz = 2
inp_seq_sz = 128; trg_seq_sz = 130

test_results = []
for model_name, model_cls in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, 
                                                                                   model_cls=model_cls)
    
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\nmodel:\t\t{type(hf_model).__name__}\n')
    
    # 1. build your DataBlock
    def add_t5_prefix(inp): return f'summarize: {inp}' if (hf_arch == 't5') else inp
    
    hf_batch_tfm = HF_SummarizationBatchTransform(hf_arch, hf_tokenizer, max_length=[inp_seq_sz, trg_seq_sz])
    blocks = (HF_TextBlock(hf_batch_tfm=hf_batch_tfm), noop)
    dblock = DataBlock(blocks=blocks, 
                       get_x=Pipeline([ColReader('article'), add_t5_prefix]), 
                       get_y=ColReader('highlights'), 
                       splitter=RandomSplitter())

    dls = dblock.dataloaders(cnndm_df, bs=bsz)

    # 2. build your Learner
    text_gen_kwargs = {}
    if (hf_arch in ['bart', 't5']):
        text_gen_kwargs = { 
            **hf_config.task_specific_params['summarization'], 
            **{'max_length': 30, 'min_length': 10} 
        }
    
    model = HF_BaseModelWrapper(hf_model)
    model_cb = HF_SummarizationModelCallback(text_gen_kwargs=text_gen_kwargs)

    learn = Learner(dls, 
                    model,
                    opt_func=ranger,
                    loss_func=HF_MaskedLMLoss(),
                    cbs=[model_cb],
                    splitter=partial(summarization_splitter, arch=hf_arch))#.to_fp16()

    learn.create_opt() 
    learn.freeze()
    
    # 3. Run your tests
    b = dls.one_batch()

    try:
        print('*** TESTING DataLoaders ***\n')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, inp_seq_sz]))
        test_eq(len(b[1]), bsz)

        print('*** TESTING One pass through the model ***')
        preds = learn.model(b[0])
        test_eq(preds[1].shape[0], bsz)
        test_eq(preds[1].shape[2], hf_config.vocab_size)

        print('*** TESTING Training/Results ***')
        learn.fit_one_cycle(1, lr_max=1e-3)

        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'PASSED', ''))
        learn.show_results(learner=learn, max_n=2)
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, type(hf_model).__name__, 'FAILED', err))
    finally:
        # cleanup
        del learn; torch.cuda.empty_cache()
{% endraw %} {% raw %}
arch tokenizer model_name result error
0 bart BartTokenizer BartForConditionalGeneration PASSED
1 t5 T5Tokenizer T5ForConditionalGeneration PASSED
{% endraw %}

Cleanup