mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
use spk_embedding when sft
This commit is contained in:
@@ -60,7 +60,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
token_len = batch['speech_token_len'].to(device)
|
token_len = batch['speech_token_len'].to(device)
|
||||||
feat = batch['speech_feat'].to(device)
|
feat = batch['speech_feat'].to(device)
|
||||||
feat_len = batch['speech_feat_len'].to(device)
|
feat_len = batch['speech_feat_len'].to(device)
|
||||||
embedding = batch['utt_embedding'].to(device)
|
embedding = batch['embedding'].to(device)
|
||||||
|
|
||||||
# xvec projection
|
# xvec projection
|
||||||
embedding = F.normalize(embedding, dim=1)
|
embedding = F.normalize(embedding, dim=1)
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ class TransformerLM(torch.nn.Module):
|
|||||||
text_token_len = batch['text_token_len'].to(device)
|
text_token_len = batch['text_token_len'].to(device)
|
||||||
speech_token = batch['speech_token'].to(device)
|
speech_token = batch['speech_token'].to(device)
|
||||||
speech_token_len = batch['speech_token_len'].to(device)
|
speech_token_len = batch['speech_token_len'].to(device)
|
||||||
embedding = batch['utt_embedding'].to(device)
|
embedding = batch['embedding'].to(device)
|
||||||
|
|
||||||
# 1. prepare llm_target
|
# 1. prepare llm_target
|
||||||
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
|
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ class Executor:
|
|||||||
info_dict["batch_idx"] = batch_idx
|
info_dict["batch_idx"] = batch_idx
|
||||||
if cosyvoice_join(group_join, info_dict):
|
if cosyvoice_join(group_join, info_dict):
|
||||||
break
|
break
|
||||||
|
if info_dict["use_spk_embedding"] is True:
|
||||||
|
batch_dict["embedding"] = batch_dict["spk_embedding"]
|
||||||
|
else:
|
||||||
|
batch_dict["embedding"] = batch_dict["utt_embedding"]
|
||||||
|
|
||||||
# Disable gradient synchronizations across DDP processes.
|
# Disable gradient synchronizations across DDP processes.
|
||||||
# Within this context, gradients will be accumulated on module
|
# Within this context, gradients will be accumulated on module
|
||||||
|
|||||||
@@ -190,6 +190,7 @@ train_conf:
|
|||||||
scheduler: warmuplr
|
scheduler: warmuplr
|
||||||
scheduler_conf:
|
scheduler_conf:
|
||||||
warmup_steps: 25000
|
warmup_steps: 25000
|
||||||
|
use_spk_embedding: False # change to True during sft
|
||||||
max_epoch: 200
|
max_epoch: 200
|
||||||
grad_clip: 5
|
grad_clip: 5
|
||||||
accum_grad: 2
|
accum_grad: 2
|
||||||
|
|||||||
@@ -190,6 +190,7 @@ train_conf:
|
|||||||
scheduler: warmuplr # change to constantlr during sft
|
scheduler: warmuplr # change to constantlr during sft
|
||||||
scheduler_conf:
|
scheduler_conf:
|
||||||
warmup_steps: 2500
|
warmup_steps: 2500
|
||||||
|
use_spk_embedding: False # change to True during sft
|
||||||
max_epoch: 200
|
max_epoch: 200
|
||||||
grad_clip: 5
|
grad_clip: 5
|
||||||
accum_grad: 2
|
accum_grad: 2
|
||||||
|
|||||||
Reference in New Issue
Block a user