# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
import torch.nn as nn
import pytorch_lightning as pl
from typing import Dict, Optional, List, Tuple
from omegaconf import DictConfig
from torch import Tensor
from openspeech.criterion import CRITERION_REGISTRY
from openspeech.metrics import WordErrorRate, CharacterErrorRate
from openspeech.optim.scheduler import SCHEDULER_REGISTRY
from openspeech.utils import get_class_name
from openspeech.vocabs.vocab import Vocabulary
[docs]class OpenspeechModel(pl.LightningModule):
r"""
Super class of openspeech models.
Note:
Do not use this class directly, use one of the sub classes.
Args:
configs (DictConfig): configuration set.
vocab (Vocabulary): the class of vocabulary
Inputs:
inputs (torch.FloatTensor): A input sequence passed to encoders. Typically for inputs this will be a padded
`FloatTensor` of size ``(batch, seq_length, dimension)``.
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
Returns:
* y_hats (torch.FloatTensor): Result of model predictions.
"""
def __init__(self, configs: DictConfig, vocab: Vocabulary) -> None:
super(OpenspeechModel, self).__init__()
self.configs = configs
self.num_classes = len(vocab)
self.gradient_clip_val = configs.trainer.gradient_clip_val
self.vocab = vocab
self.current_val_loss = 100.0
self.wer_metric = WordErrorRate(vocab)
self.cer_metric = CharacterErrorRate(vocab)
self.criterion = self.configure_criterion(configs.criterion.criterion_name)
def build_model(self):
raise NotImplementedError
def set_beam_decoder(self, beam_size: int = 3):
raise NotImplementedError
[docs] def log_steps(
self,
stage: str,
wer: float,
cer: float,
loss: Optional[float] = None,
cross_entropy_loss: Optional[float] = None,
ctc_loss: Optional[float] = None,
) -> None:
r"""
Provides log dictionary.
Args:
stage (str): current stage (train, valid, test)
wer (float): word error rate
cer (float): character error rate
loss (float): loss of model's prediction
cross_entropy_loss (Optional, float): cross entropy loss of model's prediction
ctc_loss (Optional, float): ctc loss of model's prediction
"""
self.log(f"{stage}_wer", wer)
self.log(f"{stage}_cer", cer)
if loss is not None:
self.log(f"{stage}_loss", loss)
if cross_entropy_loss is not None:
self.log(f"{stage}_cross_entropy_loss", cross_entropy_loss)
if ctc_loss is not None:
self.log(f"{stage}_ctc_loss", ctc_loss)
if hasattr(self, "optimizer"):
self.log("current_lr", self.get_lr())
[docs] def forward(self, inputs: torch.FloatTensor, input_lengths: torch.LongTensor) -> Dict[str, Tensor]:
r"""
Forward propagate a `inputs` and `targets` pair for inference.
Inputs:
inputs (torch.FloatTensor): A input sequence passed to encoders. Typically for inputs this will be a padded
`FloatTensor` of size ``(batch, seq_length, dimension)``.
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
Returns:
* outputs (dict): Result of model predictions.
"""
raise NotImplementedError
[docs] def training_step(self, batch: tuple, batch_idx: int):
r"""
Forward propagate a `inputs` and `targets` pair for training.
Inputs:
batch (tuple): A train batch contains `inputs`, `targets`, `input_lengths`, `target_lengths`
batch_idx (int): The index of batch
Returns:
loss (torch.Tensor): loss for training
"""
raise NotImplementedError
[docs] def validation_step(self, batch: tuple, batch_idx: int):
r"""
Forward propagate a `inputs` and `targets` pair for validation.
Inputs:
batch (tuple): A train batch contains `inputs`, `targets`, `input_lengths`, `target_lengths`
batch_idx (int): The index of batch
Returns:
loss (torch.Tensor): loss for training
"""
raise NotImplementedError
[docs] def validation_epoch_end(self, outputs: dict) -> dict:
self.current_val_loss = torch.stack([output['val_loss'] for output in outputs]).mean()
if get_class_name(self.scheduler) == "WarmupReduceLROnPlateauScheduler" \
or get_class_name(self.scheduler) == "ReduceLROnPlateauScheduler":
self.scheduler.step(self.current_val_loss)
return {
'loss': self.current_val_loss,
'log': {'val_loss': self.current_val_loss},
}
[docs] def test_step(self, batch: tuple, batch_idx: int):
r"""
Forward propagate a `inputs` and `targets` pair for test.
Inputs:
batch (tuple): A train batch contains `inputs`, `targets`, `input_lengths`, `target_lengths`
batch_idx (int): The index of batch
Returns:
loss (torch.Tensor): loss for training
"""
raise NotImplementedError
def get_lr(self):
for g in self.optimizer.param_groups:
return g['lr']
def set_lr(self, lr):
for g in self.optimizer.param_groups:
g['lr'] = lr