mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update
This commit is contained in:
@@ -157,6 +157,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
vocab_size: int = 4096,
|
||||
input_frame_rate: int = 50,
|
||||
only_mask_loss: bool = True,
|
||||
token_mel_ratio: int = 2,
|
||||
pre_lookahead_len: int = 3,
|
||||
encoder: torch.nn.Module = None,
|
||||
decoder: torch.nn.Module = None,
|
||||
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||
@@ -181,6 +183,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
||||
self.decoder = decoder
|
||||
self.only_mask_loss = only_mask_loss
|
||||
self.token_mel_ratio = token_mel_ratio
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self,
|
||||
@@ -206,7 +210,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
if finalize is False:
|
||||
h = h[:, :-self.encoder.pre_lookahead_layer.pre_lookahead_len * self.encoder.up_layer.stride]
|
||||
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user