use spk_embedding when sft

This commit is contained in:
lyuxiang.lx
2024-07-10 17:49:32 +08:00
parent a723ea375e
commit 0fd15bb12b
5 changed files with 8 additions and 2 deletions

View File

@@ -60,7 +60,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['utt_embedding'].to(device)
embedding = batch['embedding'].to(device)
# xvec projection
embedding = F.normalize(embedding, dim=1)