mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
@@ -0,0 +1,334 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.nets_utils import make_pad_mask
|
||||
from funasr_local.modules.nets_utils import to_device
|
||||
from funasr_local.modules.rnn.attentions import initial_att
|
||||
from funasr_local.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr_local.utils.get_default_kwargs import get_default_kwargs
|
||||
|
||||
|
||||
def build_attention_list(
|
||||
eprojs: int,
|
||||
dunits: int,
|
||||
atype: str = "location",
|
||||
num_att: int = 1,
|
||||
num_encs: int = 1,
|
||||
aheads: int = 4,
|
||||
adim: int = 320,
|
||||
awin: int = 5,
|
||||
aconv_chans: int = 10,
|
||||
aconv_filts: int = 100,
|
||||
han_mode: bool = False,
|
||||
han_type=None,
|
||||
han_heads: int = 4,
|
||||
han_dim: int = 320,
|
||||
han_conv_chans: int = -1,
|
||||
han_conv_filts: int = 100,
|
||||
han_win: int = 5,
|
||||
):
|
||||
|
||||
att_list = torch.nn.ModuleList()
|
||||
if num_encs == 1:
|
||||
for i in range(num_att):
|
||||
att = initial_att(
|
||||
atype,
|
||||
eprojs,
|
||||
dunits,
|
||||
aheads,
|
||||
adim,
|
||||
awin,
|
||||
aconv_chans,
|
||||
aconv_filts,
|
||||
)
|
||||
att_list.append(att)
|
||||
elif num_encs > 1: # no multi-speaker mode
|
||||
if han_mode:
|
||||
att = initial_att(
|
||||
han_type,
|
||||
eprojs,
|
||||
dunits,
|
||||
han_heads,
|
||||
han_dim,
|
||||
han_win,
|
||||
han_conv_chans,
|
||||
han_conv_filts,
|
||||
han_mode=True,
|
||||
)
|
||||
return att
|
||||
else:
|
||||
att_list = torch.nn.ModuleList()
|
||||
for idx in range(num_encs):
|
||||
att = initial_att(
|
||||
atype[idx],
|
||||
eprojs,
|
||||
dunits,
|
||||
aheads[idx],
|
||||
adim[idx],
|
||||
awin[idx],
|
||||
aconv_chans[idx],
|
||||
aconv_filts[idx],
|
||||
)
|
||||
att_list.append(att)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Number of encoders needs to be more than one. {}".format(num_encs)
|
||||
)
|
||||
return att_list
|
||||
|
||||
|
||||
class RNNDecoder(AbsDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
rnn_type: str = "lstm",
|
||||
num_layers: int = 1,
|
||||
hidden_size: int = 320,
|
||||
sampling_probability: float = 0.0,
|
||||
dropout: float = 0.0,
|
||||
context_residual: bool = False,
|
||||
replace_sos: bool = False,
|
||||
num_encs: int = 1,
|
||||
att_conf: dict = get_default_kwargs(build_attention_list),
|
||||
):
|
||||
# FIXME(kamo): The parts of num_spk should be refactored more more more
|
||||
assert check_argument_types()
|
||||
if rnn_type not in {"lstm", "gru"}:
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
super().__init__()
|
||||
eprojs = encoder_output_size
|
||||
self.dtype = rnn_type
|
||||
self.dunits = hidden_size
|
||||
self.dlayers = num_layers
|
||||
self.context_residual = context_residual
|
||||
self.sos = vocab_size - 1
|
||||
self.eos = vocab_size - 1
|
||||
self.odim = vocab_size
|
||||
self.sampling_probability = sampling_probability
|
||||
self.dropout = dropout
|
||||
self.num_encs = num_encs
|
||||
|
||||
# for multilingual translation
|
||||
self.replace_sos = replace_sos
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
|
||||
self.dropout_emb = torch.nn.Dropout(p=dropout)
|
||||
|
||||
self.decoder = torch.nn.ModuleList()
|
||||
self.dropout_dec = torch.nn.ModuleList()
|
||||
self.decoder += [
|
||||
torch.nn.LSTMCell(hidden_size + eprojs, hidden_size)
|
||||
if self.dtype == "lstm"
|
||||
else torch.nn.GRUCell(hidden_size + eprojs, hidden_size)
|
||||
]
|
||||
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
||||
for _ in range(1, self.dlayers):
|
||||
self.decoder += [
|
||||
torch.nn.LSTMCell(hidden_size, hidden_size)
|
||||
if self.dtype == "lstm"
|
||||
else torch.nn.GRUCell(hidden_size, hidden_size)
|
||||
]
|
||||
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
||||
# NOTE: dropout is applied only for the vertical connections
|
||||
# see https://arxiv.org/pdf/1409.2329.pdf
|
||||
|
||||
if context_residual:
|
||||
self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size)
|
||||
else:
|
||||
self.output = torch.nn.Linear(hidden_size, vocab_size)
|
||||
|
||||
self.att_list = build_attention_list(
|
||||
eprojs=eprojs, dunits=hidden_size, **att_conf
|
||||
)
|
||||
|
||||
def zero_state(self, hs_pad):
|
||||
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
|
||||
|
||||
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
|
||||
if self.dtype == "lstm":
|
||||
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
|
||||
for i in range(1, self.dlayers):
|
||||
z_list[i], c_list[i] = self.decoder[i](
|
||||
self.dropout_dec[i - 1](z_list[i - 1]),
|
||||
(z_prev[i], c_prev[i]),
|
||||
)
|
||||
else:
|
||||
z_list[0] = self.decoder[0](ey, z_prev[0])
|
||||
for i in range(1, self.dlayers):
|
||||
z_list[i] = self.decoder[i](
|
||||
self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
|
||||
)
|
||||
return z_list, c_list
|
||||
|
||||
def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
hs_pad = [hs_pad]
|
||||
hlens = [hlens]
|
||||
|
||||
# attention index for the attention module
|
||||
# in SPA (speaker parallel attention),
|
||||
# att_idx is used to select attention module. In other cases, it is 0.
|
||||
att_idx = min(strm_idx, len(self.att_list) - 1)
|
||||
|
||||
# hlens should be list of list of integer
|
||||
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
|
||||
|
||||
# get dim, length info
|
||||
olength = ys_in_pad.size(1)
|
||||
|
||||
# initialization
|
||||
c_list = [self.zero_state(hs_pad[0])]
|
||||
z_list = [self.zero_state(hs_pad[0])]
|
||||
for _ in range(1, self.dlayers):
|
||||
c_list.append(self.zero_state(hs_pad[0]))
|
||||
z_list.append(self.zero_state(hs_pad[0]))
|
||||
z_all = []
|
||||
if self.num_encs == 1:
|
||||
att_w = None
|
||||
self.att_list[att_idx].reset() # reset pre-computation of h
|
||||
else:
|
||||
att_w_list = [None] * (self.num_encs + 1) # atts + han
|
||||
att_c_list = [None] * self.num_encs # atts
|
||||
for idx in range(self.num_encs + 1):
|
||||
# reset pre-computation of h in atts and han
|
||||
self.att_list[idx].reset()
|
||||
|
||||
# pre-computation of embedding
|
||||
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
|
||||
|
||||
# loop for an output sequence
|
||||
for i in range(olength):
|
||||
if self.num_encs == 1:
|
||||
att_c, att_w = self.att_list[att_idx](
|
||||
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
|
||||
)
|
||||
else:
|
||||
for idx in range(self.num_encs):
|
||||
att_c_list[idx], att_w_list[idx] = self.att_list[idx](
|
||||
hs_pad[idx],
|
||||
hlens[idx],
|
||||
self.dropout_dec[0](z_list[0]),
|
||||
att_w_list[idx],
|
||||
)
|
||||
hs_pad_han = torch.stack(att_c_list, dim=1)
|
||||
hlens_han = [self.num_encs] * len(ys_in_pad)
|
||||
att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs](
|
||||
hs_pad_han,
|
||||
hlens_han,
|
||||
self.dropout_dec[0](z_list[0]),
|
||||
att_w_list[self.num_encs],
|
||||
)
|
||||
if i > 0 and random.random() < self.sampling_probability:
|
||||
z_out = self.output(z_all[-1])
|
||||
z_out = np.argmax(z_out.detach().cpu(), axis=1)
|
||||
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
|
||||
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
|
||||
else:
|
||||
# utt x (zdim + hdim)
|
||||
ey = torch.cat((eys[:, i, :], att_c), dim=1)
|
||||
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
|
||||
if self.context_residual:
|
||||
z_all.append(
|
||||
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
||||
) # utt x (zdim + hdim)
|
||||
else:
|
||||
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
|
||||
|
||||
z_all = torch.stack(z_all, dim=1)
|
||||
z_all = self.output(z_all)
|
||||
z_all.masked_fill_(
|
||||
make_pad_mask(ys_in_lens, z_all, 1),
|
||||
0,
|
||||
)
|
||||
return z_all, ys_in_lens
|
||||
|
||||
def init_state(self, x):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
x = [x]
|
||||
|
||||
c_list = [self.zero_state(x[0].unsqueeze(0))]
|
||||
z_list = [self.zero_state(x[0].unsqueeze(0))]
|
||||
for _ in range(1, self.dlayers):
|
||||
c_list.append(self.zero_state(x[0].unsqueeze(0)))
|
||||
z_list.append(self.zero_state(x[0].unsqueeze(0)))
|
||||
# TODO(karita): support strm_index for `asr_mix`
|
||||
strm_index = 0
|
||||
att_idx = min(strm_index, len(self.att_list) - 1)
|
||||
if self.num_encs == 1:
|
||||
a = None
|
||||
self.att_list[att_idx].reset() # reset pre-computation of h
|
||||
else:
|
||||
a = [None] * (self.num_encs + 1) # atts + han
|
||||
for idx in range(self.num_encs + 1):
|
||||
# reset pre-computation of h in atts and han
|
||||
self.att_list[idx].reset()
|
||||
return dict(
|
||||
c_prev=c_list[:],
|
||||
z_prev=z_list[:],
|
||||
a_prev=a,
|
||||
workspace=(att_idx, z_list, c_list),
|
||||
)
|
||||
|
||||
def score(self, yseq, state, x):
|
||||
# to support mutiple encoder asr mode, in single encoder mode,
|
||||
# convert torch.Tensor to List of torch.Tensor
|
||||
if self.num_encs == 1:
|
||||
x = [x]
|
||||
|
||||
att_idx, z_list, c_list = state["workspace"]
|
||||
vy = yseq[-1].unsqueeze(0)
|
||||
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
|
||||
if self.num_encs == 1:
|
||||
att_c, att_w = self.att_list[att_idx](
|
||||
x[0].unsqueeze(0),
|
||||
[x[0].size(0)],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"],
|
||||
)
|
||||
else:
|
||||
att_w = [None] * (self.num_encs + 1) # atts + han
|
||||
att_c_list = [None] * self.num_encs # atts
|
||||
for idx in range(self.num_encs):
|
||||
att_c_list[idx], att_w[idx] = self.att_list[idx](
|
||||
x[idx].unsqueeze(0),
|
||||
[x[idx].size(0)],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"][idx],
|
||||
)
|
||||
h_han = torch.stack(att_c_list, dim=1)
|
||||
att_c, att_w[self.num_encs] = self.att_list[self.num_encs](
|
||||
h_han,
|
||||
[self.num_encs],
|
||||
self.dropout_dec[0](state["z_prev"][0]),
|
||||
state["a_prev"][self.num_encs],
|
||||
)
|
||||
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
|
||||
z_list, c_list = self.rnn_forward(
|
||||
ey, z_list, c_list, state["z_prev"], state["c_prev"]
|
||||
)
|
||||
if self.context_residual:
|
||||
logits = self.output(
|
||||
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
||||
)
|
||||
else:
|
||||
logits = self.output(self.dropout_dec[-1](z_list[-1]))
|
||||
logp = F.log_softmax(logits, dim=1).squeeze(0)
|
||||
return (
|
||||
logp,
|
||||
dict(
|
||||
c_prev=c_list[:],
|
||||
z_prev=z_list[:],
|
||||
a_prev=att_w,
|
||||
workspace=(att_idx, z_list, c_list),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,258 @@
|
||||
"""RNN decoder definition for Transducer models."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr_local.modules.beam_search.beam_search_transducer import Hypothesis
|
||||
from funasr_local.models.specaug.specaug import SpecAug
|
||||
|
||||
class RNNTDecoder(torch.nn.Module):
|
||||
"""RNN decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: Vocabulary size.
|
||||
embed_size: Embedding size.
|
||||
hidden_size: Hidden size..
|
||||
rnn_type: Decoder layers type.
|
||||
num_layers: Number of decoder layers.
|
||||
dropout_rate: Dropout rate for decoder layers.
|
||||
embed_dropout_rate: Dropout rate for embedding layer.
|
||||
embed_pad: Embedding padding symbol ID.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embed_size: int = 256,
|
||||
hidden_size: int = 256,
|
||||
rnn_type: str = "lstm",
|
||||
num_layers: int = 1,
|
||||
dropout_rate: float = 0.0,
|
||||
embed_dropout_rate: float = 0.0,
|
||||
embed_pad: int = 0,
|
||||
) -> None:
|
||||
"""Construct a RNNDecoder object."""
|
||||
super().__init__()
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
if rnn_type not in ("lstm", "gru"):
|
||||
raise ValueError(f"Not supported: rnn_type={rnn_type}")
|
||||
|
||||
self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
|
||||
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
|
||||
|
||||
rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
|
||||
|
||||
self.rnn = torch.nn.ModuleList(
|
||||
[rnn_class(embed_size, hidden_size, 1, batch_first=True)]
|
||||
)
|
||||
|
||||
for _ in range(1, num_layers):
|
||||
self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
|
||||
|
||||
self.dropout_rnn = torch.nn.ModuleList(
|
||||
[torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
|
||||
)
|
||||
|
||||
self.dlayers = num_layers
|
||||
self.dtype = rnn_type
|
||||
|
||||
self.output_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.device = next(self.parameters()).device
|
||||
self.score_cache = {}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
labels: torch.Tensor,
|
||||
label_lens: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
labels: Label ID sequences. (B, L)
|
||||
states: Decoder hidden states.
|
||||
((N, B, D_dec), (N, B, D_dec) or None) or None
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, U, D_dec)
|
||||
|
||||
"""
|
||||
if states is None:
|
||||
states = self.init_state(labels.size(0))
|
||||
|
||||
dec_embed = self.dropout_embed(self.embed(labels))
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
return dec_out
|
||||
|
||||
def rnn_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""Encode source label sequences.
|
||||
|
||||
Args:
|
||||
x: RNN input sequences. (B, D_emb)
|
||||
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
x: RNN output sequences. (B, D_dec)
|
||||
(h_next, c_next): Decoder hidden states.
|
||||
(N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_prev, c_prev = state
|
||||
h_next, c_next = self.init_state(x.size(0))
|
||||
|
||||
for layer in range(self.dlayers):
|
||||
if self.dtype == "lstm":
|
||||
x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
|
||||
layer
|
||||
](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
|
||||
else:
|
||||
x, h_next[layer : layer + 1] = self.rnn[layer](
|
||||
x, hx=h_prev[layer : layer + 1]
|
||||
)
|
||||
|
||||
x = self.dropout_rnn[layer](x)
|
||||
|
||||
return x, (h_next, c_next)
|
||||
|
||||
def score(
|
||||
self,
|
||||
label: torch.Tensor,
|
||||
label_sequence: List[int],
|
||||
dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypothesis.
|
||||
|
||||
Args:
|
||||
label: Previous label. (1, 1)
|
||||
label_sequence: Current label sequence.
|
||||
dec_state: Previous decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequence. (1, D_dec)
|
||||
dec_state: Decoder hidden states.
|
||||
((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
str_labels = "_".join(map(str, label_sequence))
|
||||
|
||||
if str_labels in self.score_cache:
|
||||
dec_out, dec_state = self.score_cache[str_labels]
|
||||
else:
|
||||
dec_embed = self.embed(label)
|
||||
dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
|
||||
|
||||
self.score_cache[str_labels] = (dec_out, dec_state)
|
||||
|
||||
return dec_out[0], dec_state
|
||||
|
||||
def batch_score(
|
||||
self,
|
||||
hyps: List[Hypothesis],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""One-step forward hypotheses.
|
||||
|
||||
Args:
|
||||
hyps: Hypotheses.
|
||||
|
||||
Returns:
|
||||
dec_out: Decoder output sequences. (B, D_dec)
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
|
||||
dec_embed = self.embed(labels)
|
||||
|
||||
states = self.create_batch_states([h.dec_state for h in hyps])
|
||||
dec_out, states = self.rnn_forward(dec_embed, states)
|
||||
|
||||
return dec_out.squeeze(1), states
|
||||
|
||||
def set_device(self, device: torch.device) -> None:
|
||||
"""Set GPU device to use.
|
||||
|
||||
Args:
|
||||
device: Device ID.
|
||||
|
||||
"""
|
||||
self.device = device
|
||||
|
||||
def init_state(
|
||||
self, batch_size: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
|
||||
"""Initialize decoder states.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size.
|
||||
|
||||
Returns:
|
||||
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
h_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.dtype == "lstm":
|
||||
c_n = torch.zeros(
|
||||
self.dlayers,
|
||||
batch_size,
|
||||
self.output_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
return (h_n, c_n)
|
||||
|
||||
return (h_n, None)
|
||||
|
||||
def select_state(
|
||||
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Get specified ID state from decoder hidden states.
|
||||
|
||||
Args:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
idx: State ID to extract.
|
||||
|
||||
Returns:
|
||||
: Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
states[0][:, idx : idx + 1, :],
|
||||
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
|
||||
)
|
||||
|
||||
def create_batch_states(
|
||||
self,
|
||||
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Create decoder hidden states.
|
||||
|
||||
Args:
|
||||
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
|
||||
|
||||
Returns:
|
||||
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
|
||||
|
||||
"""
|
||||
return (
|
||||
torch.cat([s[0] for s in new_states], dim=1),
|
||||
torch.cat([s[1] for s in new_states], dim=1)
|
||||
if self.dtype == "lstm"
|
||||
else None,
|
||||
)
|
||||
Reference in New Issue
Block a user