Module ktrain.text.summarization.core

Expand source code
from ...torch_base import TorchBase


class TransformerSummarizer(TorchBase):
    """
    interface to Transformer-based text summarization
    """

    def __init__(self, model_name="facebook/bart-large-cnn", device=None):
        """
        ```
        interface to BART-based text summarization using transformers library

        Args:
          model_name(str): name of BART model for summarization
          device(str): device to use (e.g., 'cuda', 'cpu')
        ```
        """
        if "bart" not in model_name:
            raise ValueError("TransformerSummarizer currently only accepts BART models")
        super().__init__(device=device)
        from transformers import BartForConditionalGeneration, BartTokenizer

        self.tokenizer = BartTokenizer.from_pretrained(model_name)
        self.model = BartForConditionalGeneration.from_pretrained(model_name).to(
            self.torch_device
        )

    def summarize(self, doc):
        """
        ```
        summarize document text
        Args:
          doc(str): text of document
        Returns:
          str: summary text
        ```
        """
        import torch

        with torch.no_grad():
            answers_input_ids = self.tokenizer.batch_encode_plus(
                [doc], return_tensors="pt", truncation=True, max_length=1024
            )["input_ids"].to(self.torch_device)
            summary_ids = self.model.generate(
                answers_input_ids,
                num_beams=4,
                length_penalty=2.0,
                max_length=142,
                min_length=56,
                no_repeat_ngram_size=3,
            )

            exec_sum = self.tokenizer.decode(
                summary_ids.squeeze(), skip_special_tokens=True
            )
        return exec_sum

Classes

class TransformerSummarizer (model_name='facebook/bart-large-cnn', device=None)

interface to Transformer-based text summarization

interface to BART-based text summarization using transformers library

Args:
  model_name(str): name of BART model for summarization
  device(str): device to use (e.g., 'cuda', 'cpu')
Expand source code
class TransformerSummarizer(TorchBase):
    """
    interface to Transformer-based text summarization
    """

    def __init__(self, model_name="facebook/bart-large-cnn", device=None):
        """
        ```
        interface to BART-based text summarization using transformers library

        Args:
          model_name(str): name of BART model for summarization
          device(str): device to use (e.g., 'cuda', 'cpu')
        ```
        """
        if "bart" not in model_name:
            raise ValueError("TransformerSummarizer currently only accepts BART models")
        super().__init__(device=device)
        from transformers import BartForConditionalGeneration, BartTokenizer

        self.tokenizer = BartTokenizer.from_pretrained(model_name)
        self.model = BartForConditionalGeneration.from_pretrained(model_name).to(
            self.torch_device
        )

    def summarize(self, doc):
        """
        ```
        summarize document text
        Args:
          doc(str): text of document
        Returns:
          str: summary text
        ```
        """
        import torch

        with torch.no_grad():
            answers_input_ids = self.tokenizer.batch_encode_plus(
                [doc], return_tensors="pt", truncation=True, max_length=1024
            )["input_ids"].to(self.torch_device)
            summary_ids = self.model.generate(
                answers_input_ids,
                num_beams=4,
                length_penalty=2.0,
                max_length=142,
                min_length=56,
                no_repeat_ngram_size=3,
            )

            exec_sum = self.tokenizer.decode(
                summary_ids.squeeze(), skip_special_tokens=True
            )
        return exec_sum

Ancestors

Methods

def summarize(self, doc)
summarize document text
Args:
  doc(str): text of document
Returns:
  str: summary text
Expand source code
def summarize(self, doc):
    """
    ```
    summarize document text
    Args:
      doc(str): text of document
    Returns:
      str: summary text
    ```
    """
    import torch

    with torch.no_grad():
        answers_input_ids = self.tokenizer.batch_encode_plus(
            [doc], return_tensors="pt", truncation=True, max_length=1024
        )["input_ids"].to(self.torch_device)
        summary_ids = self.model.generate(
            answers_input_ids,
            num_beams=4,
            length_penalty=2.0,
            max_length=142,
            min_length=56,
            no_repeat_ngram_size=3,
        )

        exec_sum = self.tokenizer.decode(
            summary_ids.squeeze(), skip_special_tokens=True
        )
    return exec_sum

Inherited members