This commit is contained in:
lyuxiang.lx
2025-02-06 16:07:13 +08:00
parent 24f796a2b1
commit 2a3e033ee1
17 changed files with 187 additions and 135 deletions

View File

@@ -112,10 +112,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
prompt_feat_len,
embedding,
flow_cache):
if self.fp16 is True:
prompt_feat = prompt_feat.half()
embedding = embedding.half()
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
@@ -146,7 +142,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
cond=conds,
n_timesteps=10,
prompt_len=mel_len1,
flow_cache=flow_cache
cache=flow_cache
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
@@ -249,10 +245,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
embedding,
cache,
finalize):
if self.fp16 is True:
prompt_feat = prompt_feat.half()
embedding = embedding.half()
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)