--- title: data.summarization keywords: fastai sidebar: home_sidebar summary: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for summarization tasks using architectures like BART and T5." description: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for summarization tasks using architectures like BART and T5." nb_path: "nbs/01e_data-summarization.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti
{% endraw %}

Summarization tokenization, batch transform, and DataBlock methods

Summarization tasks attempt to generate a human-understandable and sensible representation of a larger body of text (e.g., capture the meaning of a larger document in 1-3 sentences).

{% raw %}
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)
1000
{% endraw %} {% raw %}
cnndm_df.head(2)
article highlights ds_type
0 (CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el... John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in "century of knowledge" train
1 (CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will "hopefully bring some order" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ... NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says . train
{% endraw %} {% raw %}
pretrained_model_name = "facebook/bart-large-cnn"

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               model_cls=BartForConditionalGeneration)

hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
('bart',
 transformers.tokenization_bart.BartTokenizer,
 transformers.configuration_bart.BartConfig,
 transformers.modeling_bart.BartForConditionalGeneration)
{% endraw %} {% raw %}
{% endraw %} {% raw %}

class HF_SummarizationInput[source]

HF_SummarizationInput(iterable=()) :: list

Built-in mutable sequence.

If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.

{% endraw %}

We create a subclass of HF_BatchTransform for summarization tasks to add decoder_input_ids and labels to our inputs during training, which will in turn allow the huggingface model to calculate the loss for us. See here for more information on these additional inputs are used in summarization and conversational training tasks.

Note also that labels is simply target_ids shifted to the right by one since the task to is to predict the next token based on the current (and all previous) decoder_input_ids.

And lastly, we also update our targets to just be the input_ids of our target sequence so that fastai's Learner.show_results works (again, almost all the fastai bits require returning a single tensor to work).

{% raw %}
{% endraw %} {% raw %}

class HF_SummarizationBatchTransform[source]

HF_SummarizationBatchTransform(hf_arch, hf_tokenizer, **kwargs) :: HF_BatchTransform

Handles everything you need to assemble a mini-batch of inputs and targets, as well as decode HF_TokenizerTransform inputs

{% endraw %}

We had to override the decodes method above because, while both our inputs and targets are technically the same things, we update the later to consist of only the target input_ids so that methods like Learner.show_results work. Nevertheless, because fastai remembers what they are, HF_TokenizerTransform.decodes will be called for both and it works on a list of input_ids.

{% raw %}
hf_batch_tfm = HF_SummarizationBatchTransform(hf_arch, hf_tokenizer)

blocks = ( 
    HF_TextBlock(hf_arch, hf_tokenizer), 
    HF_TextBlock(hf_arch, hf_tokenizer, hf_batch_tfm=hf_batch_tfm, max_length=150, hf_input_idxs=[0,1])
)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
dls = dblock.dataloaders(cnndm_df, bs=4)
{% endraw %} {% raw %}
b = dls.one_batch()
{% endraw %} {% raw %}
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([4, 512]), torch.Size([4, 71]))
{% endraw %} {% raw %}
{% endraw %} {% raw %}
dls.show_batch(dataloaders=dls, max_n=2)
text target
0 (CNN) -- Standing outside a courthouse Sunday that the Libyan opposition is using for a base of operations in the town of Misrata, a witness described a sense of jubilation against a backdrop of blood stains and rocket fragments. "I'm standing in the middle of a... battlefield," the witness told CNN by phone from Misrata after a fierce fight between rebels and Libyan leader Moammar Gadhafi's forces. People were holding their hands up, singing, chanting and cheering, he said. "Everyone is hugging everyone." CNN is not identifying witnesses and sources for safety reasons. Videos posted on YouTube and thought to be out of Misrata showed damage to buildings and several shots of people celebrating around the opposition flag -- once being raised on a pole, and another time being waved by a man atop a charred vehicle that had a dead body inside. A doctor at Central Misrata Hospital said 42 people were killed in the fighting -- 17 from the opposition and 25 from the pro-Gadhafi forces. Among the dead was a 3-year-old child, killed from direct fire, the doctor said. At least 85 people were wounded, the doctor said. The fighting continued on the city's outskirts Sunday evening. The witness described the opposition's victory in central Misrata even as people some 200 kilometers (125 miles) west, at a pro-Gadhafi demonstration in Tripoli, insisted the government had taken back the coastal central Libyan city. After reports of the opposition successfully holding onto Misrata, east of Tripoli, Libyan state TV showed a graphic stating that "strict orders have been issued to the armed forces not to enter cities taken by terrorist gangs." On Sunday morning, pro-Gadhafi militias converged on Misrata from three different points, trying to retake control of the city, the witness said. He saw four tanks, though other witnesses told him there were a total of six. Using heavy artillery, the ground forces and tanks headed for the courthouse operations base. Tanks fired rockets at the building, and black smoke could be seen rising from it, he said. The opposition couldn't match the government's weaponry, but rebels took to the streets using what weapons they had, such as machine guns. And some simply picked up whatever they could find, with some resorting to sticks, he said. Speaking to CNN during the battle, he said, "People are willing to die for the cause," describing them as "fearless" and "amazing." Later, after the NEW: Videos online show damage to buildings and waving of the opposition's flag.\nA doctor at a hospital in the city says 42 people were killed, 85 wounded.\nWitness in Misrata: "Everyone is hugging everyone" despite "blood everywhere"\nPro-Gadhafi demonstrators in Tripoli claimed the government had taken the city.
1 I have an uncle who has always been a robust and healthy guy. He drank a glass of skim milk every day, bragged about how many pull-ups he was doing and fit into pants he was wearing 20 years before. He didn't take a single medication and retired early. Given that he had no medical problems and ran his own business, he opted to go several years without health insurance. Eventually, when he turned 65, he picked up Medicare. What happened next was a little strange. He fell off the wagon. He exercised only sporadically, and paid hardly any attention to what he was eating. One day, I saw him eat an entire bag of potato chips. He bemoaned the fact that he was forced to buy new, bigger pants, and he stopped drinking his milk. For him, becoming newly insured had nearly the opposite effect on him of what we doctors hope to achieve. He'd become unhealthier. In many ways, my uncle was demonstrating a concept known as the moral hazard. Two economists wrote about this exact scenario in 2006. They found that many men, at the time they obtained Medicare, started behaving badly. Moral, or morale, hazard is a term largely used by economists to describe the actions of people more willing to take risks because they are insulated from the cost of their actions, in this case because of their recently obtained health insurance. In the case of these men, when they got Medicare, they took worse care of themselves; they actually exercised less. Among those who didn't visit the doctor after getting insurance, the effect was dramatic: Their overall physical activity dropped by 40%; they were 16% more likely to smoke cigarettes and 32% more likely to drink alcohol. Even if that seems extreme, it's still worth asking: Does health insurance make us healthier? The past five years have seen a tumultuous battle over Obamacare, or the Affordable Care Act, culminating in the bitter recriminations this fall over lost policies and the disastrous launch of the HealthCare.gov website. When I interviewed Health and Human Services Secretary Kathleen Sebelius at the end of October, she downplayed the concerns and seemed certain the site would be up and running by the end of November. The website may be working better now, but to me that's not the most important issue. In my mind, the real suspense comes from whether Obamacare will really make us a healthier America, even if it succeeds in its ambitions to dramatically expand coverage. A healthier America: That is the goal we should share as Americans Sanjay Gupta: Moral hazard causes some to neglect health when they get health insurance.\nHe says Obamacare alone won't guarantee good health; personal habits must do that.\nHe says research shows 30 minutes of daily exercise cuts heart attack, stroke risk by a third.\nGupta: It's time to stop playing defense on your health; instead, start optimizing it yourself.
{% endraw %}

Tests

The tests below to ensure the core DataBlock code above works for all pretrained summarization models available in huggingface. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)

{% raw %}
BLURR_MODEL_HELPER.get_models(task='ConditionalGeneration')
[transformers.modeling_bart.BartForConditionalGeneration,
 transformers.modeling_mbart.MBartForConditionalGeneration,
 transformers.modeling_pegasus.PegasusForConditionalGeneration,
 transformers.modeling_t5.T5ForConditionalGeneration]
{% endraw %} {% raw %}
pretrained_model_names = [
    ('facebook/bart-base',BartForConditionalGeneration),
    ('t5-small', T5ForConditionalGeneration),
    ('google/pegasus-cnn_dailymail', PegasusForConditionalGeneration)
]
{% endraw %} {% raw %}
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')
{% endraw %} {% raw %}
#hide_output
task = HF_TASKS_ALL.ConditionalGeneration
bsz = 2

test_results = []
for model_name, model_cls in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, 
                                                                                   task=task,
                                                                                   model_cls=model_cls)
    
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n')
    
    hf_batch_tfm = HF_SummarizationBatchTransform(hf_arch, hf_tokenizer)

    blocks = ( 
        HF_TextBlock(hf_arch, hf_tokenizer, padding='max_length', max_length=256), 
        HF_TextBlock(hf_arch, hf_tokenizer, hf_batch_tfm=hf_batch_tfm, padding='max_length', max_length=50, 
                     hf_input_idxs=[0,1])
    )

    def add_t5_prefix(inp): return f'summarize: {inp}' if (hf_arch == 't5') else inp

    dblock = DataBlock(blocks=blocks, 
                   get_x=Pipeline([ColReader('article'), add_t5_prefix]), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())

    dls = dblock.dataloaders(cnndm_df, bs=bsz) 
    b = dls.one_batch()
    
    try:
        print('*** TESTING DataLoaders ***\n')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, 256]))
        test_eq(len(b[1]), bsz)
        test_eq(b[1].shape, torch.Size([bsz,50]))

        if (hasattr(hf_tokenizer, 'add_prefix_space')):
            test_eq(dls.tfms[0].kwargs['add_prefix_space'], True)
            
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'PASSED', ''))
        dls.show_batch(dataloaders=dls, max_n=2)
        
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'FAILED', err))
{% endraw %} {% raw %}
arch tokenizer model_name result error
0 bart BartTokenizer facebook/bart-base PASSED
1 t5 T5Tokenizer t5-small PASSED
2 pegasus PegasusTokenizer google/pegasus-cnn_dailymail PASSED
{% endraw %}

Cleanup