mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix vocoder train
This commit is contained in:
@@ -91,7 +91,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(feat_len)).to(h)
|
||||
# NOTE 这一句应该是不需要的,应该h已经过length_regulator跟feat一样的shape了
|
||||
# NOTE this is unnecessary, feat/h already same shape
|
||||
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
feat.transpose(1, 2).contiguous(),
|
||||
@@ -117,7 +117,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
# concat speech token and prompt speech token
|
||||
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||
|
||||
Reference in New Issue
Block a user