This commit is contained in:
lyuxiang.lx
2024-12-12 15:43:17 +08:00
parent 0bf706c26f
commit 2345ce6be2
4 changed files with 19 additions and 9 deletions

View File

@@ -157,6 +157,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
token_mel_ratio: int = 2,
pre_lookahead_len: int = 3,
encoder: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
@@ -181,6 +183,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len
@torch.inference_mode()
def inference(self,
@@ -206,7 +210,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
# text encode
h, h_lengths = self.encoder(token, token_len)
if finalize is False:
h = h[:, :-self.encoder.pre_lookahead_layer.pre_lookahead_len * self.encoder.up_layer.stride]
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)