Module ktrain.text.summarization.core
Expand source code
from ...torch_base import TorchBase
class TransformerSummarizer(TorchBase):
"""
interface to Transformer-based text summarization
"""
def __init__(self, model_name='facebook/bart-large-cnn', device=None):
"""
```
interface to BART-based text summarization using transformers library
Args:
model_name(str): name of BART model for summarization
device(str): device to use (e.g., 'cuda', 'cpu')
```
"""
if 'bart' not in model_name:
raise ValueError('TransformerSummarizer currently only accepts BART models')
super().__init__(device=device)
from transformers import BartTokenizer, BartForConditionalGeneration
self.tokenizer = BartTokenizer.from_pretrained(model_name)
self.model = BartForConditionalGeneration.from_pretrained(model_name).to(self.torch_device)
def summarize(self, doc):
"""
```
summarize document text
Args:
doc(str): text of document
Returns:
str: summary text
```
"""
import torch
with torch.no_grad():
answers_input_ids = self.tokenizer.batch_encode_plus([doc],
return_tensors='pt', truncation=True,
max_length=1024)['input_ids'].to(self.torch_device)
summary_ids = self.model.generate(answers_input_ids,
num_beams=4,
length_penalty=2.0,
max_length=142,
min_length=56,
no_repeat_ngram_size=3)
exec_sum = self.tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True)
return exec_sum
Classes
class TransformerSummarizer (model_name='facebook/bart-large-cnn', device=None)
-
interface to Transformer-based text summarization
interface to BART-based text summarization using transformers library Args: model_name(str): name of BART model for summarization device(str): device to use (e.g., 'cuda', 'cpu')
Expand source code
class TransformerSummarizer(TorchBase): """ interface to Transformer-based text summarization """ def __init__(self, model_name='facebook/bart-large-cnn', device=None): """ ``` interface to BART-based text summarization using transformers library Args: model_name(str): name of BART model for summarization device(str): device to use (e.g., 'cuda', 'cpu') ``` """ if 'bart' not in model_name: raise ValueError('TransformerSummarizer currently only accepts BART models') super().__init__(device=device) from transformers import BartTokenizer, BartForConditionalGeneration self.tokenizer = BartTokenizer.from_pretrained(model_name) self.model = BartForConditionalGeneration.from_pretrained(model_name).to(self.torch_device) def summarize(self, doc): """ ``` summarize document text Args: doc(str): text of document Returns: str: summary text ``` """ import torch with torch.no_grad(): answers_input_ids = self.tokenizer.batch_encode_plus([doc], return_tensors='pt', truncation=True, max_length=1024)['input_ids'].to(self.torch_device) summary_ids = self.model.generate(answers_input_ids, num_beams=4, length_penalty=2.0, max_length=142, min_length=56, no_repeat_ngram_size=3) exec_sum = self.tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True) return exec_sum
Ancestors
Methods
def summarize(self, doc)
-
summarize document text Args: doc(str): text of document Returns: str: summary text
Expand source code
def summarize(self, doc): """ ``` summarize document text Args: doc(str): text of document Returns: str: summary text ``` """ import torch with torch.no_grad(): answers_input_ids = self.tokenizer.batch_encode_plus([doc], return_tensors='pt', truncation=True, max_length=1024)['input_ids'].to(self.torch_device) summary_ids = self.model.generate(answers_input_ids, num_beams=4, length_penalty=2.0, max_length=142, min_length=56, no_repeat_ngram_size=3) exec_sum = self.tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True) return exec_sum
Inherited members