--- title: data.summarization keywords: fastai sidebar: home_sidebar summary: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for summarization tasks using architectures like BART and T5." description: "This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for summarization tasks using architectures like BART and T5." nb_path: "nbs/01e_data-summarization.ipynb" ---
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)
cnndm_df.head(2)
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name,
model_cls=BartForConditionalGeneration)
hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)
We create a subclass of HF_BatchTransform
for summarization tasks to add decoder_input_ids
and labels
to our inputs during training, which will in turn allow the huggingface model to calculate the loss for us. See here for more information on these additional inputs are used in summarization and conversational training tasks.
Note also that labels
is simply target_ids shifted to the right by one since the task to is to predict the next token based on the current (and all previous) decoder_input_ids
.
And lastly, we also update our targets to just be the input_ids
of our target sequence so that fastai's Learner.show_results
works (again, almost all the fastai bits require returning a single tensor to work).
We had to override the decodes
method above because, while both our inputs and targets are technically the same things, we update the later to consist of only the target input_ids so that methods like Learner.show_results
work.
blocks = (HF_TextBlock(hf_batch_tfm=HF_SummarizationBatchTransform(hf_arch, hf_tokenizer)), noop)
dblock = DataBlock(blocks=blocks, get_x=ColReader('article'), get_y=ColReader('highlights'), splitter=RandomSplitter())
Two lines! Notice we pass in noop
for our targets (e.g. our summaries) because the batch transform will take care of both out inputs and targets.
dls = dblock.dataloaders(cnndm_df, bs=4)
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
dls.show_batch(dataloaders=dls, max_n=2)
The tests below to ensure the core DataBlock code above works for all pretrained summarization models available in huggingface. These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.
Note: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue (or a PR if you'd like to fix it yourself)
BLURR_MODEL_HELPER.get_models(task='ConditionalGeneration')
pretrained_model_names = [
('facebook/bart-base',BartForConditionalGeneration),
('t5-small', T5ForConditionalGeneration),
('google/pegasus-cnn_dailymail', PegasusForConditionalGeneration)
]
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')
#hide_output
task = HF_TASKS_ALL.ConditionalGeneration
bsz = 2
seq_sz = 256
trg_seq_sz = 40
test_results = []
for model_name, model_cls in pretrained_model_names:
error=None
print(f'=== {model_name} ===\n')
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name,
task=task,
model_cls=model_cls)
print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n')
hf_batch_tfm = HF_SummarizationBatchTransform(hf_arch, hf_tokenizer,
padding='max_length', max_length=[seq_sz, trg_seq_sz])
blocks = (
HF_TextBlock(hf_arch, hf_tokenizer, hf_batch_tfm=hf_batch_tfm),
noop
)
def add_t5_prefix(inp): return f'summarize: {inp}' if (hf_arch == 't5') else inp
dblock = DataBlock(blocks=blocks,
get_x=Pipeline([ColReader('article'), add_t5_prefix]),
get_y=ColReader('highlights'),
splitter=RandomSplitter())
dls = dblock.dataloaders(cnndm_df, bs=bsz)
b = dls.one_batch()
try:
print('*** TESTING DataLoaders ***\n')
test_eq(len(b), 2)
test_eq(len(b[0]['input_ids']), bsz)
test_eq(b[0]['input_ids'].shape, torch.Size([bsz, seq_sz]))
test_eq(len(b[1]), bsz)
test_eq(b[1].shape, torch.Size([bsz, trg_seq_sz]))
if (hasattr(hf_tokenizer, 'add_prefix_space')):
test_eq(dls.before_batch[0].tok_kwargs['add_prefix_space'], True)
test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'PASSED', ''))
dls.show_batch(dataloaders=dls, max_n=2)
except Exception as err:
test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'FAILED', err))