Merge pull request #1232 from boji123/bj_dev_feat_len_pad

a better solution for mismatch of speech feat len and speech token len when trainning
This commit is contained in:
Xiang Lyu
2025-04-30 09:41:50 +08:00
committed by GitHub
6 changed files with 59 additions and 13 deletions

View File

@@ -74,6 +74,9 @@ class CosyVoice:
self.frontend.spk2info[zero_shot_spk_id] = model_input
return True
def save_spkinfo(self):
torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_sft(i, spk_id)
@@ -99,9 +102,9 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -174,10 +177,10 @@ class CosyVoice2(CosyVoice):
def inference_instruct(self, *args, **kwargs):
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):

View File

@@ -178,8 +178,8 @@ class CosyVoiceFrontEnd:
model_input['text_len'] = tts_text_token_len
return model_input
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
# in cross lingual mode, we remove prompt in llm
del model_input['prompt_text']
del model_input['prompt_text_len']
@@ -196,8 +196,8 @@ class CosyVoiceFrontEnd:
model_input['prompt_text_len'] = instruct_text_token_len
return model_input
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
del model_input['llm_prompt_speech_token']
del model_input['llm_prompt_speech_token_len']
return model_input

View File

@@ -159,6 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'):
def compute_fbank(data,
feat_extractor,
token_mel_ratio=2,
mode='train'):
""" Extract fbank
@@ -174,8 +175,14 @@ def compute_fbank(data,
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
sample['speech_feat'] = mat
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
# trim to align speech_token and speech_feat
token_len = min(feat.shape[0] // token_mel_ratio, sample["speech_token"].shape[0])
feat = feat[:token_mel_ratio * token_len]
sample["speech_token"] = sample["speech_token"][:token_len]
sample['speech_feat'] = feat
yield sample

View File

@@ -92,7 +92,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
mask = (~make_pad_mask(feat_len)).to(h)
# NOTE this is unnecessary, feat/h already same shape
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
@@ -214,7 +213,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
h = self.encoder_proj(h)
# get conditions
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5: