Source code for openspeech.decoders.lstm_decoder

# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import random
import torch
import torch.nn as nn
from typing import Optional, Tuple, Any

from openspeech.decoders import OpenspeechDecoder
from openspeech.modules import (
    Linear,
    View,
    LocationAwareAttention,
    MultiHeadAttention,
    AdditiveAttention,
    DotProductAttention,
)


[docs]class LSTMDecoder(OpenspeechDecoder): r""" Converts higher level features (from encoders) into output utterances by specifying a probability distribution over sequences of characters. Args: num_classes (int): number of classification hidden_state_dim (int): the number of features in the decoders hidden state `h` num_layers (int, optional): number of recurrent layers (default: 2) rnn_type (str, optional): type of RNN cell (default: lstm) pad_id (int, optional): index of the pad symbol (default: 0) sos_id (int, optional): index of the start of sentence symbol (default: 1) eos_id (int, optional): index of the end of sentence symbol (default: 2) attn_mechanism (str, optional): type of attention mechanism (default: multi-head) num_heads (int, optional): number of attention heads. (default: 4) dropout_p (float, optional): dropout probability of decoders (default: 0.2) Inputs: inputs, encoder_outputs, teacher_forcing_ratio - **inputs** (batch, seq_len, input_size): list of sequences, whose length is the batch size and within which each sequence is a list of token IDs. It is used for teacher forcing when provided. (default `None`) - **encoder_outputs** (batch, seq_len, hidden_state_dim): tensor with containing the outputs of the encoders. Used for attention mechanism (default is `None`). - **teacher_forcing_ratio** (float): The probability that teacher forcing will be used. A random number is drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value, teacher forcing would be used (default is 0). Returns: logits * logits (torch.FloatTensor) : log probabilities of model's prediction """ supported_rnns = { 'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN, } def __init__( self, num_classes: int, max_length: int = 150, hidden_state_dim: int = 1024, pad_id: int = 0, sos_id: int = 1, eos_id: int = 2, attn_mechanism: str = 'multi-head', num_heads: int = 4, num_layers: int = 2, rnn_type: str = 'lstm', dropout_p: float = 0.3, ) -> None: super(LSTMDecoder, self).__init__() self.hidden_state_dim = hidden_state_dim self.num_classes = num_classes self.num_heads = num_heads self.num_layers = num_layers self.max_length = max_length self.eos_id = eos_id self.sos_id = sos_id self.pad_id = pad_id self.attn_mechanism = attn_mechanism.lower() self.embedding = nn.Embedding(num_classes, hidden_state_dim) self.input_dropout = nn.Dropout(dropout_p) rnn_cell = self.supported_rnns[rnn_type.lower()] self.rnn = rnn_cell( input_size=hidden_state_dim, hidden_size=hidden_state_dim, num_layers=num_layers, bias=True, batch_first=True, dropout=dropout_p, bidirectional=False, ) if self.attn_mechanism == 'loc': self.attention = LocationAwareAttention(hidden_state_dim, attn_dim=hidden_state_dim, smoothing=False) elif self.attn_mechanism == 'multi-head': self.attention = MultiHeadAttention(hidden_state_dim, num_heads=num_heads) elif self.attn_mechanism == 'additive': self.attention = AdditiveAttention(hidden_state_dim) elif self.attn_mechanism == 'dot': self.attention = DotProductAttention(dim=hidden_state_dim) elif self.attn_mechanism == 'scaled-dot': self.attention = DotProductAttention(dim=hidden_state_dim, scale=True) else: raise ValueError("Unsupported attention: %s".format(attn_mechanism)) self.fc = nn.Sequential( Linear(hidden_state_dim << 1, hidden_state_dim), nn.Tanh(), View(shape=(-1, self.hidden_state_dim), contiguous=True), Linear(hidden_state_dim, num_classes), ) def forward_step( self, input_var: torch.Tensor, hidden_states: Optional[torch.Tensor], encoder_outputs: torch.Tensor, attn: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, output_lengths = input_var.size(0), input_var.size(1) embedded = self.embedding(input_var) embedded = self.input_dropout(embedded) if self.training: self.rnn.flatten_parameters() outputs, hidden_states = self.rnn(embedded, hidden_states) if self.attn_mechanism == 'loc': context, attn = self.attention(outputs, encoder_outputs, attn) else: context, attn = self.attention(outputs, encoder_outputs, encoder_outputs) outputs = torch.cat((outputs, context), dim=2) step_outputs = self.fc(outputs.view(-1, self.hidden_state_dim << 1)).log_softmax(dim=-1) step_outputs = step_outputs.view(batch_size, output_lengths, -1).squeeze(1) return step_outputs, hidden_states, attn
[docs] def forward( self, encoder_outputs: torch.Tensor, targets: Optional[torch.Tensor] = None, encoder_output_lengths: Optional[torch.Tensor] = None, teacher_forcing_ratio: float = 1.0, ) -> torch.Tensor: """ Forward propagate a `encoder_outputs` for training. Args: targets (torch.LongTensr): A target sequence passed to decoders. `IntTensor` of size ``(batch, seq_length)`` encoder_outputs (torch.FloatTensor): A output sequence of encoders. `FloatTensor` of size ``(batch, seq_length, dimension)`` encoder_output_lengths: The length of encoders outputs. ``(batch)`` teacher_forcing_ratio (float): ratio of teacher forcing Returns: * logits (torch.FloatTensor): Log probability of model predictions. """ logits = list() hidden_states, attn = None, None targets, batch_size, max_length = self.validate_args(targets, encoder_outputs, teacher_forcing_ratio) use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False if use_teacher_forcing: targets = targets[targets != self.eos_id].view(batch_size, -1) if self.attn_mechanism == 'loc' or self.attn_mechanism == 'additive': for di in range(targets.size(1)): input_var = targets[:, di].unsqueeze(1) step_outputs, hidden_states, attn = self.forward_step( input_var=input_var, hidden_states=hidden_states, encoder_outputs=encoder_outputs, attn=attn, ) logits.append(step_outputs) else: step_outputs, hidden_states, attn = self.forward_step( input_var=targets, hidden_states=hidden_states, encoder_outputs=encoder_outputs, attn=attn, ) for di in range(step_outputs.size(1)): step_output = step_outputs[:, di, :] logits.append(step_output) else: input_var = targets[:, 0].unsqueeze(1) for di in range(max_length): step_outputs, hidden_states, attn = self.forward_step( input_var=input_var, hidden_states=hidden_states, encoder_outputs=encoder_outputs, attn=attn, ) logits.append(step_outputs) input_var = logits[-1].topk(1)[1] logits = torch.stack(logits, dim=1) return logits
def validate_args( self, targets: Optional[Any] = None, encoder_outputs: torch.Tensor = None, teacher_forcing_ratio: float = 1.0, ) -> Tuple[torch.Tensor, int, int]: assert encoder_outputs is not None batch_size = encoder_outputs.size(0) if targets is None: # inference targets = torch.LongTensor([self.sos_id] * batch_size).view(batch_size, 1) max_length = self.max_length if torch.cuda.is_available(): targets = targets.cuda() if teacher_forcing_ratio > 0: raise ValueError("Teacher forcing has to be disabled (set 0) when no targets is provided.") else: max_length = targets.size(1) - 1 # minus the start of sequence symbol return targets, batch_size, max_length