--- title: modeling.seq2seq.core keywords: fastai sidebar: home_sidebar summary: "This module contains core custom models, loss functions, etc... for Seq2Seq based tasks (e.g., language modeling, summarization, translation, etc...)" description: "This module contains core custom models, loss functions, etc... for Seq2Seq based tasks (e.g., language modeling, summarization, translation, etc...)" nb_path: "nbs/02za_modeling-seq2seq-core.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
{% endraw %} {% raw %}
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
Using GPU #1: GeForce GTX 1080 Ti
{% endraw %}

Seq2Seq

{% raw %}
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')

cnndm_df.head(2)
article highlights ds_type
0 (CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el... John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in "century of knowledge" train
1 (CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will "hopefully bring some order" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ... NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says . train
{% endraw %} {% raw %}
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, 
                                                                  model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)
('bart',
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)
{% endraw %} {% raw %}
before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                  max_length=256, max_target_length=130)

blocks = (HF_Seq2SeqBlock(before_batch_tfm=before_batch_tfm), noop)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())
{% endraw %} {% raw %}
dls = dblock.dataloaders(cnndm_df, bs=2)
{% endraw %} {% raw %}
b = dls.one_batch()
{% endraw %} {% raw %}
len(b), b[0]['input_ids'].shape, b[1].shape
(2, torch.Size([2, 256]), torch.Size([2, 84]))
{% endraw %} {% raw %}
dls.show_batch(dataloaders=dls, max_n=2)
text target
0 (CNN) -- Home to up to 10 percent of all known species, Mexico is recognized as one of the most biodiverse regions on the planet. The twin threats of climate change and human encroachment on natural environments are, however, threatening the existence of the country's rich wildlife. And there is a great deal to lose. In the United Nations Environment Program (UNEP) World Conservation Monitoring Centre's list of megadiverse countries Mexico ranks 11th. The list represents a group of 17 countries that harbor the majority of the Earth's species and are therefore considered extremely biodiverse. From its coral reefs in the Caribbean Sea to its tropical jungles in Chiapas and the Yucatan peninsula and its deserts and prairies in the north, Mexico boasts an incredibly rich variety of flora and fauna. Some 574 out of 717 reptile species found in Mexico -- the most in any country -- can only be encountered within its borders. It is home to 502 types of mammals, 290 species of birds, 1,150 varieties of birds and 26,000 classifications of plants. Pronatura, a non-profit organization that works to promote conservation and sustainable development in Mexico, has selected six species which it says symbolize the problems faced by the Mexico hosts to up to 10 percent of all known species on Earth.\nIt is home to 502 types of mammals, 290 bird species and 26,000 types of plants.\nHuman development and climate change is placing a big strain on its biodiversity.\nThe Golden Eagle is under threat in spite of being the country's national symbol.
1 (CNN) -- It's a congested, sprawling transport hub surrounded by 1950s architecture and predominantly used by commuters or tourists to cross the city of Istanbul. But proposed changes to Taksim Square have seen it become the flashpoint for protests that have swept through Turkey in the past week, leaving thousands injured and focusing the world's attention on the government of Prime Minister Recep Tayyip Erdogan. Taksim has been no stranger to violence. In 1977, at least 34 protesters died during May Day clashes with police. May 1 rallies in the square were banned in 1980 and were only allowed to legally resume in 2010. On May Day this year, there were riots after city authorities again refused to grant trade unions and youth groups permission to demonstrate in Taksim, blaming construction work being carried out in the square. Professor Ersin Kalaycioglu, professor of political science at Istanbul's Sabanci University, said significantly, Taksim Square was also known as "republic square," because it was built by the Republic of Turkey's founding fathers to commemorate the war of liberation. "Taksim Square is connected to Istiklal Caddesi -- Independence Avenue -- and Cumhuriyet Caddesi -- the Avenue of the Republic. So there Taksim Square was where Istanbul's water was distributed -- Taksim means divide.\nThe site is seen as symbolizing the seclar Turkish republic founded by Ataturk.\nErdogan's government's plans to alter Taksim's Gezi Park prompted protests.\nThe police's heavy-handed response saw demonstrators' numbers surge.
{% endraw %}

Training

Here we create a Seq2Seq specific subclass of HF_BaseModelCallback in order to include custom, Seq2Seq specific, metrics, and also handle the pre-calculated loss during training

seq2seq_metrics

  • {'rouge': { 'compute_args': {'return_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True}, 'returns':["rouge1", "rouge2", "rougeL"]}
  • {'bert_score': { 'returns': ["precision", "recall", "f1"] }
  • {'bleu': { 'returns': "bleu" }
  • {'bleurt': { 'returns': "scores" }
  • {'meteor': { 'returns': "meteor" }
  • {'sacrebleu': { 'returns': "score" }
{% raw %}

class HF_Seq2SeqMetricsCallback[source]

HF_Seq2SeqMetricsCallback(custom_metrics=None, ignore_token_id=-100, text_gen_kwargs={}, **kwargs) :: Callback

Basic class handling tweaks of the training loop by changing a Learner in various events

{% endraw %} {% raw %}
{% endraw %}

We add a custom param splitter to give us a bit more depth in applying discriminative learning rates for Seq2Seq tasks.

{% raw %}
{% endraw %} {% raw %}

seq2seq_splitter[source]

seq2seq_splitter(m, arch)

Custom param splitter for summarization models

{% endraw %} {% raw %}
seq2seq_metrics = {
    'rouge': {
        'compute_kwargs': {
            'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True
        }, 
        'returns': ["rouge1", "rouge2", "rougeL"] 
    },
    'bertscore': {
        'compute_kwargs': { 'lang': 'en' },
        'returns': ["precision", "recall", "f1"]
    }, 
    'bleu': { 'returns': "bleu" },
    'meteor': { 'returns': "meteor" },
    'sacrebleu': { 'returns': "score" }
}

model = HF_BaseModelWrapper(hf_model)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

learn = Learner(dls, 
                model,
                opt_func=partial(Adam),
                loss_func=CrossEntropyLossFlat(), #HF_PreCalculatedLoss()
                cbs=learn_cbs,
                splitter=partial(seq2seq_splitter, arch=hf_arch)) #.to_native_fp16() #.to_fp16()

learn.freeze()
{% endraw %} {% raw %}
b = dls.one_batch()
preds = learn.model(b[0])

len(preds),preds['loss'].shape, preds['logits'].shape
(4, torch.Size([]), torch.Size([2, 69, 50264]))
{% endraw %} {% raw %}
b = dls.one_batch()
preds = learn.model(b[0])

len(preds),preds['loss'].shape, preds['logits'].shape
(4, torch.Size([]), torch.Size([2, 69, 50264]))
{% endraw %} {% raw %}
print(len(learn.opt.param_groups))
3
{% endraw %} {% raw %}
learn.lr_find(suggestions=True)
SuggestedLRs(lr_min=8.317637839354575e-05, lr_steep=6.309573450380412e-07)
{% endraw %} {% raw %}
learn.fit_one_cycle(1, lr_max=4e-5, cbs=fit_cbs)
epoch train_loss valid_loss rouge1 rouge2 rougeL bertscore_precision bertscore_recall bertscore_f1 bleu meteor sacrebleu time
0 1.835147 1.645218 0.390021 0.174499 0.265284 0.875935 0.894593 0.885088 0.151584 0.309012 12.513763 03:19
{% endraw %}

Showing results

Below we'll add in additional functionality to take advantage of huggingface's PreTrainedModel.generate model, which can be used to easily implement beam search, top-k/nucleous sampling, etc... so that we get more human sounding results.

{% raw %}
test_article = """
About 10 men armed with pistols and small machine guns raided a casino in Switzerland and made off 
into France with several hundred thousand Swiss francs in the early hours of Sunday morning, police said. 
The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino 
Basel, Chief Inspector Peter Gill told CNN. One group tried to break into the casino's vault on the lower level 
but could not get in, but they did rob the cashier of the money that was not secured, he said. The second group 
of armed robbers entered the upper level where the roulette and blackjack tables are located and robbed the 
cashier there, he said. As the thieves were leaving the casino, a woman driving by and unaware of what was 
occurring unknowingly blocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beat 
her, and took off for the French border. The other gunmen followed into France, which is only about 100 
meters (yards) from the casino, Gill said. There were about 600 people in the casino at the time of the robbery. 
There were no serious injuries, although one guest on the Casino floor was kicked in the head by one of the 
robbers when he moved, the police officer said. Swiss authorities are working closely with French authorities, 
Gill said. The robbers spoke French and drove vehicles with French lRicense plates. CNN's Andreena Narayan 
contributed to this report.
"""
{% endraw %} {% raw %}
res = learn.blurr_predict(test_article)
print(hf_tokenizer.decode(res[0][0][0][:20]))
<s><s>                About 10 men armed with pistols and machine machine guns raid a casino in Switzerland. made
{% endraw %}

That doesn't look much like a human-generated text. Let's use huggingface's PreTrainedModel.generate method to create something more human-like.

{% raw %}
b = dls.valid.one_batch()

b_before_batch_tfm = get_blurr_tfm(dls.before_batch)

b_hf_tokenizer = b_before_batch_tfm.hf_tokenizer
b_ignore_token_id = b_before_batch_tfm.ignore_token_id

test_input_ids = b[0]['input_ids'][0].unsqueeze(0).to(learn.model.hf_model.device)
test_trg_ids = b[1][0].unsqueeze(0).to(learn.model.hf_model.device)
test_trg_ids = [ trg[trg != b_ignore_token_id] for trg in test_trg_ids ]

gen_text = learn.model.hf_model.generate(test_input_ids, num_beams=4, max_length=130, min_length=30)

print('=== Target ===')
print(f'{b_hf_tokenizer.decode(test_trg_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)}\n')

print('=== Prediction ===')
print(b_hf_tokenizer.decode(gen_text[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
=== Target ===
 Consider U.S. efforts to offer Afghan citizens an alternative to the Taliban.
Hear how a proposed health care bill addresses the issue of the public option.
Meet a soldier who is making history at the U.S. Army Drill Sergeant School.
Use the Daily Discussion to help students understand today's featured news stories.

=== Prediction ===
 Find out how a member of the military is making history.
President Obama meets with advisers to discuss the U.S. strategy in Afghanistan and Pakistan.
Learn about a crash that led to the deaths of more than 100 people.
Use the Daily Discussion to help students understand today's featured news stories.
{% endraw %}

We'll add a blurr_generate method to Learner that uses huggingface's PreTrainedModel.generate to create our predictions. For the full list of arguments you can pass in see here. You can also check out their "How To Generate" notebook for more information about how it all works.

{% raw %}
{% endraw %} {% raw %}

Learner.blurr_generate[source]

Learner.blurr_generate(inp, task=None, **kwargs)

Uses the built-in generate method to generate the text (see here for a list of arguments you can pass in)

{% endraw %} {% raw %}
outputs = learn.blurr_generate(test_article, num_return_sequences=3)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')
=== Prediction 1 ===
 Police say about 10 men robbed the Grand Casino Basel in the early hours of Sunday morning .
About 600 people were in the casino at the time of the robbery .
There were no serious injuries, but one guest was kicked in the head by one of the robbers .
The robbers made off with several hundred thousand Swiss francs from the casino, police say .

=== Prediction 2 ===
 Police say about 10 men robbed the Grand Casino Basel in the early hours of Sunday morning .
About 600 people were in the casino at the time of the robbery .
There were no serious injuries, but one guest was kicked in the head by one of the robbers .
The robbers made off with several hundred thousand Swiss francs, police say .

=== Prediction 3 ===
 Police say about 10 men robbed the Grand Casino Basel in the early hours of Sunday morning .
About 600 people were in the casino at the time of the robbery .
There were no serious injuries, but one guest was kicked in the head by one of the robbers .
The robbers made off with several hundred thousand Swiss francs from the casino .

{% endraw %}

Much nicer!!! Now, we can update our @typedispatched show_results to use this new method.

{% raw %}
{% endraw %} {% raw %}
learn.show_results(learner=learn, input_trunc_at=500, target_trunc_at=250)
text target prediction
0 (CNN Student News) -- October 27, 2009. Downloadable Maps. Download PDF maps related to today's show: • Afghanistan & Pakistan • Los Angeles & San Diego • Ft. Jackson, South Carolina. Transcript. THIS IS A RUSH TRANSCRIPT. THIS COPY MAY NOT BE IN ITS FINAL FORM AND MAY BE UPDATED. NATISHA LANCE, CNN STUDENT NEWS ANCHOR: A member of the military is making history. We'll explain how in today's edition of CNN Student News. Hi, everyone. Carl Azuz is off this week. I'm Natisha Lance. First Up: Afg Consider U.S. efforts to offer Afghan citizens an alternative to the Taliban.\nHear how a proposed health care bill addresses the issue of the public option.\nMeet a soldier who is making history at the U.S. Army Drill Sergeant School.\nUse the Daily D Find out how a member of the military is making history .\nPresident Obama meets with advisers to discuss the U.S. strategy in Afghanistan and Pakistan .\nLearn about a crash that led to the deaths of more than 100 people .\nUse the Daily Discussion to
1 Michael Zehaf-Bibeau, the 32-year-old gunman who attacked the Canadian Parliament and killed a soldier last week, had a familiar profile. It is that of a young man alienated from mainstream society, with few friends, without a steady job, drifting from one place to another -- often with a history of petty crime and drug abuse. Then comes the conversion to or rediscovery of Islam, and the adoption of a jihadist mindset, fed by media and online coverage of the West's involvement in wars in Iraq a Like many "lone wolf" terrorists, Ottawa gunman was alienated drifter who converted to Islam.\nConversion to militant Islam is often about seeking identity, purpose or adventure.\nSome countries have tried "de-radicalization" programs to help prevent Michael Zehaf-Bibeau, the gunman who attacked the Canadian Parliament last week, had a familiar profile .\nHe was alienated from mainstream society, with few friends, without a steady job, drifting from one place to another .\nHis conversion to milita
{% endraw %}

Inference

{% raw %}
export_fname = 'summarize_export'
{% endraw %} {% raw %}
learn.metrics = None
learn.export(fname=f'{export_fname}.pkl')
{% endraw %} {% raw %}
inf_learn = load_learner(fname=f'{export_fname}.pkl')
inf_learn.blurr_generate(test_article)
[' Police say about 10 men robbed the Grand Casino Basel in the early hours of Sunday morning .\nAbout 600 people were in the casino at the time of the robbery .\nThere were no serious injuries, but one guest was kicked in the head by one of the robbers .\nThe robbers made off with several hundred thousand Swiss francs from the casino, police say .']
{% endraw %}

Cleanup