Module ktrain.text.summarization.core

Expand source code
# 2020-08-10: unnecessary imports removed for ZSL to address #225
#from ...imports import *
#from ... import utils as U

class TransformerSummarizer():
    """
    interface to Transformer-based text summarization
    """

    def __init__(self, model_name='facebook/bart-large-cnn', device=None):
        """
        ```
        interface to BART-based text summarization using transformers library

        Args:
          model_name(str): name of BART model for summarization
          device(str): device to use (e.g., 'cuda', 'cpu')
        ```
        """
        if 'bart' not in model_name:
            raise ValueError('TransformerSummarizer currently only accepts BART models')
        try:
            import torch
        except ImportError:
            raise Exception('TransformerSummarizer requires PyTorch to be installed.')
        self.torch_device = device
        if self.torch_device is None: self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
        from transformers import BartTokenizer, BartForConditionalGeneration
        self.tokenizer = BartTokenizer.from_pretrained(model_name)
        self.model = BartForConditionalGeneration.from_pretrained(model_name).to(self.torch_device)


    def summarize(self, doc):
        """
        ```
        summarize document text
        Args:
          doc(str): text of document
        Returns:
          str: summary text
        ```
        """
        import torch
        with torch.no_grad():
            answers_input_ids = self.tokenizer.batch_encode_plus([doc], 
                                                                 return_tensors='pt', truncation=True,
                                                                 max_length=1024)['input_ids'].to(self.torch_device)
            summary_ids = self.model.generate(answers_input_ids,
                                              num_beams=4,
                                              length_penalty=2.0,
                                              max_length=142,
                                              min_length=56,
                                              no_repeat_ngram_size=3)

            exec_sum = self.tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True)
        return exec_sum

Classes

class TransformerSummarizer (model_name='facebook/bart-large-cnn', device=None)

interface to Transformer-based text summarization

interface to BART-based text summarization using transformers library

Args:
  model_name(str): name of BART model for summarization
  device(str): device to use (e.g., 'cuda', 'cpu')
Expand source code
class TransformerSummarizer():
    """
    interface to Transformer-based text summarization
    """

    def __init__(self, model_name='facebook/bart-large-cnn', device=None):
        """
        ```
        interface to BART-based text summarization using transformers library

        Args:
          model_name(str): name of BART model for summarization
          device(str): device to use (e.g., 'cuda', 'cpu')
        ```
        """
        if 'bart' not in model_name:
            raise ValueError('TransformerSummarizer currently only accepts BART models')
        try:
            import torch
        except ImportError:
            raise Exception('TransformerSummarizer requires PyTorch to be installed.')
        self.torch_device = device
        if self.torch_device is None: self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
        from transformers import BartTokenizer, BartForConditionalGeneration
        self.tokenizer = BartTokenizer.from_pretrained(model_name)
        self.model = BartForConditionalGeneration.from_pretrained(model_name).to(self.torch_device)


    def summarize(self, doc):
        """
        ```
        summarize document text
        Args:
          doc(str): text of document
        Returns:
          str: summary text
        ```
        """
        import torch
        with torch.no_grad():
            answers_input_ids = self.tokenizer.batch_encode_plus([doc], 
                                                                 return_tensors='pt', truncation=True,
                                                                 max_length=1024)['input_ids'].to(self.torch_device)
            summary_ids = self.model.generate(answers_input_ids,
                                              num_beams=4,
                                              length_penalty=2.0,
                                              max_length=142,
                                              min_length=56,
                                              no_repeat_ngram_size=3)

            exec_sum = self.tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True)
        return exec_sum

Methods

def summarize(self, doc)
summarize document text
Args:
  doc(str): text of document
Returns:
  str: summary text
Expand source code
def summarize(self, doc):
    """
    ```
    summarize document text
    Args:
      doc(str): text of document
    Returns:
      str: summary text
    ```
    """
    import torch
    with torch.no_grad():
        answers_input_ids = self.tokenizer.batch_encode_plus([doc], 
                                                             return_tensors='pt', truncation=True,
                                                             max_length=1024)['input_ids'].to(self.torch_device)
        summary_ids = self.model.generate(answers_input_ids,
                                          num_beams=4,
                                          length_penalty=2.0,
                                          max_length=142,
                                          min_length=56,
                                          no_repeat_ngram_size=3)

        exec_sum = self.tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True)
    return exec_sum