mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
Merge pull request #1232 from boji123/bj_dev_feat_len_pad
a better solution for mismatch of speech feat len and speech token len when trainning
This commit is contained in:
@@ -74,6 +74,9 @@ class CosyVoice:
|
||||
self.frontend.spk2info[zero_shot_spk_id] = model_input
|
||||
return True
|
||||
|
||||
def save_spkinfo(self):
|
||||
torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
|
||||
|
||||
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_sft(i, spk_id)
|
||||
@@ -99,9 +102,9 @@ class CosyVoice:
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
||||
def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
@@ -174,10 +177,10 @@ class CosyVoice2(CosyVoice):
|
||||
def inference_instruct(self, *args, **kwargs):
|
||||
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
||||
|
||||
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
||||
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
||||
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
|
||||
@@ -178,8 +178,8 @@ class CosyVoiceFrontEnd:
|
||||
model_input['text_len'] = tts_text_token_len
|
||||
return model_input
|
||||
|
||||
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
|
||||
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
|
||||
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
||||
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
||||
# in cross lingual mode, we remove prompt in llm
|
||||
del model_input['prompt_text']
|
||||
del model_input['prompt_text_len']
|
||||
@@ -196,8 +196,8 @@ class CosyVoiceFrontEnd:
|
||||
model_input['prompt_text_len'] = instruct_text_token_len
|
||||
return model_input
|
||||
|
||||
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
|
||||
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
|
||||
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
||||
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
||||
del model_input['llm_prompt_speech_token']
|
||||
del model_input['llm_prompt_speech_token_len']
|
||||
return model_input
|
||||
|
||||
@@ -159,6 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'):
|
||||
|
||||
def compute_fbank(data,
|
||||
feat_extractor,
|
||||
token_mel_ratio=2,
|
||||
mode='train'):
|
||||
""" Extract fbank
|
||||
|
||||
@@ -174,8 +175,14 @@ def compute_fbank(data,
|
||||
assert 'utt' in sample
|
||||
assert 'text_token' in sample
|
||||
waveform = sample['speech']
|
||||
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
||||
sample['speech_feat'] = mat
|
||||
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
||||
|
||||
# trim to align speech_token and speech_feat
|
||||
token_len = min(feat.shape[0] // token_mel_ratio, sample["speech_token"].shape[0])
|
||||
feat = feat[:token_mel_ratio * token_len]
|
||||
sample["speech_token"] = sample["speech_token"][:token_len]
|
||||
|
||||
sample['speech_feat'] = feat
|
||||
yield sample
|
||||
|
||||
|
||||
|
||||
@@ -92,7 +92,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
|
||||
mask = (~make_pad_mask(feat_len)).to(h)
|
||||
# NOTE this is unnecessary, feat/h already same shape
|
||||
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
feat.transpose(1, 2).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
@@ -214,7 +213,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
# get conditions
|
||||
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
||||
conds = torch.zeros(feat.shape, device=token.device)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
|
||||
Reference in New Issue
Block a user