fix bistream bug

This commit is contained in:
lyuxiang.lx
2025-12-12 10:41:25 +00:00
parent b02d7e61f7
commit ca3b054a52
6 changed files with 37 additions and 34 deletions

View File

@@ -122,12 +122,12 @@ class CosyVoiceFrontEnd:
return speech_feat, speech_feat_len return speech_feat, speech_feat_len
def text_normalize(self, text, split=True, text_frontend=True): def text_normalize(self, text, split=True, text_frontend=True):
# NOTE skip text_frontend when ssml symbol in text
if '<|' in text and '|>' in text:
text_frontend = False
if isinstance(text, Generator): if isinstance(text, Generator):
logging.info('get tts_text generator, will skip text_normalize!') logging.info('get tts_text generator, will skip text_normalize!')
return [text] return [text]
# NOTE skip text_frontend when ssml symbol in text
if '<|' in text and '|>' in text:
text_frontend = False
if text_frontend is False or text == '': if text_frontend is False or text == '':
return [text] if split is True else text return [text] if split is True else text
text = text.strip() text = text.strip()

View File

@@ -413,18 +413,18 @@ class CosyVoice3Model(CosyVoice2Model):
embedding=embedding.to(self.device), embedding=embedding.to(self.device),
streaming=stream, streaming=stream,
finalize=finalize) finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append mel cache # append mel cache
if self.hift_cache_dict[uuid] is not None: if self.hift_cache_dict[uuid] is not None:
hift_cache_mel = self.hift_cache_dict[uuid]['mel'] hift_cache_mel = self.hift_cache_dict[uuid]['mel']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
self.hift_cache_dict[uuid]['mel'] = tts_mel self.hift_cache_dict[uuid]['mel'] = tts_mel
else: else:
self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0} self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
if speed != 1.0: if speed != 1.0:
assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode' assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear') tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize) tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:] tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1] self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
return tts_speech return tts_speech

View File

@@ -155,11 +155,13 @@ class SineGen(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, f0): def forward(self, f0):
""" sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, dim=1, length)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
""" """
:param f0: [B, 1, sample_len], Hz f0 = f0.transpose(1, 2)
:return: [B, 1, sample_len]
"""
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
for i in range(self.harmonic_num + 1): for i in range(self.harmonic_num + 1):
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
@@ -184,7 +186,7 @@ class SineGen(torch.nn.Module):
# first: set the unvoiced part to 0 by uv # first: set the unvoiced part to 0 by uv
# then: additive noise # then: additive noise
sine_waves = sine_waves * uv + noise sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise return sine_waves.transpose(1, 2), uv.transpose(1, 2), noise
class SineGen2(torch.nn.Module): class SineGen2(torch.nn.Module):
@@ -221,7 +223,7 @@ class SineGen2(torch.nn.Module):
if causal is True: if causal is True:
self.rand_ini = torch.rand(1, 9) self.rand_ini = torch.rand(1, 9)
self.rand_ini[:, 0] = 0 self.rand_ini[:, 0] = 0
self.sine_waves = torch.rand(1, 60 * 16000, 9) self.sine_waves = torch.rand(1, 300 * 24000, 9)
def _f02uv(self, f0): def _f02uv(self, f0):
# generate uv signal # generate uv signal
@@ -351,7 +353,7 @@ class SourceModuleHnNSF(torch.nn.Module):
self.l_tanh = torch.nn.Tanh() self.l_tanh = torch.nn.Tanh()
self.causal = causal self.causal = causal
if causal is True: if causal is True:
self.uv = torch.rand(1, 60 * 24000, 1) self.uv = torch.rand(1, 300 * 24000, 1)
def forward(self, x): def forward(self, x):
""" """

View File

@@ -17,6 +17,7 @@ import random
import time import time
import threading import threading
from typing import Dict, Optional, Callable, List, Generator from typing import Dict, Optional, Callable, List, Generator
import numpy as np
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -216,7 +217,7 @@ class TransformerLM(torch.nn.Module):
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
device=lm_input.device)).to(torch.bool)) device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
if top_ids == self.eos_token: if top_ids == self.eos_token:
break break
# in stream mode, yield token one by one # in stream mode, yield token one by one
@@ -544,7 +545,7 @@ class Qwen2LM(TransformerLM):
cache = None cache = None
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5 # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
text_cache = self.llm.model.model.embed_tokens(prompt_text) text_cache = self.llm.model.model.embed_tokens(prompt_text)
next_fill_index = -1 next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
for this_text in text: for this_text in text:
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1) text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
# prompt_speech_token_emb not empty, try append to lm_input # prompt_speech_token_emb not empty, try append to lm_input
@@ -582,7 +583,7 @@ class Qwen2LM(TransformerLM):
top_ids = self.fill_token top_ids = self.fill_token
next_fill_index += (self.mix_ratio[1] + 1) next_fill_index += (self.mix_ratio[1] + 1)
else: else:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item() top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
if top_ids == self.fill_token: if top_ids == self.fill_token:
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1 next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index)) logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))

View File

@@ -15,15 +15,15 @@ def cosyvoice_example():
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M') cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M')
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean # zero_shot usage
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')): for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# cross_lingual usage # cross_lingual usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.',
'./asset/cross_lingual_prompt.wav')): './asset/cross_lingual_prompt.wav')):
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# vc usage # vc usage
for i, j in enumerate(cosyvoice.inference_vc('./asset/zero_shot_prompt.wav', './asset/cross_lingual_prompt.wav')): for i, j in enumerate(cosyvoice.inference_vc('./asset/cross_lingual_prompt.wav', './asset/zero_shot_prompt.wav')):
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-Instruct') cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-Instruct')
@@ -65,7 +65,7 @@ def cosyvoice2_example():
yield '让我心中充满了甜蜜的快乐,' yield '让我心中充满了甜蜜的快乐,'
yield '笑容如花儿般绽放。' yield '笑容如花儿般绽放。'
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)): for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) torchaudio.save('zero_shot_bistream_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
def cosyvoice3_example(): def cosyvoice3_example():
@@ -97,8 +97,8 @@ def cosyvoice3_example():
def main(): def main():
cosyvoice_example() # cosyvoice_example()
cosyvoice2_example() # cosyvoice2_example()
cosyvoice3_example() cosyvoice3_example()

View File

@@ -31,7 +31,7 @@ def cosyvoice3_example():
def main(): def main():
cosyvoice2_example() # cosyvoice2_example()
cosyvoice3_example() cosyvoice3_example()