mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 09:59:18 +08:00
add files
This commit is contained in:
@@ -0,0 +1,334 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.nets_utils import to_device
|
||||
from funasr_local.modules.rnn.attentions import initial_att
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr_local.utils.get_default_kwargs import get_default_kwargs
|
||||
|
||||
|
||||
def build_attention_list(
|
||||
eprojs: int,
|
||||
dunits: int,
|
||||
atype: str = "location",
|
||||
num_att: int = 1,
|
||||
num_encs: int = 1,
|
||||
aheads: int = 4,
|
||||
adim: int = 320,
|
||||
awin: int = 5,
|
||||
aconv_chans: int = 10,
|
||||
aconv_filts: int = 100,
|
||||
han_mode: bool = False,
|
||||
han_type=None,
|
||||
han_heads: int = 4,
|
||||
han_dim: int = 320,
|
||||
han_conv_chans: int = -1,
|
||||
han_conv_filts: int = 100,
|
||||
han_win: int = 5,
|
||||
):
|
||||
|
||||
att_list = torch.nn.ModuleList()
|
||||
if num_encs == 1:
|
||||
for i in range(num_att):
|
||||
att = initial_att(
|
||||
atype,
|
||||
eprojs,
|
||||
dunits,
|
||||
aheads,
|
||||
adim,
|
||||
awin,
|
||||
aconv_chans,
|
||||
aconv_filts,
|
||||
)
|
||||
att_list.append(att)
|
||||
elif num_encs > 1: # no multi-speaker mode
|
||||
if han_mode:
|
||||
att = initial_att(
|
||||
han_type,
|
||||
eprojs,
|
||||
dunits,
|
||||
han_heads,
|
||||
han_dim,
|
||||
han_win,
|
||||
han_conv_chans,
|
||||
han_conv_filts,
|
||||
han_mode=True,
|
||||
)
|
||||
return att
|
||||
else:
|
||||
att_list = torch.nn.ModuleList()
|
||||
for idx in range(num_encs):
|
||||
att = initial_att(
|
||||
atype[idx],
|
||||
eprojs,
|
||||
dunits,
|
||||
aheads[idx],
|
||||
adim[idx],
|
||||
awin[idx],
|
||||
aconv_chans[idx],
|
||||
aconv_filts[idx],
|
||||
)
|
||||
att_list.append(att)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Number of encoders needs to be more than one. {}".format(num_encs)
|
||||
)
|
||||
return att_list
|
||||
|
||||
|
||||
class RNNDecoder(AbsDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
rnn_type: str = "lstm",
|
||||
num_layers: int = 1,
|
||||
hidden_size: int = 320,
|
||||
sampling_probability: float = 0.0,
|
||||
dropout: float = 0.0,
|
||||
context_residual: bool = False,
|
||||
replace_sos: bool = False,
|
||||
num_encs: int = 1,
|
||||
att_conf: dict = get_default_kwargs(build_attention_list),
|
||||
):
|
||||
# FIXME(kamo): The parts of num_spk should be refactored more more more
|
||||
assert check_argument_types()
|
||||
if rnn_type not in {"lstm", "gru"}:
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
super().__init__()
|
||||
eprojs = encoder_output_size
|
||||
self.dtype = rnn_type
|
||||
self.dunits = hidden_size
|
||||
self.dlayers = num_layers
|
||||
self.context_residual = context_residual
|
||||
self.sos = vocab_size - 1
|
||||
self.eos = vocab_size - 1
|
||||
self.odim = vocab_size
|
||||
self.sampling_probability = sampling_probability
|
||||
self.dropout = dropout
|
||||
self.num_encs = num_encs
|
||||
|
||||
# for multilingual translation
|
||||
self.replace_sos = replace_sos
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
|
||||
self.dropout_emb = torch.nn.Dropout(p=dropout)
|
||||
|
||||
self.decoder = torch.nn.ModuleList()
|
||||
self.dropout_dec = torch.nn.ModuleList()
|
||||
self.decoder += [
|
||||
torch.nn.LSTMCell(hidden_size + eprojs, hidden_size)
|
||||
if self.dtype == "lstm"
|
||||
else torch.nn.GRUCell(hidden_size + eprojs, hidden_size)
|
||||
]
|
||||
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
||||
for _ in range(1, self.dlayers):
|
||||
self.decoder += [
|
||||
torch.nn.LSTMCell(hidden_size, hidden_size)
|
||||
if self.dtype == "lstm"
|
||||
else torch.nn.GRUCell(hidden_size, hidden_size)
|
||||
]
|
||||
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
||||
# NOTE: dropout is applied only for the vertical connections
|
||||
# see https://arxiv.org/pdf/1409.2329.pdf
|
||||
|
||||
if context_residual:
|
||||
self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size)
|
||||
else:
|
||||
self.output = torch.nn.Linear(hidden_size, vocab_size)
|
||||
|
||||
self.att_list = build_attention_list(
|
||||
eprojs=eprojs, dunits=hidden_size, **att_conf
|
||||
)
|
||||
|
||||
def zero_state(self, hs_pad):
|
||||
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
|
||||
|
||||
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
|
||||
if self.dtype == "lstm":
|
||||
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
|
||||
for i in range(1, self.dlayers):
|
||||
z_list[i], c_list[i] = self.decoder[i](
|
||||
self.dropout_dec[i - 1](z_list[i - 1]),
|
||||
(z_prev[i], c_prev[i]),
|
||||
)
|
||||
else:
|
||||
z_list[0] = self.decoder[0](ey, z_prev[0])
|
||||
for i in range(1, self.dlayers):
|
||||
z_list[i] = self.decoder[i](
|
||||
self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
|
||||
)
|
||||
return z_list, c_list
|
||||
|
||||
def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
hs_pad = [hs_pad]
|
||||
hlens = [hlens]
|
||||
|
||||
# attention index for the attention module
|
||||
# in SPA (speaker parallel attention),
|
||||
# att_idx is used to select attention module. In other cases, it is 0.
|
||||
att_idx = min(strm_idx, len(self.att_list) - 1)
|
||||
|
||||
# hlens should be list of list of integer
|
||||
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
|
||||
|
||||
# get dim, length info
|
||||
olength = ys_in_pad.size(1)
|
||||
|
||||
# initialization
|
||||
c_list = [self.zero_state(hs_pad[0])]
|
||||
z_list = [self.zero_state(hs_pad[0])]
|
||||
for _ in range(1, self.dlayers):
|
||||
c_list.append(self.zero_state(hs_pad[0]))
|
||||
z_list.append(self.zero_state(hs_pad[0]))
|
||||
z_all = []
|
||||
if self.num_encs == 1:
|
||||
att_w = None
|
||||
self.att_list[att_idx].reset() # reset pre-computation of h
|
||||
else:
|
||||
att_w_list = [None] * (self.num_encs + 1) # atts + han
|
||||
att_c_list = [None] * self.num_encs # atts
|
||||
for idx in range(self.num_encs + 1):
|
||||
# reset pre-computation of h in atts and han
|
||||
self.att_list[idx].reset()
|
||||
|
||||
# pre-computation of embedding
|
||||
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
|
||||
|
||||
# loop for an output sequence
|
||||
for i in range(olength):
|
||||
if self.num_encs == 1:
|
||||
att_c, att_w = self.att_list[att_idx](
|
||||
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
|
||||
)
|
||||
else:
|
||||
for idx in range(self.num_encs):
|
||||
att_c_list[idx], att_w_list[idx] = self.att_list[idx](
|
||||
hs_pad[idx],
|
||||
hlens[idx],
|
||||
self.dropout_dec[0](z_list[0]),
|
||||
att_w_list[idx],
|
||||
)
|
||||
hs_pad_han = torch.stack(att_c_list, dim=1)
|
||||
hlens_han = [self.num_encs] * len(ys_in_pad)
|
||||
att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
|
||||
hs_pad_han,
|
||||
hlens_han,
|
||||
self.dropout_dec[0](z_list[0]),
|
||||
att_w_list[self.num_encs],
|
||||
)
|
||||
if i > 0 and random.random() < self.sampling_probability:
|
||||
z_out = self.output(z_all[-1])
|
||||
z_out = np.argmax(z_out.detach().cpu(), axis=1)
|
||||
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
|
||||
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
|
||||
else:
|
||||
# utt x (zdim + hdim)
|
||||
ey = torch.cat((eys[:, i, :], att_c), dim=1)
|
||||
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
|
||||
if self.context_residual:
|
||||
z_all.append(
|
||||
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
||||
) # utt x (zdim + hdim)
|
||||
else:
|
||||
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
|
||||
|
||||
z_all = torch.stack(z_all, dim=1)
|
||||
z_all = self.output(z_all)
|
||||
z_all.masked_fill_(
|
||||
make_pad_mask(ys_in_lens, z_all, 1),
|
||||
0,
|
||||
)
|
||||
return z_all, ys_in_lens
|
||||
|
||||
def init_state(self, x):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
x = [x]
|
||||
|
||||
c_list = [self.zero_state(x[0].unsqueeze(0))]
|
||||
z_list = [self.zero_state(x[0].unsqueeze(0))]
|
||||
for _ in range(1, self.dlayers):
|
||||
c_list.append(self.zero_state(x[0].unsqueeze(0)))
|
||||
z_list.append(self.zero_state(x[0].unsqueeze(0)))
|
||||
# TODO(karita): support strm_index for `asr_mix`
|
||||
strm_index = 0
|
||||
att_idx = min(strm_index, len(self.att_list) - 1)
|
||||
if self.num_encs == 1:
|
||||
a = None
|
||||
self.att_list[att_idx].reset() # reset pre-computation of h
|
||||
else:
|
||||
a = [None] * (self.num_encs + 1) # atts + han
|
||||
for idx in range(self.num_encs + 1):
|
||||
# reset pre-computation of h in atts and han
|
||||
self.att_list[idx].reset()
|
||||
return dict(
|
||||
c_prev=c_list[:],
|
||||
z_prev=z_list[:],
|
||||
a_prev=a,
|
||||
workspace=(att_idx, z_list, c_list),
|
||||
)
|
||||
|
||||
def score(self, yseq, state, x):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
x = [x]
|
||||
|
||||
att_idx, z_list, c_list = state["workspace"]
|
||||
vy = yseq[-1].unsqueeze(0)
|
||||
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
|
||||
if self.num_encs == 1:
|
||||
att_c, att_w = self.att_list[att_idx](
|
||||
x[0].unsqueeze(0),
|
||||
[x[0].size(0)],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"],
|
||||
)
|
||||
else:
|
||||
att_w = [None] * (self.num_encs + 1) # atts + han
|
||||
att_c_list = [None] * self.num_encs # atts
|
||||
for idx in range(self.num_encs):
|
||||
att_c_list[idx], att_w[idx] = self.att_list[idx](
|
||||
x[idx].unsqueeze(0),
|
||||
[x[idx].size(0)],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"][idx],
|
||||
)
|
||||
h_han = torch.stack(att_c_list, dim=1)
|
||||
att_c, att_w[self.num_encs] = self.att_list[self.num_encs](
|
||||
h_han,
|
||||
[self.num_encs],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"][self.num_encs],
|
||||
)
|
||||
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
|
||||
z_list, c_list = self.rnn_forward(
|
||||
ey, z_list, c_list, state["z_prev"], state["c_prev"]
|
||||
)
|
||||
if self.context_residual:
|
||||
logits = self.output(
|
||||
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
||||
)
|
||||
else:
|
||||
logits = self.output(self.dropout_dec[-1](z_list[-1]))
|
||||
logp = F.log_softmax(logits, dim=1).squeeze(0)
|
||||
return (
|
||||
logp,
|
||||
dict(
|
||||
c_prev=c_list[:],
|
||||
z_prev=z_list[:],
|
||||
a_prev=att_w,
|
||||
workspace=(att_idx, z_list, c_list),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,258 @@
|
||||
"""RNN decoder definition for Transducer models."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.beam_search.beam_search_transducer import Hypothesis
|
||||
from funasr_local.models.specaug.specaug import SpecAug
|
||||
|
||||
class RNNTDecoder(torch.nn.Module):
|
||||
"""RNN decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: Vocabulary size.
|
||||
embed_size: Embedding size.
|
||||
hidden_size: Hidden size..
|
||||
rnn_type: Decoder layers type.
|
||||
num_layers: Number of decoder layers.
|
||||
dropout_rate: Dropout rate for decoder layers.
|
||||
embed_dropout_rate: Dropout rate for embedding layer.
|
||||
embed_pad: Embedding padding symbol ID.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embed_size: int = 256,
|
||||
hidden_size: int = 256,
|
||||
rnn_type: str = "lstm",
|
||||
num_layers: int = 1,
|
||||
dropout_rate: float = 0.0,
|
||||
embed_dropout_rate: float = 0.0,
|
||||
embed_pad: int = 0,
|
||||
) -> None:
|
||||
"""Construct a RNNDecoder object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
if rnn_type not in ("lstm", "gru"):
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
|
||||
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
|
||||
|
||||
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
|
||||
|
||||
self.rnn = torch.nn.ModuleList(
|
||||
[rnn_class(embed_size, hidden_size, 1, batch_first=True)]
|
||||
)
|
||||
|
||||
for _ in range(1, num_layers):
|
||||
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
|
||||
|
||||
self.dropout_rnn = torch.nn.ModuleList(
|
||||
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
|
||||
)
|
||||
|
||||
self.dlayers = num_layers
|
||||
self.dtype = rnn_type
|
||||
|
||||
self.output_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.device = next(self.parameters()).device
|
||||
self.score_cache = {}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
labels: torch.Tensor,
|
||||
label_lens: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
states: Decoder hidden states.
|
||||
((N, B, D_dec), (N, B, D_dec) or None) or None
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, U, D_dec)
|
||||
|
||||
"""
|
||||
if states is None:
|
||||
states = self.init_state(labels.size(0))
|
||||
|
||||
dec_embed = self.dropout_embed(self.embed(labels))
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
return dec_out
|
||||
|
||||
def rnn_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
x: RNN input sequences. (B, D_emb)
|
||||
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
x: RNN output sequences. (B, D_dec)
|
||||
(h_next, c_next): Decoder hidden states.
|
||||
(N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_prev, c_prev = state
|
||||
h_next, c_next = self.init_state(x.size(0))
|
||||
|
||||
for layer in range(self.dlayers):
|
||||
if self.dtype == "lstm":
|
||||
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
|
||||
layer
|
||||
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
|
||||
else:
|
||||
x, h_next[layer : layer + 1] = self.rnn[layer](
|
||||
x, hx=h_prev[layer : layer + 1]
|
||||
)
|
||||
|
||||
x = self.dropout_rnn[layer](x)
|
||||
|
||||
return x, (h_next, c_next)
|
||||
|
||||
def score(
|
||||
self,
|
||||
label: torch.Tensor,
|
||||
label_sequence: List[int],
|
||||
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypothesis.
|
||||
|
||||
Args:
|
||||
label: Previous label. (1, 1)
|
||||
label_sequence: Current label sequence.
|
||||
dec_state: Previous decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequence. (1, D_dec)
|
||||
dec_state: Decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
str_labels = "_".join(map(str, label_sequence))
|
||||
|
||||
if str_labels in self.score_cache:
|
||||
dec_out, dec_state = self.score_cache[str_labels]
|
||||
else:
|
||||
dec_embed = self.embed(label)
|
||||
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
|
||||
|
||||
self.score_cache[str_labels] = (dec_out, dec_state)
|
||||
|
||||
return dec_out[0], dec_state
|
||||
|
||||
def batch_score(
|
||||
self,
|
||||
hyps: List[Hypothesis],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypotheses.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, D_dec)
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
|
||||
dec_embed = self.embed(labels)
|
||||
|
||||
states = self.create_batch_states([h.dec_state for h in hyps])
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
|
||||
return dec_out.squeeze(1), states
|
||||
|
||||
def set_device(self, device: torch.device) -> None:
|
||||
"""Set GPU device to use.
|
||||
|
||||
Args:
|
||||
device: Device ID.
|
||||
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
def init_state(
|
||||
self, batch_size: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
|
||||
"""Initialize decoder states.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size.
|
||||
|
||||
Returns:
|
||||
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.dtype == "lstm":
|
||||
c_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
return (h_n, c_n)
|
||||
|
||||
return (h_n, None)
|
||||
|
||||
def select_state(
|
||||
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Get specified ID state from decoder hidden states.
|
||||
|
||||
Args:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
idx: State ID to extract.
|
||||
|
||||
Returns:
|
||||
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
states[0][:, idx : idx + 1, :],
|
||||
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
|
||||
)
|
||||
|
||||
def create_batch_states(
|
||||
self,
|
||||
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Create decoder hidden states.
|
||||
|
||||
Args:
|
||||
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
|
||||
|
||||
Returns:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
torch.cat([s[0] for s in new_states], dim=1),
|
||||
torch.cat([s[1] for s in new_states], dim=1)
|
||||
if self.dtype == "lstm"
|
||||
else None,
|
||||
)
|
||||
0
funasr_local/models/decoder/__init__.py
Normal file
0
funasr_local/models/decoder/__init__.py
Normal file
19
funasr_local/models/decoder/abs_decoder.py
Normal file
19
funasr_local/models/decoder/abs_decoder.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from funasr_local.modules.scorers.scorer_interface import ScorerInterface
|
||||
|
||||
|
||||
class AbsDecoder(torch.nn.Module, ScorerInterface, ABC):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
776
funasr_local/models/decoder/contextual_decoder.py
Normal file
776
funasr_local/models/decoder/contextual_decoder.py
Normal file
@@ -0,0 +1,776 @@
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from funasr_local.modules.streaming_utils import utils as myutils
|
||||
from funasr_local.models.decoder.transformer_decoder import BaseTransformerDecoder
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
|
||||
from funasr_local.modules.embedding import PositionalEncoding
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
from funasr_local.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
|
||||
from funasr_local.modules.repeat import repeat
|
||||
from funasr_local.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
|
||||
|
||||
|
||||
class ContextualDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
src_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(ContextualDecoderLayer, self).__init__()
|
||||
self.size = size
|
||||
self.self_attn = self_attn
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
if self_attn is not None:
|
||||
self.norm2 = LayerNorm(size)
|
||||
if src_attn is not None:
|
||||
self.norm3 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
|
||||
# tgt = self.dropout(tgt)
|
||||
if isinstance(tgt, Tuple):
|
||||
tgt, _ = tgt
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
tgt = self.feed_forward(tgt)
|
||||
|
||||
x = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm2(tgt)
|
||||
if self.training:
|
||||
cache = None
|
||||
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
|
||||
x = residual + self.dropout(x)
|
||||
x_self_attn = x
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
x = self.src_attn(x, memory, memory_mask)
|
||||
x_src_attn = x
|
||||
|
||||
x = residual + self.dropout(x)
|
||||
return x, tgt_mask, x_self_attn, x_src_attn
|
||||
|
||||
|
||||
class ContextualBiasDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
src_attn,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(ContextualBiasDecoder, self).__init__()
|
||||
self.size = size
|
||||
self.src_attn = src_attn
|
||||
if src_attn is not None:
|
||||
self.norm3 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
|
||||
x = tgt
|
||||
if self.src_attn is not None:
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
x = self.dropout(self.src_attn(x, memory, memory_mask))
|
||||
return x, tgt_mask, memory, memory_mask, cache
|
||||
|
||||
|
||||
class ContextualParaformerDecoder(ParaformerSANMDecoder):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
att_layer_num: int = 6,
|
||||
kernel_size: int = 21,
|
||||
sanm_shfit: int = 0,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
if input_layer == 'none':
|
||||
self.embed = None
|
||||
if input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(vocab_size, attention_dim),
|
||||
# pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(vocab_size, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
if use_output_layer:
|
||||
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
||||
else:
|
||||
self.output_layer = None
|
||||
|
||||
self.att_layer_num = att_layer_num
|
||||
self.num_blocks = num_blocks
|
||||
if sanm_shfit is None:
|
||||
sanm_shfit = (kernel_size - 1) // 2
|
||||
self.decoders = repeat(
|
||||
att_layer_num - 1,
|
||||
lambda lnum: DecoderLayerSANM(
|
||||
attention_dim,
|
||||
MultiHeadedAttentionSANMDecoder(
|
||||
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
|
||||
),
|
||||
MultiHeadedAttentionCrossAtt(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.bias_decoder = ContextualBiasDecoder(
|
||||
size=attention_dim,
|
||||
src_attn=MultiHeadedAttentionCrossAtt(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
dropout_rate=dropout_rate,
|
||||
normalize_before=True,
|
||||
)
|
||||
self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
|
||||
self.last_decoder = ContextualDecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttentionSANMDecoder(
|
||||
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
|
||||
),
|
||||
MultiHeadedAttentionCrossAtt(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
)
|
||||
if num_blocks - att_layer_num <= 0:
|
||||
self.decoders2 = None
|
||||
else:
|
||||
self.decoders2 = repeat(
|
||||
num_blocks - att_layer_num,
|
||||
lambda lnum: DecoderLayerSANM(
|
||||
attention_dim,
|
||||
MultiHeadedAttentionSANMDecoder(
|
||||
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
|
||||
),
|
||||
None,
|
||||
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.decoders3 = repeat(
|
||||
1,
|
||||
lambda lnum: DecoderLayerSANM(
|
||||
attention_dim,
|
||||
None,
|
||||
None,
|
||||
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
contextual_info: torch.Tensor,
|
||||
return_hidden: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
|
||||
|
||||
x = tgt
|
||||
x, tgt_mask, memory, memory_mask, _ = self.decoders(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
_, _, x_self_attn, x_src_attn = self.last_decoder(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
|
||||
# contextual paraformer related
|
||||
contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
|
||||
contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
|
||||
cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
|
||||
|
||||
if self.bias_output is not None:
|
||||
x = torch.cat([x_src_attn, cx], dim=2)
|
||||
x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
|
||||
x = x_self_attn + self.dropout(x)
|
||||
|
||||
if self.decoders2 is not None:
|
||||
x, tgt_mask, memory, memory_mask, _ = self.decoders2(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
|
||||
x, tgt_mask, memory, memory_mask, _ = self.decoders3(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
olens = tgt_mask.sum(1)
|
||||
if self.output_layer is not None and return_hidden is False:
|
||||
x = self.output_layer(x)
|
||||
return x, olens
|
||||
|
||||
def gen_tf2torch_map_dict(self):
|
||||
|
||||
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
|
||||
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
|
||||
map_dict_local = {
|
||||
|
||||
## decoder
|
||||
# ffn
|
||||
"{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
|
||||
# fsmn
|
||||
"{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
|
||||
tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
|
||||
tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
|
||||
tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 2, 0),
|
||||
}, # (256,1,31),(1,31,256,1)
|
||||
# src att
|
||||
"{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# dnn
|
||||
"{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
|
||||
# embed_concat_ffn
|
||||
"{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,1024),(1,1024,256)
|
||||
|
||||
# out norm
|
||||
"{}.after_norm.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.after_norm.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
|
||||
# in embed
|
||||
"{}.embed.0.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/w_embs".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (4235,256),(4235,256)
|
||||
|
||||
# out layer
|
||||
"{}.output_layer.weight".format(tensor_name_prefix_torch):
|
||||
{"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
|
||||
"squeeze": [None, None],
|
||||
"transpose": [(1, 0), None],
|
||||
}, # (4235,256),(256,4235)
|
||||
"{}.output_layer.bias".format(tensor_name_prefix_torch):
|
||||
{"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
|
||||
"seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
|
||||
"squeeze": [None, None],
|
||||
"transpose": [None, None],
|
||||
}, # (4235,),(4235,)
|
||||
|
||||
## clas decoder
|
||||
# src att
|
||||
"{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
"{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
"{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (1024,),(1024,)
|
||||
"{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": 0,
|
||||
"transpose": (1, 0),
|
||||
}, # (256,256),(1,256,256)
|
||||
"{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": None,
|
||||
}, # (256,),(256,)
|
||||
# dnn
|
||||
"{}.bias_output.weight".format(tensor_name_prefix_torch):
|
||||
{"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
|
||||
"squeeze": None,
|
||||
"transpose": (2, 1, 0),
|
||||
}, # (1024,256),(1,256,1024)
|
||||
|
||||
}
|
||||
return map_dict_local
|
||||
|
||||
def convert_tf2torch(self,
|
||||
var_dict_tf,
|
||||
var_dict_torch,
|
||||
):
|
||||
map_dict = self.gen_tf2torch_map_dict()
|
||||
var_dict_torch_update = dict()
|
||||
decoder_layeridx_sets = set()
|
||||
for name in sorted(var_dict_torch.keys(), reverse=False):
|
||||
names = name.split('.')
|
||||
if names[0] == self.tf2torch_tensor_name_prefix_torch:
|
||||
if names[1] == "decoders":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
decoder_layeridx_sets.add(layeridx)
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
elif names[1] == "last_decoder":
|
||||
layeridx = 15
|
||||
name_q = name.replace("last_decoder", "decoders.layeridx")
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
decoder_layeridx_sets.add(layeridx)
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
|
||||
elif names[1] == "decoders2":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
name_q = name_q.replace("decoders2", "decoders")
|
||||
layeridx_bias = len(decoder_layeridx_sets)
|
||||
|
||||
layeridx += layeridx_bias
|
||||
if "decoders." in name:
|
||||
decoder_layeridx_sets.add(layeridx)
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
elif names[1] == "decoders3":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
if "decoders." in name:
|
||||
decoder_layeridx_sets.add(layeridx)
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
elif names[1] == "bias_decoder":
|
||||
name_q = name
|
||||
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
|
||||
elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
|
||||
name_tf = map_dict[name]["name"]
|
||||
if isinstance(name_tf, list):
|
||||
idx_list = 0
|
||||
if name_tf[idx_list] in var_dict_tf.keys():
|
||||
pass
|
||||
else:
|
||||
idx_list = 1
|
||||
data_tf = var_dict_tf[name_tf[idx_list]]
|
||||
if map_dict[name]["squeeze"][idx_list] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
|
||||
if map_dict[name]["transpose"][idx_list] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
|
||||
name_tf[idx_list],
|
||||
var_dict_tf[name_tf[
|
||||
idx_list]].shape))
|
||||
|
||||
else:
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
|
||||
if map_dict[name]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
elif names[1] == "after_norm":
|
||||
name_tf = map_dict[name]["name"]
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
elif names[1] == "embed_concat_ffn":
|
||||
layeridx = int(names[2])
|
||||
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
|
||||
|
||||
layeridx_bias = 0
|
||||
layeridx += layeridx_bias
|
||||
if "decoders." in name:
|
||||
decoder_layeridx_sets.add(layeridx)
|
||||
if name_q in map_dict.keys():
|
||||
name_v = map_dict[name_q]["name"]
|
||||
name_tf = name_v.replace("layeridx", "{}".format(layeridx))
|
||||
data_tf = var_dict_tf[name_tf]
|
||||
if map_dict[name_q]["squeeze"] is not None:
|
||||
data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
|
||||
if map_dict[name_q]["transpose"] is not None:
|
||||
data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
|
||||
data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
|
||||
assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
|
||||
var_dict_torch[
|
||||
name].size(),
|
||||
data_tf.size())
|
||||
var_dict_torch_update[name] = data_tf
|
||||
logging.info(
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
334
funasr_local/models/decoder/rnn_decoder.py
Normal file
334
funasr_local/models/decoder/rnn_decoder.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.nets_utils import to_device
|
||||
from funasr_local.modules.rnn.attentions import initial_att
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr_local.utils.get_default_kwargs import get_default_kwargs
|
||||
|
||||
|
||||
def build_attention_list(
|
||||
eprojs: int,
|
||||
dunits: int,
|
||||
atype: str = "location",
|
||||
num_att: int = 1,
|
||||
num_encs: int = 1,
|
||||
aheads: int = 4,
|
||||
adim: int = 320,
|
||||
awin: int = 5,
|
||||
aconv_chans: int = 10,
|
||||
aconv_filts: int = 100,
|
||||
han_mode: bool = False,
|
||||
han_type=None,
|
||||
han_heads: int = 4,
|
||||
han_dim: int = 320,
|
||||
han_conv_chans: int = -1,
|
||||
han_conv_filts: int = 100,
|
||||
han_win: int = 5,
|
||||
):
|
||||
|
||||
att_list = torch.nn.ModuleList()
|
||||
if num_encs == 1:
|
||||
for i in range(num_att):
|
||||
att = initial_att(
|
||||
atype,
|
||||
eprojs,
|
||||
dunits,
|
||||
aheads,
|
||||
adim,
|
||||
awin,
|
||||
aconv_chans,
|
||||
aconv_filts,
|
||||
)
|
||||
att_list.append(att)
|
||||
elif num_encs > 1: # no multi-speaker mode
|
||||
if han_mode:
|
||||
att = initial_att(
|
||||
han_type,
|
||||
eprojs,
|
||||
dunits,
|
||||
han_heads,
|
||||
han_dim,
|
||||
han_win,
|
||||
han_conv_chans,
|
||||
han_conv_filts,
|
||||
han_mode=True,
|
||||
)
|
||||
return att
|
||||
else:
|
||||
att_list = torch.nn.ModuleList()
|
||||
for idx in range(num_encs):
|
||||
att = initial_att(
|
||||
atype[idx],
|
||||
eprojs,
|
||||
dunits,
|
||||
aheads[idx],
|
||||
adim[idx],
|
||||
awin[idx],
|
||||
aconv_chans[idx],
|
||||
aconv_filts[idx],
|
||||
)
|
||||
att_list.append(att)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Number of encoders needs to be more than one. {}".format(num_encs)
|
||||
)
|
||||
return att_list
|
||||
|
||||
|
||||
class RNNDecoder(AbsDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
rnn_type: str = "lstm",
|
||||
num_layers: int = 1,
|
||||
hidden_size: int = 320,
|
||||
sampling_probability: float = 0.0,
|
||||
dropout: float = 0.0,
|
||||
context_residual: bool = False,
|
||||
replace_sos: bool = False,
|
||||
num_encs: int = 1,
|
||||
att_conf: dict = get_default_kwargs(build_attention_list),
|
||||
):
|
||||
# FIXME(kamo): The parts of num_spk should be refactored more more more
|
||||
assert check_argument_types()
|
||||
if rnn_type not in {"lstm", "gru"}:
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
super().__init__()
|
||||
eprojs = encoder_output_size
|
||||
self.dtype = rnn_type
|
||||
self.dunits = hidden_size
|
||||
self.dlayers = num_layers
|
||||
self.context_residual = context_residual
|
||||
self.sos = vocab_size - 1
|
||||
self.eos = vocab_size - 1
|
||||
self.odim = vocab_size
|
||||
self.sampling_probability = sampling_probability
|
||||
self.dropout = dropout
|
||||
self.num_encs = num_encs
|
||||
|
||||
# for multilingual translation
|
||||
self.replace_sos = replace_sos
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
|
||||
self.dropout_emb = torch.nn.Dropout(p=dropout)
|
||||
|
||||
self.decoder = torch.nn.ModuleList()
|
||||
self.dropout_dec = torch.nn.ModuleList()
|
||||
self.decoder += [
|
||||
torch.nn.LSTMCell(hidden_size + eprojs, hidden_size)
|
||||
if self.dtype == "lstm"
|
||||
else torch.nn.GRUCell(hidden_size + eprojs, hidden_size)
|
||||
]
|
||||
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
||||
for _ in range(1, self.dlayers):
|
||||
self.decoder += [
|
||||
torch.nn.LSTMCell(hidden_size, hidden_size)
|
||||
if self.dtype == "lstm"
|
||||
else torch.nn.GRUCell(hidden_size, hidden_size)
|
||||
]
|
||||
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
||||
# NOTE: dropout is applied only for the vertical connections
|
||||
# see https://arxiv.org/pdf/1409.2329.pdf
|
||||
|
||||
if context_residual:
|
||||
self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size)
|
||||
else:
|
||||
self.output = torch.nn.Linear(hidden_size, vocab_size)
|
||||
|
||||
self.att_list = build_attention_list(
|
||||
eprojs=eprojs, dunits=hidden_size, **att_conf
|
||||
)
|
||||
|
||||
def zero_state(self, hs_pad):
|
||||
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
|
||||
|
||||
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
|
||||
if self.dtype == "lstm":
|
||||
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
|
||||
for i in range(1, self.dlayers):
|
||||
z_list[i], c_list[i] = self.decoder[i](
|
||||
self.dropout_dec[i - 1](z_list[i - 1]),
|
||||
(z_prev[i], c_prev[i]),
|
||||
)
|
||||
else:
|
||||
z_list[0] = self.decoder[0](ey, z_prev[0])
|
||||
for i in range(1, self.dlayers):
|
||||
z_list[i] = self.decoder[i](
|
||||
self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
|
||||
)
|
||||
return z_list, c_list
|
||||
|
||||
def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
hs_pad = [hs_pad]
|
||||
hlens = [hlens]
|
||||
|
||||
# attention index for the attention module
|
||||
# in SPA (speaker parallel attention),
|
||||
# att_idx is used to select attention module. In other cases, it is 0.
|
||||
att_idx = min(strm_idx, len(self.att_list) - 1)
|
||||
|
||||
# hlens should be list of list of integer
|
||||
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
|
||||
|
||||
# get dim, length info
|
||||
olength = ys_in_pad.size(1)
|
||||
|
||||
# initialization
|
||||
c_list = [self.zero_state(hs_pad[0])]
|
||||
z_list = [self.zero_state(hs_pad[0])]
|
||||
for _ in range(1, self.dlayers):
|
||||
c_list.append(self.zero_state(hs_pad[0]))
|
||||
z_list.append(self.zero_state(hs_pad[0]))
|
||||
z_all = []
|
||||
if self.num_encs == 1:
|
||||
att_w = None
|
||||
self.att_list[att_idx].reset() # reset pre-computation of h
|
||||
else:
|
||||
att_w_list = [None] * (self.num_encs + 1) # atts + han
|
||||
att_c_list = [None] * self.num_encs # atts
|
||||
for idx in range(self.num_encs + 1):
|
||||
# reset pre-computation of h in atts and han
|
||||
self.att_list[idx].reset()
|
||||
|
||||
# pre-computation of embedding
|
||||
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
|
||||
|
||||
# loop for an output sequence
|
||||
for i in range(olength):
|
||||
if self.num_encs == 1:
|
||||
att_c, att_w = self.att_list[att_idx](
|
||||
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
|
||||
)
|
||||
else:
|
||||
for idx in range(self.num_encs):
|
||||
att_c_list[idx], att_w_list[idx] = self.att_list[idx](
|
||||
hs_pad[idx],
|
||||
hlens[idx],
|
||||
self.dropout_dec[0](z_list[0]),
|
||||
att_w_list[idx],
|
||||
)
|
||||
hs_pad_han = torch.stack(att_c_list, dim=1)
|
||||
hlens_han = [self.num_encs] * len(ys_in_pad)
|
||||
att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
|
||||
hs_pad_han,
|
||||
hlens_han,
|
||||
self.dropout_dec[0](z_list[0]),
|
||||
att_w_list[self.num_encs],
|
||||
)
|
||||
if i > 0 and random.random() < self.sampling_probability:
|
||||
z_out = self.output(z_all[-1])
|
||||
z_out = np.argmax(z_out.detach().cpu(), axis=1)
|
||||
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
|
||||
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
|
||||
else:
|
||||
# utt x (zdim + hdim)
|
||||
ey = torch.cat((eys[:, i, :], att_c), dim=1)
|
||||
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
|
||||
if self.context_residual:
|
||||
z_all.append(
|
||||
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
||||
) # utt x (zdim + hdim)
|
||||
else:
|
||||
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
|
||||
|
||||
z_all = torch.stack(z_all, dim=1)
|
||||
z_all = self.output(z_all)
|
||||
z_all.masked_fill_(
|
||||
make_pad_mask(ys_in_lens, z_all, 1),
|
||||
0,
|
||||
)
|
||||
return z_all, ys_in_lens
|
||||
|
||||
def init_state(self, x):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
x = [x]
|
||||
|
||||
c_list = [self.zero_state(x[0].unsqueeze(0))]
|
||||
z_list = [self.zero_state(x[0].unsqueeze(0))]
|
||||
for _ in range(1, self.dlayers):
|
||||
c_list.append(self.zero_state(x[0].unsqueeze(0)))
|
||||
z_list.append(self.zero_state(x[0].unsqueeze(0)))
|
||||
# TODO(karita): support strm_index for `asr_mix`
|
||||
strm_index = 0
|
||||
att_idx = min(strm_index, len(self.att_list) - 1)
|
||||
if self.num_encs == 1:
|
||||
a = None
|
||||
self.att_list[att_idx].reset() # reset pre-computation of h
|
||||
else:
|
||||
a = [None] * (self.num_encs + 1) # atts + han
|
||||
for idx in range(self.num_encs + 1):
|
||||
# reset pre-computation of h in atts and han
|
||||
self.att_list[idx].reset()
|
||||
return dict(
|
||||
c_prev=c_list[:],
|
||||
z_prev=z_list[:],
|
||||
a_prev=a,
|
||||
workspace=(att_idx, z_list, c_list),
|
||||
)
|
||||
|
||||
def score(self, yseq, state, x):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
x = [x]
|
||||
|
||||
att_idx, z_list, c_list = state["workspace"]
|
||||
vy = yseq[-1].unsqueeze(0)
|
||||
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
|
||||
if self.num_encs == 1:
|
||||
att_c, att_w = self.att_list[att_idx](
|
||||
x[0].unsqueeze(0),
|
||||
[x[0].size(0)],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"],
|
||||
)
|
||||
else:
|
||||
att_w = [None] * (self.num_encs + 1) # atts + han
|
||||
att_c_list = [None] * self.num_encs # atts
|
||||
for idx in range(self.num_encs):
|
||||
att_c_list[idx], att_w[idx] = self.att_list[idx](
|
||||
x[idx].unsqueeze(0),
|
||||
[x[idx].size(0)],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"][idx],
|
||||
)
|
||||
h_han = torch.stack(att_c_list, dim=1)
|
||||
att_c, att_w[self.num_encs] = self.att_list[self.num_encs](
|
||||
h_han,
|
||||
[self.num_encs],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"][self.num_encs],
|
||||
)
|
||||
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
|
||||
z_list, c_list = self.rnn_forward(
|
||||
ey, z_list, c_list, state["z_prev"], state["c_prev"]
|
||||
)
|
||||
if self.context_residual:
|
||||
logits = self.output(
|
||||
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
||||
)
|
||||
else:
|
||||
logits = self.output(self.dropout_dec[-1](z_list[-1]))
|
||||
logp = F.log_softmax(logits, dim=1).squeeze(0)
|
||||
return (
|
||||
logp,
|
||||
dict(
|
||||
c_prev=c_list[:],
|
||||
z_prev=z_list[:],
|
||||
a_prev=att_w,
|
||||
workspace=(att_idx, z_list, c_list),
|
||||
),
|
||||
)
|
||||
258
funasr_local/models/decoder/rnnt_decoder.py
Normal file
258
funasr_local/models/decoder/rnnt_decoder.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""RNN decoder definition for Transducer models."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.beam_search.beam_search_transducer import Hypothesis
|
||||
from funasr_local.models.specaug.specaug import SpecAug
|
||||
|
||||
class RNNTDecoder(torch.nn.Module):
|
||||
"""RNN decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: Vocabulary size.
|
||||
embed_size: Embedding size.
|
||||
hidden_size: Hidden size..
|
||||
rnn_type: Decoder layers type.
|
||||
num_layers: Number of decoder layers.
|
||||
dropout_rate: Dropout rate for decoder layers.
|
||||
embed_dropout_rate: Dropout rate for embedding layer.
|
||||
embed_pad: Embedding padding symbol ID.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embed_size: int = 256,
|
||||
hidden_size: int = 256,
|
||||
rnn_type: str = "lstm",
|
||||
num_layers: int = 1,
|
||||
dropout_rate: float = 0.0,
|
||||
embed_dropout_rate: float = 0.0,
|
||||
embed_pad: int = 0,
|
||||
) -> None:
|
||||
"""Construct a RNNDecoder object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
if rnn_type not in ("lstm", "gru"):
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
|
||||
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
|
||||
|
||||
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
|
||||
|
||||
self.rnn = torch.nn.ModuleList(
|
||||
[rnn_class(embed_size, hidden_size, 1, batch_first=True)]
|
||||
)
|
||||
|
||||
for _ in range(1, num_layers):
|
||||
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
|
||||
|
||||
self.dropout_rnn = torch.nn.ModuleList(
|
||||
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
|
||||
)
|
||||
|
||||
self.dlayers = num_layers
|
||||
self.dtype = rnn_type
|
||||
|
||||
self.output_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.device = next(self.parameters()).device
|
||||
self.score_cache = {}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
labels: torch.Tensor,
|
||||
label_lens: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
states: Decoder hidden states.
|
||||
((N, B, D_dec), (N, B, D_dec) or None) or None
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, U, D_dec)
|
||||
|
||||
"""
|
||||
if states is None:
|
||||
states = self.init_state(labels.size(0))
|
||||
|
||||
dec_embed = self.dropout_embed(self.embed(labels))
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
return dec_out
|
||||
|
||||
def rnn_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
x: RNN input sequences. (B, D_emb)
|
||||
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
x: RNN output sequences. (B, D_dec)
|
||||
(h_next, c_next): Decoder hidden states.
|
||||
(N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_prev, c_prev = state
|
||||
h_next, c_next = self.init_state(x.size(0))
|
||||
|
||||
for layer in range(self.dlayers):
|
||||
if self.dtype == "lstm":
|
||||
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
|
||||
layer
|
||||
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
|
||||
else:
|
||||
x, h_next[layer : layer + 1] = self.rnn[layer](
|
||||
x, hx=h_prev[layer : layer + 1]
|
||||
)
|
||||
|
||||
x = self.dropout_rnn[layer](x)
|
||||
|
||||
return x, (h_next, c_next)
|
||||
|
||||
def score(
|
||||
self,
|
||||
label: torch.Tensor,
|
||||
label_sequence: List[int],
|
||||
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypothesis.
|
||||
|
||||
Args:
|
||||
label: Previous label. (1, 1)
|
||||
label_sequence: Current label sequence.
|
||||
dec_state: Previous decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequence. (1, D_dec)
|
||||
dec_state: Decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
str_labels = "_".join(map(str, label_sequence))
|
||||
|
||||
if str_labels in self.score_cache:
|
||||
dec_out, dec_state = self.score_cache[str_labels]
|
||||
else:
|
||||
dec_embed = self.embed(label)
|
||||
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
|
||||
|
||||
self.score_cache[str_labels] = (dec_out, dec_state)
|
||||
|
||||
return dec_out[0], dec_state
|
||||
|
||||
def batch_score(
|
||||
self,
|
||||
hyps: List[Hypothesis],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypotheses.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, D_dec)
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
|
||||
dec_embed = self.embed(labels)
|
||||
|
||||
states = self.create_batch_states([h.dec_state for h in hyps])
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
|
||||
return dec_out.squeeze(1), states
|
||||
|
||||
def set_device(self, device: torch.device) -> None:
|
||||
"""Set GPU device to use.
|
||||
|
||||
Args:
|
||||
device: Device ID.
|
||||
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
def init_state(
|
||||
self, batch_size: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
|
||||
"""Initialize decoder states.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size.
|
||||
|
||||
Returns:
|
||||
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.dtype == "lstm":
|
||||
c_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
return (h_n, c_n)
|
||||
|
||||
return (h_n, None)
|
||||
|
||||
def select_state(
|
||||
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Get specified ID state from decoder hidden states.
|
||||
|
||||
Args:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
idx: State ID to extract.
|
||||
|
||||
Returns:
|
||||
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
states[0][:, idx : idx + 1, :],
|
||||
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
|
||||
)
|
||||
|
||||
def create_batch_states(
|
||||
self,
|
||||
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Create decoder hidden states.
|
||||
|
||||
Args:
|
||||
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
|
||||
|
||||
Returns:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
torch.cat([s[0] for s in new_states], dim=1),
|
||||
torch.cat([s[1] for s in new_states], dim=1)
|
||||
if self.dtype == "lstm"
|
||||
else None,
|
||||
)
|
||||
1484
funasr_local/models/decoder/sanm_decoder.py
Normal file
1484
funasr_local/models/decoder/sanm_decoder.py
Normal file
File diff suppressed because it is too large
Load Diff
37
funasr_local/models/decoder/sv_decoder.py
Normal file
37
funasr_local/models/decoder/sv_decoder.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
|
||||
|
||||
class DenseDecoder(AbsDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
encoder_output_size,
|
||||
num_nodes_resnet1: int = 256,
|
||||
num_nodes_last_layer: int = 256,
|
||||
batchnorm_momentum: float = 0.5,
|
||||
):
|
||||
super(DenseDecoder, self).__init__()
|
||||
self.resnet1_dense = torch.nn.Linear(encoder_output_size, num_nodes_resnet1)
|
||||
self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
|
||||
self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
|
||||
|
||||
self.output_dense = torch.nn.Linear(num_nodes_last_layer, vocab_size, bias=False)
|
||||
|
||||
def forward(self, features):
|
||||
embeddings = {}
|
||||
features = self.resnet1_dense(features)
|
||||
embeddings["resnet1_dense"] = features
|
||||
features = F.relu(features)
|
||||
features = self.resnet1_bn(features)
|
||||
|
||||
features = self.resnet2_dense(features)
|
||||
embeddings["resnet2_dense"] = features
|
||||
features = F.relu(features)
|
||||
features = self.resnet2_bn(features)
|
||||
|
||||
features = self.output_dense(features)
|
||||
return features, embeddings
|
||||
766
funasr_local/models/decoder/transformer_decoder.py
Normal file
766
funasr_local/models/decoder/transformer_decoder.py
Normal file
@@ -0,0 +1,766 @@
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Decoder definition."""
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr_local.modules.attention import MultiHeadedAttention
|
||||
from funasr_local.modules.dynamic_conv import DynamicConvolution
|
||||
from funasr_local.modules.dynamic_conv2d import DynamicConvolution2D
|
||||
from funasr_local.modules.embedding import PositionalEncoding
|
||||
from funasr_local.modules.layer_norm import LayerNorm
|
||||
from funasr_local.modules.lightconv import LightweightConvolution
|
||||
from funasr_local.modules.lightconv2d import LightweightConvolution2D
|
||||
from funasr_local.modules.mask import subsequent_mask
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr_local.modules.repeat import repeat
|
||||
from funasr_local.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""Single decoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` instance can be used as the argument.
|
||||
src_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` instance can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
src_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(DecoderLayer, self).__init__()
|
||||
self.size = size
|
||||
self.self_attn = self_attn
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.norm3 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear1 = nn.Linear(size + size, size)
|
||||
self.concat_linear2 = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
|
||||
"""Compute decoded features.
|
||||
|
||||
Args:
|
||||
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
||||
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
|
||||
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
|
||||
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
|
||||
cache (List[torch.Tensor]): List of cached tensors.
|
||||
Each tensor shape should be (#batch, maxlen_out - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor(#batch, maxlen_out, size).
|
||||
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
||||
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
||||
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
||||
|
||||
"""
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
if cache is None:
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
else:
|
||||
# compute only the last frame query keeping dim: max_time_out -> 1
|
||||
assert cache.shape == (
|
||||
tgt.shape[0],
|
||||
tgt.shape[1] - 1,
|
||||
self.size,
|
||||
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
||||
tgt_q = tgt[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
tgt_q_mask = None
|
||||
if tgt_mask is not None:
|
||||
tgt_q_mask = tgt_mask[:, -1:, :]
|
||||
|
||||
if self.concat_after:
|
||||
tgt_concat = torch.cat(
|
||||
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear1(tgt_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
|
||||
)
|
||||
x = residual + self.concat_linear2(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, tgt_mask, memory, memory_mask
|
||||
|
||||
|
||||
class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
|
||||
"""Base class of Transfomer decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: output dim
|
||||
encoder_output_size: dimension of attention
|
||||
attention_heads: the number of heads of multi head attention
|
||||
linear_units: the number of units of position-wise feed forward
|
||||
num_blocks: the number of decoder blocks
|
||||
dropout_rate: dropout rate
|
||||
self_attention_dropout_rate: dropout rate for attention
|
||||
input_layer: input layer type
|
||||
use_output_layer: whether to use output layer
|
||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
||||
normalize_before: whether to use layer_norm before the first block
|
||||
concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied.
|
||||
i.e. x -> x + att(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
if input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(vocab_size, attention_dim),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(vocab_size, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
if use_output_layer:
|
||||
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
||||
else:
|
||||
self.output_layer = None
|
||||
|
||||
# Must set by the inheritance
|
||||
self.decoders = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
tgt_mask = tgt_mask & m
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
|
||||
memory.device
|
||||
)
|
||||
# Padding for Longformer
|
||||
if memory_mask.shape[-1] != memory.shape[1]:
|
||||
padlen = memory.shape[1] - memory_mask.shape[-1]
|
||||
memory_mask = torch.nn.functional.pad(
|
||||
memory_mask, (0, padlen), "constant", False
|
||||
)
|
||||
|
||||
x = self.embed(tgt)
|
||||
x, tgt_mask, memory, memory_mask = self.decoders(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.output_layer is not None:
|
||||
x = self.output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
return x, olens
|
||||
|
||||
def forward_one_step(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
cache: List[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Forward one step.
|
||||
|
||||
Args:
|
||||
tgt: input token ids, int64 (batch, maxlen_out)
|
||||
tgt_mask: input token mask, (batch, maxlen_out)
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
||||
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
cache: cached output list of (batch, max_time_out-1, size)
|
||||
Returns:
|
||||
y, cache: NN output value and cache per `self.decoders`.
|
||||
y.shape` is (batch, maxlen_out, token)
|
||||
"""
|
||||
x = self.embed(tgt)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.decoders)
|
||||
new_cache = []
|
||||
for c, decoder in zip(cache, self.decoders):
|
||||
x, tgt_mask, memory, memory_mask = decoder(
|
||||
x, tgt_mask, memory, None, cache=c
|
||||
)
|
||||
new_cache.append(x)
|
||||
|
||||
if self.normalize_before:
|
||||
y = self.after_norm(x[:, -1])
|
||||
else:
|
||||
y = x[:, -1]
|
||||
if self.output_layer is not None:
|
||||
y = torch.log_softmax(self.output_layer(y), dim=-1)
|
||||
|
||||
return y, new_cache
|
||||
|
||||
def score(self, ys, state, x):
|
||||
"""Score."""
|
||||
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
|
||||
logp, state = self.forward_one_step(
|
||||
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
|
||||
)
|
||||
return logp.squeeze(0), state
|
||||
|
||||
def batch_score(
|
||||
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# merge states
|
||||
n_batch = len(ys)
|
||||
n_layers = len(self.decoders)
|
||||
if states[0] is None:
|
||||
batch_state = None
|
||||
else:
|
||||
# transpose state of [batch, layer] into [layer, batch]
|
||||
batch_state = [
|
||||
torch.stack([states[b][i] for b in range(n_batch)])
|
||||
for i in range(n_layers)
|
||||
]
|
||||
|
||||
# batch decoding
|
||||
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
|
||||
logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
|
||||
|
||||
# transpose state of [layer, batch] into [batch, layer]
|
||||
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
||||
return logp, state_list
|
||||
|
||||
|
||||
class TransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ParaformerDecoderSAN(BaseTransformerDecoder):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
embeds_id: int = -1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, self_attention_dropout_rate
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
self.embeds_id = embeds_id
|
||||
self.attention_dim = attention_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
|
||||
memory.device
|
||||
)
|
||||
# Padding for Longformer
|
||||
if memory_mask.shape[-1] != memory.shape[1]:
|
||||
padlen = memory.shape[1] - memory_mask.shape[-1]
|
||||
memory_mask = torch.nn.functional.pad(
|
||||
memory_mask, (0, padlen), "constant", False
|
||||
)
|
||||
|
||||
# x = self.embed(tgt)
|
||||
x = tgt
|
||||
embeds_outputs = None
|
||||
for layer_id, decoder in enumerate(self.decoders):
|
||||
x, tgt_mask, memory, memory_mask = decoder(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
if layer_id == self.embeds_id:
|
||||
embeds_outputs = x
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.output_layer is not None:
|
||||
x = self.output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
if embeds_outputs is not None:
|
||||
return x, olens, embeds_outputs
|
||||
else:
|
||||
return x, olens
|
||||
|
||||
|
||||
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
LightweightConvolution(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
LightweightConvolution2D(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
DynamicConvolution(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
"conv_kernel_length must have equal number of values to num_blocks: "
|
||||
f"{len(conv_kernel_length)} != {num_blocks}"
|
||||
)
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
DynamicConvolution2D(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(
|
||||
attention_heads, attention_dim, src_attention_dropout_rate
|
||||
),
|
||||
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
Reference in New Issue
Block a user