add files

This commit is contained in:
烨玮
2025-02-20 12:17:03 +08:00
parent a21dd4555c
commit edd008441b
667 changed files with 473123 additions and 0 deletions

View File

@@ -0,0 +1,334 @@
import random
import numpy as np
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.nets_utils import to_device
from funasr_local.modules.rnn.attentions import initial_att
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.utils.get_default_kwargs import get_default_kwargs
def build_attention_list(
eprojs: int,
dunits: int,
atype: str = "location",
num_att: int = 1,
num_encs: int = 1,
aheads: int = 4,
adim: int = 320,
awin: int = 5,
aconv_chans: int = 10,
aconv_filts: int = 100,
han_mode: bool = False,
han_type=None,
han_heads: int = 4,
han_dim: int = 320,
han_conv_chans: int = -1,
han_conv_filts: int = 100,
han_win: int = 5,
):
att_list = torch.nn.ModuleList()
if num_encs == 1:
for i in range(num_att):
att = initial_att(
atype,
eprojs,
dunits,
aheads,
adim,
awin,
aconv_chans,
aconv_filts,
)
att_list.append(att)
elif num_encs > 1: # no multi-speaker mode
if han_mode:
att = initial_att(
han_type,
eprojs,
dunits,
han_heads,
han_dim,
han_win,
han_conv_chans,
han_conv_filts,
han_mode=True,
)
return att
else:
att_list = torch.nn.ModuleList()
for idx in range(num_encs):
att = initial_att(
atype[idx],
eprojs,
dunits,
aheads[idx],
adim[idx],
awin[idx],
aconv_chans[idx],
aconv_filts[idx],
)
att_list.append(att)
else:
raise ValueError(
"Number of encoders needs to be more than one. {}".format(num_encs)
)
return att_list
class RNNDecoder(AbsDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
rnn_type: str = "lstm",
num_layers: int = 1,
hidden_size: int = 320,
sampling_probability: float = 0.0,
dropout: float = 0.0,
context_residual: bool = False,
replace_sos: bool = False,
num_encs: int = 1,
att_conf: dict = get_default_kwargs(build_attention_list),
):
# FIXME(kamo): The parts of num_spk should be refactored more more more
assert check_argument_types()
if rnn_type not in {"lstm", "gru"}:
raise ValueError(f"Not supported: rnn_type={rnn_type}")
super().__init__()
eprojs = encoder_output_size
self.dtype = rnn_type
self.dunits = hidden_size
self.dlayers = num_layers
self.context_residual = context_residual
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.odim = vocab_size
self.sampling_probability = sampling_probability
self.dropout = dropout
self.num_encs = num_encs
# for multilingual translation
self.replace_sos = replace_sos
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
self.dropout_emb = torch.nn.Dropout(p=dropout)
self.decoder = torch.nn.ModuleList()
self.dropout_dec = torch.nn.ModuleList()
self.decoder += [
torch.nn.LSTMCell(hidden_size + eprojs, hidden_size)
if self.dtype == "lstm"
else torch.nn.GRUCell(hidden_size + eprojs, hidden_size)
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
for _ in range(1, self.dlayers):
self.decoder += [
torch.nn.LSTMCell(hidden_size, hidden_size)
if self.dtype == "lstm"
else torch.nn.GRUCell(hidden_size, hidden_size)
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
# NOTE: dropout is applied only for the vertical connections
# see https://arxiv.org/pdf/1409.2329.pdf
if context_residual:
self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size)
else:
self.output = torch.nn.Linear(hidden_size, vocab_size)
self.att_list = build_attention_list(
eprojs=eprojs, dunits=hidden_size, **att_conf
)
def zero_state(self, hs_pad):
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
if self.dtype == "lstm":
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
for i in range(1, self.dlayers):
z_list[i], c_list[i] = self.decoder[i](
self.dropout_dec[i - 1](z_list[i - 1]),
(z_prev[i], c_prev[i]),
)
else:
z_list[0] = self.decoder[0](ey, z_prev[0])
for i in range(1, self.dlayers):
z_list[i] = self.decoder[i](
self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
)
return z_list, c_list
def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
hs_pad = [hs_pad]
hlens = [hlens]
# attention index for the attention module
# in SPA (speaker parallel attention),
# att_idx is used to select attention module. In other cases, it is 0.
att_idx = min(strm_idx, len(self.att_list) - 1)
# hlens should be list of list of integer
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
# get dim, length info
olength = ys_in_pad.size(1)
# initialization
c_list = [self.zero_state(hs_pad[0])]
z_list = [self.zero_state(hs_pad[0])]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad[0]))
z_list.append(self.zero_state(hs_pad[0]))
z_all = []
if self.num_encs == 1:
att_w = None
self.att_list[att_idx].reset() # reset pre-computation of h
else:
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * self.num_encs # atts
for idx in range(self.num_encs + 1):
# reset pre-computation of h in atts and han
self.att_list[idx].reset()
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in range(olength):
if self.num_encs == 1:
att_c, att_w = self.att_list[att_idx](
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
)
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att_list[idx](
hs_pad[idx],
hlens[idx],
self.dropout_dec[0](z_list[0]),
att_w_list[idx],
)
hs_pad_han = torch.stack(att_c_list, dim=1)
hlens_han = [self.num_encs] * len(ys_in_pad)
att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
hs_pad_han,
hlens_han,
self.dropout_dec[0](z_list[0]),
att_w_list[self.num_encs],
)
if i > 0 and random.random() < self.sampling_probability:
z_out = self.output(z_all[-1])
z_out = np.argmax(z_out.detach().cpu(), axis=1)
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
else:
# utt x (zdim + hdim)
ey = torch.cat((eys[:, i, :], att_c), dim=1)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.context_residual:
z_all.append(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
) # utt x (zdim + hdim)
else:
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
z_all = torch.stack(z_all, dim=1)
z_all = self.output(z_all)
z_all.masked_fill_(
make_pad_mask(ys_in_lens, z_all, 1),
0,
)
return z_all, ys_in_lens
def init_state(self, x):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
c_list = [self.zero_state(x[0].unsqueeze(0))]
z_list = [self.zero_state(x[0].unsqueeze(0))]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(x[0].unsqueeze(0)))
z_list.append(self.zero_state(x[0].unsqueeze(0)))
# TODO(karita): support strm_index for `asr_mix`
strm_index = 0
att_idx = min(strm_index, len(self.att_list) - 1)
if self.num_encs == 1:
a = None
self.att_list[att_idx].reset() # reset pre-computation of h
else:
a = [None] * (self.num_encs + 1) # atts + han
for idx in range(self.num_encs + 1):
# reset pre-computation of h in atts and han
self.att_list[idx].reset()
return dict(
c_prev=c_list[:],
z_prev=z_list[:],
a_prev=a,
workspace=(att_idx, z_list, c_list),
)
def score(self, yseq, state, x):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
att_idx, z_list, c_list = state["workspace"]
vy = yseq[-1].unsqueeze(0)
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
if self.num_encs == 1:
att_c, att_w = self.att_list[att_idx](
x[0].unsqueeze(0),
[x[0].size(0)],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"],
)
else:
att_w = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * self.num_encs # atts
for idx in range(self.num_encs):
att_c_list[idx], att_w[idx] = self.att_list[idx](
x[idx].unsqueeze(0),
[x[idx].size(0)],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"][idx],
)
h_han = torch.stack(att_c_list, dim=1)
att_c, att_w[self.num_encs] = self.att_list[self.num_encs](
h_han,
[self.num_encs],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"][self.num_encs],
)
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(
ey, z_list, c_list, state["z_prev"], state["c_prev"]
)
if self.context_residual:
logits = self.output(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
)
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
logp = F.log_softmax(logits, dim=1).squeeze(0)
return (
logp,
dict(
c_prev=c_list[:],
z_prev=z_list[:],
a_prev=att_w,
workspace=(att_idx, z_list, c_list),
),
)

View File

@@ -0,0 +1,258 @@
"""RNN decoder definition for Transducer models."""
from typing import List, Optional, Tuple
import torch
from typeguard import check_argument_types
from funasr_local.modules.beam_search.beam_search_transducer import Hypothesis
from funasr_local.models.specaug.specaug import SpecAug
class RNNTDecoder(torch.nn.Module):
"""RNN decoder module.
Args:
vocab_size: Vocabulary size.
embed_size: Embedding size.
hidden_size: Hidden size..
rnn_type: Decoder layers type.
num_layers: Number of decoder layers.
dropout_rate: Dropout rate for decoder layers.
embed_dropout_rate: Dropout rate for embedding layer.
embed_pad: Embedding padding symbol ID.
"""
def __init__(
self,
vocab_size: int,
embed_size: int = 256,
hidden_size: int = 256,
rnn_type: str = "lstm",
num_layers: int = 1,
dropout_rate: float = 0.0,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
) -> None:
"""Construct a RNNDecoder object."""
super().__init__()
assert check_argument_types()
if rnn_type not in ("lstm", "gru"):
raise ValueError(f"Not supported: rnn_type={rnn_type}")
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
self.rnn = torch.nn.ModuleList(
[rnn_class(embed_size, hidden_size, 1, batch_first=True)]
)
for _ in range(1, num_layers):
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
self.dropout_rnn = torch.nn.ModuleList(
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
)
self.dlayers = num_layers
self.dtype = rnn_type
self.output_size = hidden_size
self.vocab_size = vocab_size
self.device = next(self.parameters()).device
self.score_cache = {}
def forward(
self,
labels: torch.Tensor,
label_lens: torch.Tensor,
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
) -> torch.Tensor:
"""Encode source label sequences.
Args:
labels: Label ID sequences. (B, L)
states: Decoder hidden states.
((N, B, D_dec), (N, B, D_dec) or None) or None
Returns:
dec_out: Decoder output sequences. (B, U, D_dec)
"""
if states is None:
states = self.init_state(labels.size(0))
dec_embed = self.dropout_embed(self.embed(labels))
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out
def rnn_forward(
self,
x: torch.Tensor,
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Encode source label sequences.
Args:
x: RNN input sequences. (B, D_emb)
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
Returns:
x: RNN output sequences. (B, D_dec)
(h_next, c_next): Decoder hidden states.
(N, B, D_dec), (N, B, D_dec) or None)
"""
h_prev, c_prev = state
h_next, c_next = self.init_state(x.size(0))
for layer in range(self.dlayers):
if self.dtype == "lstm":
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
layer
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
else:
x, h_next[layer : layer + 1] = self.rnn[layer](
x, hx=h_prev[layer : layer + 1]
)
x = self.dropout_rnn[layer](x)
return x, (h_next, c_next)
def score(
self,
label: torch.Tensor,
label_sequence: List[int],
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""One-step forward hypothesis.
Args:
label: Previous label. (1, 1)
label_sequence: Current label sequence.
dec_state: Previous decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None)
Returns:
dec_out: Decoder output sequence. (1, D_dec)
dec_state: Decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None)
"""
str_labels = "_".join(map(str, label_sequence))
if str_labels in self.score_cache:
dec_out, dec_state = self.score_cache[str_labels]
else:
dec_embed = self.embed(label)
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
self.score_cache[str_labels] = (dec_out, dec_state)
return dec_out[0], dec_state
def batch_score(
self,
hyps: List[Hypothesis],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""One-step forward hypotheses.
Args:
hyps: Hypotheses.
Returns:
dec_out: Decoder output sequences. (B, D_dec)
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
dec_embed = self.embed(labels)
states = self.create_batch_states([h.dec_state for h in hyps])
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out.squeeze(1), states
def set_device(self, device: torch.device) -> None:
"""Set GPU device to use.
Args:
device: Device ID.
"""
self.device = device
def init_state(
self, batch_size: int
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
"""Initialize decoder states.
Args:
batch_size: Batch size.
Returns:
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
h_n = torch.zeros(
self.dlayers,
batch_size,
self.output_size,
device=self.device,
)
if self.dtype == "lstm":
c_n = torch.zeros(
self.dlayers,
batch_size,
self.output_size,
device=self.device,
)
return (h_n, c_n)
return (h_n, None)
def select_state(
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Get specified ID state from decoder hidden states.
Args:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
idx: State ID to extract.
Returns:
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
"""
return (
states[0][:, idx : idx + 1, :],
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
)
def create_batch_states(
self,
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Create decoder hidden states.
Args:
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
Returns:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
return (
torch.cat([s[0] for s in new_states], dim=1),
torch.cat([s[1] for s in new_states], dim=1)
if self.dtype == "lstm"
else None,
)

View File

View File

@@ -0,0 +1,19 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
from funasr_local.modules.scorers.scorer_interface import ScorerInterface
class AbsDecoder(torch.nn.Module, ScorerInterface, ABC):
@abstractmethod
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@@ -0,0 +1,776 @@
from typing import List
from typing import Tuple
import logging
import torch
import torch.nn as nn
import numpy as np
from funasr_local.modules.streaming_utils import utils as myutils
from funasr_local.models.decoder.transformer_decoder import BaseTransformerDecoder
from typeguard import check_argument_types
from funasr_local.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr_local.modules.embedding import PositionalEncoding
from funasr_local.modules.layer_norm import LayerNorm
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr_local.modules.repeat import repeat
from funasr_local.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
class ContextualDecoderLayer(nn.Module):
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(ContextualDecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
if self_attn is not None:
self.norm2 = LayerNorm(size)
if src_attn is not None:
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
# tgt = self.dropout(tgt)
if isinstance(tgt, Tuple):
tgt, _ = tgt
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
if self.normalize_before:
tgt = self.norm2(tgt)
if self.training:
cache = None
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
x = residual + self.dropout(x)
x_self_attn = x
residual = x
if self.normalize_before:
x = self.norm3(x)
x = self.src_attn(x, memory, memory_mask)
x_src_attn = x
x = residual + self.dropout(x)
return x, tgt_mask, x_self_attn, x_src_attn
class ContextualBiasDecoder(nn.Module):
def __init__(
self,
size,
src_attn,
dropout_rate,
normalize_before=True,
):
"""Construct an DecoderLayer object."""
super(ContextualBiasDecoder, self).__init__()
self.size = size
self.src_attn = src_attn
if src_attn is not None:
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
x = tgt
if self.src_attn is not None:
if self.normalize_before:
x = self.norm3(x)
x = self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
class ContextualParaformerDecoder(ParaformerSANMDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
att_layer_num: int = 6,
kernel_size: int = 21,
sanm_shfit: int = 0,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
if input_layer == 'none':
self.embed = None
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
# pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = None
self.att_layer_num = att_layer_num
self.num_blocks = num_blocks
if sanm_shfit is None:
sanm_shfit = (kernel_size - 1) // 2
self.decoders = repeat(
att_layer_num - 1,
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.dropout = nn.Dropout(dropout_rate)
self.bias_decoder = ContextualBiasDecoder(
size=attention_dim,
src_attn=MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
),
dropout_rate=dropout_rate,
normalize_before=True,
)
self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
self.last_decoder = ContextualDecoderLayer(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
)
if num_blocks - att_layer_num <= 0:
self.decoders2 = None
else:
self.decoders2 = repeat(
num_blocks - att_layer_num,
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
),
None,
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.decoders3 = repeat(
1,
lambda lnum: DecoderLayerSANM(
attention_dim,
None,
None,
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
contextual_info: torch.Tensor,
return_hidden: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
memory = hs_pad
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
x = tgt
x, tgt_mask, memory, memory_mask, _ = self.decoders(
x, tgt_mask, memory, memory_mask
)
_, _, x_self_attn, x_src_attn = self.last_decoder(
x, tgt_mask, memory, memory_mask
)
# contextual paraformer related
contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
if self.bias_output is not None:
x = torch.cat([x_src_attn, cx], dim=2)
x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
x = x_self_attn + self.dropout(x)
if self.decoders2 is not None:
x, tgt_mask, memory, memory_mask, _ = self.decoders2(
x, tgt_mask, memory, memory_mask
)
x, tgt_mask, memory, memory_mask, _ = self.decoders3(
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
olens = tgt_mask.sum(1)
if self.output_layer is not None and return_hidden is False:
x = self.output_layer(x)
return x, olens
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
map_dict_local = {
## decoder
# ffn
"{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
# fsmn
"{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 2, 0),
}, # (256,1,31),(1,31,256,1)
# src att
"{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# dnn
"{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
# embed_concat_ffn
"{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
# out norm
"{}.after_norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.after_norm.bias".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# in embed
"{}.embed.0.weight".format(tensor_name_prefix_torch):
{"name": "{}/w_embs".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (4235,256),(4235,256)
# out layer
"{}.output_layer.weight".format(tensor_name_prefix_torch):
{"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
"squeeze": [None, None],
"transpose": [(1, 0), None],
}, # (4235,256),(256,4235)
"{}.output_layer.bias".format(tensor_name_prefix_torch):
{"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
"seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
"squeeze": [None, None],
"transpose": [None, None],
}, # (4235,),(4235,)
## clas decoder
# src att
"{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
"{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (1024,256),(1,256,1024)
"{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (1024,),(1024,)
"{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
"squeeze": 0,
"transpose": (1, 0),
}, # (256,256),(1,256,256)
"{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
# dnn
"{}.bias_output.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
"squeeze": None,
"transpose": (2, 1, 0),
}, # (1024,256),(1,256,1024)
}
return map_dict_local
def convert_tf2torch(self,
var_dict_tf,
var_dict_torch,
):
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
decoder_layeridx_sets = set()
for name in sorted(var_dict_torch.keys(), reverse=False):
names = name.split('.')
if names[0] == self.tf2torch_tensor_name_prefix_torch:
if names[1] == "decoders":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
layeridx_bias = 0
layeridx += layeridx_bias
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "last_decoder":
layeridx = 15
name_q = name.replace("last_decoder", "decoders.layeridx")
layeridx_bias = 0
layeridx += layeridx_bias
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "decoders2":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
name_q = name_q.replace("decoders2", "decoders")
layeridx_bias = len(decoder_layeridx_sets)
layeridx += layeridx_bias
if "decoders." in name:
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "decoders3":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
layeridx_bias = 0
layeridx += layeridx_bias
if "decoders." in name:
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "bias_decoder":
name_q = name
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
name_tf = map_dict[name]["name"]
if isinstance(name_tf, list):
idx_list = 0
if name_tf[idx_list] in var_dict_tf.keys():
pass
else:
idx_list = 1
data_tf = var_dict_tf[name_tf[idx_list]]
if map_dict[name]["squeeze"][idx_list] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
if map_dict[name]["transpose"][idx_list] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
name_tf[idx_list],
var_dict_tf[name_tf[
idx_list]].shape))
else:
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
if map_dict[name]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
elif names[1] == "after_norm":
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
elif names[1] == "embed_concat_ffn":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
layeridx_bias = 0
layeridx += layeridx_bias
if "decoders." in name:
decoder_layeridx_sets.add(layeridx)
if name_q in map_dict.keys():
name_v = map_dict[name_q]["name"]
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
data_tf = var_dict_tf[name_tf]
if map_dict[name_q]["squeeze"] is not None:
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
if map_dict[name_q]["transpose"] is not None:
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
var_dict_torch[
name].size(),
data_tf.size())
var_dict_torch_update[name] = data_tf
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
return var_dict_torch_update

View File

@@ -0,0 +1,334 @@
import random
import numpy as np
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.nets_utils import to_device
from funasr_local.modules.rnn.attentions import initial_att
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.utils.get_default_kwargs import get_default_kwargs
def build_attention_list(
eprojs: int,
dunits: int,
atype: str = "location",
num_att: int = 1,
num_encs: int = 1,
aheads: int = 4,
adim: int = 320,
awin: int = 5,
aconv_chans: int = 10,
aconv_filts: int = 100,
han_mode: bool = False,
han_type=None,
han_heads: int = 4,
han_dim: int = 320,
han_conv_chans: int = -1,
han_conv_filts: int = 100,
han_win: int = 5,
):
att_list = torch.nn.ModuleList()
if num_encs == 1:
for i in range(num_att):
att = initial_att(
atype,
eprojs,
dunits,
aheads,
adim,
awin,
aconv_chans,
aconv_filts,
)
att_list.append(att)
elif num_encs > 1: # no multi-speaker mode
if han_mode:
att = initial_att(
han_type,
eprojs,
dunits,
han_heads,
han_dim,
han_win,
han_conv_chans,
han_conv_filts,
han_mode=True,
)
return att
else:
att_list = torch.nn.ModuleList()
for idx in range(num_encs):
att = initial_att(
atype[idx],
eprojs,
dunits,
aheads[idx],
adim[idx],
awin[idx],
aconv_chans[idx],
aconv_filts[idx],
)
att_list.append(att)
else:
raise ValueError(
"Number of encoders needs to be more than one. {}".format(num_encs)
)
return att_list
class RNNDecoder(AbsDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
rnn_type: str = "lstm",
num_layers: int = 1,
hidden_size: int = 320,
sampling_probability: float = 0.0,
dropout: float = 0.0,
context_residual: bool = False,
replace_sos: bool = False,
num_encs: int = 1,
att_conf: dict = get_default_kwargs(build_attention_list),
):
# FIXME(kamo): The parts of num_spk should be refactored more more more
assert check_argument_types()
if rnn_type not in {"lstm", "gru"}:
raise ValueError(f"Not supported: rnn_type={rnn_type}")
super().__init__()
eprojs = encoder_output_size
self.dtype = rnn_type
self.dunits = hidden_size
self.dlayers = num_layers
self.context_residual = context_residual
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.odim = vocab_size
self.sampling_probability = sampling_probability
self.dropout = dropout
self.num_encs = num_encs
# for multilingual translation
self.replace_sos = replace_sos
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
self.dropout_emb = torch.nn.Dropout(p=dropout)
self.decoder = torch.nn.ModuleList()
self.dropout_dec = torch.nn.ModuleList()
self.decoder += [
torch.nn.LSTMCell(hidden_size + eprojs, hidden_size)
if self.dtype == "lstm"
else torch.nn.GRUCell(hidden_size + eprojs, hidden_size)
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
for _ in range(1, self.dlayers):
self.decoder += [
torch.nn.LSTMCell(hidden_size, hidden_size)
if self.dtype == "lstm"
else torch.nn.GRUCell(hidden_size, hidden_size)
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
# NOTE: dropout is applied only for the vertical connections
# see https://arxiv.org/pdf/1409.2329.pdf
if context_residual:
self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size)
else:
self.output = torch.nn.Linear(hidden_size, vocab_size)
self.att_list = build_attention_list(
eprojs=eprojs, dunits=hidden_size, **att_conf
)
def zero_state(self, hs_pad):
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
if self.dtype == "lstm":
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
for i in range(1, self.dlayers):
z_list[i], c_list[i] = self.decoder[i](
self.dropout_dec[i - 1](z_list[i - 1]),
(z_prev[i], c_prev[i]),
)
else:
z_list[0] = self.decoder[0](ey, z_prev[0])
for i in range(1, self.dlayers):
z_list[i] = self.decoder[i](
self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
)
return z_list, c_list
def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
hs_pad = [hs_pad]
hlens = [hlens]
# attention index for the attention module
# in SPA (speaker parallel attention),
# att_idx is used to select attention module. In other cases, it is 0.
att_idx = min(strm_idx, len(self.att_list) - 1)
# hlens should be list of list of integer
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
# get dim, length info
olength = ys_in_pad.size(1)
# initialization
c_list = [self.zero_state(hs_pad[0])]
z_list = [self.zero_state(hs_pad[0])]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad[0]))
z_list.append(self.zero_state(hs_pad[0]))
z_all = []
if self.num_encs == 1:
att_w = None
self.att_list[att_idx].reset() # reset pre-computation of h
else:
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * self.num_encs # atts
for idx in range(self.num_encs + 1):
# reset pre-computation of h in atts and han
self.att_list[idx].reset()
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in range(olength):
if self.num_encs == 1:
att_c, att_w = self.att_list[att_idx](
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
)
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att_list[idx](
hs_pad[idx],
hlens[idx],
self.dropout_dec[0](z_list[0]),
att_w_list[idx],
)
hs_pad_han = torch.stack(att_c_list, dim=1)
hlens_han = [self.num_encs] * len(ys_in_pad)
att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
hs_pad_han,
hlens_han,
self.dropout_dec[0](z_list[0]),
att_w_list[self.num_encs],
)
if i > 0 and random.random() < self.sampling_probability:
z_out = self.output(z_all[-1])
z_out = np.argmax(z_out.detach().cpu(), axis=1)
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
else:
# utt x (zdim + hdim)
ey = torch.cat((eys[:, i, :], att_c), dim=1)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.context_residual:
z_all.append(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
) # utt x (zdim + hdim)
else:
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
z_all = torch.stack(z_all, dim=1)
z_all = self.output(z_all)
z_all.masked_fill_(
make_pad_mask(ys_in_lens, z_all, 1),
0,
)
return z_all, ys_in_lens
def init_state(self, x):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
c_list = [self.zero_state(x[0].unsqueeze(0))]
z_list = [self.zero_state(x[0].unsqueeze(0))]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(x[0].unsqueeze(0)))
z_list.append(self.zero_state(x[0].unsqueeze(0)))
# TODO(karita): support strm_index for `asr_mix`
strm_index = 0
att_idx = min(strm_index, len(self.att_list) - 1)
if self.num_encs == 1:
a = None
self.att_list[att_idx].reset() # reset pre-computation of h
else:
a = [None] * (self.num_encs + 1) # atts + han
for idx in range(self.num_encs + 1):
# reset pre-computation of h in atts and han
self.att_list[idx].reset()
return dict(
c_prev=c_list[:],
z_prev=z_list[:],
a_prev=a,
workspace=(att_idx, z_list, c_list),
)
def score(self, yseq, state, x):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
att_idx, z_list, c_list = state["workspace"]
vy = yseq[-1].unsqueeze(0)
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
if self.num_encs == 1:
att_c, att_w = self.att_list[att_idx](
x[0].unsqueeze(0),
[x[0].size(0)],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"],
)
else:
att_w = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * self.num_encs # atts
for idx in range(self.num_encs):
att_c_list[idx], att_w[idx] = self.att_list[idx](
x[idx].unsqueeze(0),
[x[idx].size(0)],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"][idx],
)
h_han = torch.stack(att_c_list, dim=1)
att_c, att_w[self.num_encs] = self.att_list[self.num_encs](
h_han,
[self.num_encs],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"][self.num_encs],
)
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(
ey, z_list, c_list, state["z_prev"], state["c_prev"]
)
if self.context_residual:
logits = self.output(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
)
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
logp = F.log_softmax(logits, dim=1).squeeze(0)
return (
logp,
dict(
c_prev=c_list[:],
z_prev=z_list[:],
a_prev=att_w,
workspace=(att_idx, z_list, c_list),
),
)

View File

@@ -0,0 +1,258 @@
"""RNN decoder definition for Transducer models."""
from typing import List, Optional, Tuple
import torch
from typeguard import check_argument_types
from funasr_local.modules.beam_search.beam_search_transducer import Hypothesis
from funasr_local.models.specaug.specaug import SpecAug
class RNNTDecoder(torch.nn.Module):
"""RNN decoder module.
Args:
vocab_size: Vocabulary size.
embed_size: Embedding size.
hidden_size: Hidden size..
rnn_type: Decoder layers type.
num_layers: Number of decoder layers.
dropout_rate: Dropout rate for decoder layers.
embed_dropout_rate: Dropout rate for embedding layer.
embed_pad: Embedding padding symbol ID.
"""
def __init__(
self,
vocab_size: int,
embed_size: int = 256,
hidden_size: int = 256,
rnn_type: str = "lstm",
num_layers: int = 1,
dropout_rate: float = 0.0,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
) -> None:
"""Construct a RNNDecoder object."""
super().__init__()
assert check_argument_types()
if rnn_type not in ("lstm", "gru"):
raise ValueError(f"Not supported: rnn_type={rnn_type}")
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
self.rnn = torch.nn.ModuleList(
[rnn_class(embed_size, hidden_size, 1, batch_first=True)]
)
for _ in range(1, num_layers):
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
self.dropout_rnn = torch.nn.ModuleList(
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
)
self.dlayers = num_layers
self.dtype = rnn_type
self.output_size = hidden_size
self.vocab_size = vocab_size
self.device = next(self.parameters()).device
self.score_cache = {}
def forward(
self,
labels: torch.Tensor,
label_lens: torch.Tensor,
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
) -> torch.Tensor:
"""Encode source label sequences.
Args:
labels: Label ID sequences. (B, L)
states: Decoder hidden states.
((N, B, D_dec), (N, B, D_dec) or None) or None
Returns:
dec_out: Decoder output sequences. (B, U, D_dec)
"""
if states is None:
states = self.init_state(labels.size(0))
dec_embed = self.dropout_embed(self.embed(labels))
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out
def rnn_forward(
self,
x: torch.Tensor,
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Encode source label sequences.
Args:
x: RNN input sequences. (B, D_emb)
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
Returns:
x: RNN output sequences. (B, D_dec)
(h_next, c_next): Decoder hidden states.
(N, B, D_dec), (N, B, D_dec) or None)
"""
h_prev, c_prev = state
h_next, c_next = self.init_state(x.size(0))
for layer in range(self.dlayers):
if self.dtype == "lstm":
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
layer
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
else:
x, h_next[layer : layer + 1] = self.rnn[layer](
x, hx=h_prev[layer : layer + 1]
)
x = self.dropout_rnn[layer](x)
return x, (h_next, c_next)
def score(
self,
label: torch.Tensor,
label_sequence: List[int],
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""One-step forward hypothesis.
Args:
label: Previous label. (1, 1)
label_sequence: Current label sequence.
dec_state: Previous decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None)
Returns:
dec_out: Decoder output sequence. (1, D_dec)
dec_state: Decoder hidden states.
((N, 1, D_dec), (N, 1, D_dec) or None)
"""
str_labels = "_".join(map(str, label_sequence))
if str_labels in self.score_cache:
dec_out, dec_state = self.score_cache[str_labels]
else:
dec_embed = self.embed(label)
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
self.score_cache[str_labels] = (dec_out, dec_state)
return dec_out[0], dec_state
def batch_score(
self,
hyps: List[Hypothesis],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""One-step forward hypotheses.
Args:
hyps: Hypotheses.
Returns:
dec_out: Decoder output sequences. (B, D_dec)
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
dec_embed = self.embed(labels)
states = self.create_batch_states([h.dec_state for h in hyps])
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out.squeeze(1), states
def set_device(self, device: torch.device) -> None:
"""Set GPU device to use.
Args:
device: Device ID.
"""
self.device = device
def init_state(
self, batch_size: int
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
"""Initialize decoder states.
Args:
batch_size: Batch size.
Returns:
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
h_n = torch.zeros(
self.dlayers,
batch_size,
self.output_size,
device=self.device,
)
if self.dtype == "lstm":
c_n = torch.zeros(
self.dlayers,
batch_size,
self.output_size,
device=self.device,
)
return (h_n, c_n)
return (h_n, None)
def select_state(
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Get specified ID state from decoder hidden states.
Args:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
idx: State ID to extract.
Returns:
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
"""
return (
states[0][:, idx : idx + 1, :],
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
)
def create_batch_states(
self,
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Create decoder hidden states.
Args:
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
Returns:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
"""
return (
torch.cat([s[0] for s in new_states], dim=1),
torch.cat([s[1] for s in new_states], dim=1)
if self.dtype == "lstm"
else None,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
import torch
from torch.nn import functional as F
from funasr_local.models.decoder.abs_decoder import AbsDecoder
class DenseDecoder(AbsDecoder):
def __init__(
self,
vocab_size,
encoder_output_size,
num_nodes_resnet1: int = 256,
num_nodes_last_layer: int = 256,
batchnorm_momentum: float = 0.5,
):
super(DenseDecoder, self).__init__()
self.resnet1_dense = torch.nn.Linear(encoder_output_size, num_nodes_resnet1)
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
self.output_dense = torch.nn.Linear(num_nodes_last_layer, vocab_size, bias=False)
def forward(self, features):
embeddings = {}
features = self.resnet1_dense(features)
embeddings["resnet1_dense"] = features
features = F.relu(features)
features = self.resnet1_bn(features)
features = self.resnet2_dense(features)
embeddings["resnet2_dense"] = features
features = F.relu(features)
features = self.resnet2_bn(features)
features = self.output_dense(features)
return features, embeddings

View File

@@ -0,0 +1,766 @@
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Decoder definition."""
from typing import Any
from typing import List
from typing import Sequence
from typing import Tuple
import torch
from torch import nn
from typeguard import check_argument_types
from funasr_local.models.decoder.abs_decoder import AbsDecoder
from funasr_local.modules.attention import MultiHeadedAttention
from funasr_local.modules.dynamic_conv import DynamicConvolution
from funasr_local.modules.dynamic_conv2d import DynamicConvolution2D
from funasr_local.modules.embedding import PositionalEncoding
from funasr_local.modules.layer_norm import LayerNorm
from funasr_local.modules.lightconv import LightweightConvolution
from funasr_local.modules.lightconv2d import LightweightConvolution2D
from funasr_local.modules.mask import subsequent_mask
from funasr_local.modules.nets_utils import make_pad_mask
from funasr_local.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr_local.modules.repeat import repeat
from funasr_local.modules.scorers.scorer_interface import BatchScorerInterface
class DecoderLayer(nn.Module):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
if not self.normalize_before:
x = self.norm2(x)
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask
class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the number of units of position-wise feed forward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before: whether to use layer_norm before the first block
concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied.
i.e. x -> x + att(x)
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
):
assert check_argument_types()
super().__init__()
attention_dim = encoder_output_size
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = None
# Must set by the inheritance
self.decoders = None
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
memory = hs_pad
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device
)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
memory_mask = torch.nn.functional.pad(
memory_mask, (0, padlen), "constant", False
)
x = self.embed(tgt)
x, tgt_mask, memory, memory_mask = self.decoders(
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
return x, olens
def forward_one_step(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
Args:
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
memory: encoded memory, float32 (batch, maxlen_in, feat)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x = self.embed(tgt)
if cache is None:
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
def score(self, ys, state, x):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
logp, state = self.forward_one_step(
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
)
return logp.squeeze(0), state
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.decoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
torch.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
# batch decoding
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
class TransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class ParaformerDecoderSAN(BaseTransformerDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
embeds_id: int = -1,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, self_attention_dropout_rate
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.embeds_id = embeds_id
self.attention_dim = attention_dim
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
memory = hs_pad
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device
)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
memory_mask = torch.nn.functional.pad(
memory_mask, (0, padlen), "constant", False
)
# x = self.embed(tgt)
x = tgt
embeds_outputs = None
for layer_id, decoder in enumerate(self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, memory_mask
)
if layer_id == self.embeds_id:
embeds_outputs = x
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
if embeds_outputs is not None:
return x, olens, embeds_outputs
else:
return x, olens
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
LightweightConvolution(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
LightweightConvolution2D(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
DynamicConvolution(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
f"{len(conv_kernel_length)} != {num_blocks}"
)
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
DynamicConvolution2D(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)