From 0fd15bb12b9f79bbdb86a496b920311130f1710c Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Wed, 10 Jul 2024 17:49:32 +0800 Subject: [PATCH] use spk_embedding when sft --- cosyvoice/flow/flow.py | 2 +- cosyvoice/llm/llm.py | 2 +- cosyvoice/utils/executor.py | 4 ++++ examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml | 1 + examples/libritts/cosyvoice/conf/cosyvoice.yaml | 1 + 5 files changed, 8 insertions(+), 2 deletions(-) diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index d0dbcd0..009160a 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -60,7 +60,7 @@ class MaskedDiffWithXvec(torch.nn.Module): token_len = batch['speech_token_len'].to(device) feat = batch['speech_feat'].to(device) feat_len = batch['speech_feat_len'].to(device) - embedding = batch['utt_embedding'].to(device) + embedding = batch['embedding'].to(device) # xvec projection embedding = F.normalize(embedding, dim=1) diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 05c22ef..3b418c5 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -97,7 +97,7 @@ class TransformerLM(torch.nn.Module): text_token_len = batch['text_token_len'].to(device) speech_token = batch['speech_token'].to(device) speech_token_len = batch['speech_token_len'].to(device) - embedding = batch['utt_embedding'].to(device) + embedding = batch['embedding'].to(device) # 1. prepare llm_target lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))] diff --git a/cosyvoice/utils/executor.py b/cosyvoice/utils/executor.py index c12e52d..f7dfb0e 100644 --- a/cosyvoice/utils/executor.py +++ b/cosyvoice/utils/executor.py @@ -52,6 +52,10 @@ class Executor: info_dict["batch_idx"] = batch_idx if cosyvoice_join(group_join, info_dict): break + if info_dict["use_spk_embedding"] is True: + batch_dict["embedding"] = batch_dict["spk_embedding"] + else: + batch_dict["embedding"] = batch_dict["utt_embedding"] # Disable gradient synchronizations across DDP processes. # Within this context, gradients will be accumulated on module diff --git a/examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml b/examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml index 10206e6..b67b528 100644 --- a/examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml +++ b/examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml @@ -190,6 +190,7 @@ train_conf: scheduler: warmuplr scheduler_conf: warmup_steps: 25000 + use_spk_embedding: False # change to True during sft max_epoch: 200 grad_clip: 5 accum_grad: 2 diff --git a/examples/libritts/cosyvoice/conf/cosyvoice.yaml b/examples/libritts/cosyvoice/conf/cosyvoice.yaml index c791c76..588086c 100644 --- a/examples/libritts/cosyvoice/conf/cosyvoice.yaml +++ b/examples/libritts/cosyvoice/conf/cosyvoice.yaml @@ -190,6 +190,7 @@ train_conf: scheduler: warmuplr # change to constantlr during sft scheduler_conf: warmup_steps: 2500 + use_spk_embedding: False # change to True during sft max_epoch: 200 grad_clip: 5 accum_grad: 2