mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
fix bistream bug
This commit is contained in:
@@ -122,12 +122,12 @@ class CosyVoiceFrontEnd:
|
|||||||
return speech_feat, speech_feat_len
|
return speech_feat, speech_feat_len
|
||||||
|
|
||||||
def text_normalize(self, text, split=True, text_frontend=True):
|
def text_normalize(self, text, split=True, text_frontend=True):
|
||||||
# NOTE skip text_frontend when ssml symbol in text
|
|
||||||
if '<|' in text and '|>' in text:
|
|
||||||
text_frontend = False
|
|
||||||
if isinstance(text, Generator):
|
if isinstance(text, Generator):
|
||||||
logging.info('get tts_text generator, will skip text_normalize!')
|
logging.info('get tts_text generator, will skip text_normalize!')
|
||||||
return [text]
|
return [text]
|
||||||
|
# NOTE skip text_frontend when ssml symbol in text
|
||||||
|
if '<|' in text and '|>' in text:
|
||||||
|
text_frontend = False
|
||||||
if text_frontend is False or text == '':
|
if text_frontend is False or text == '':
|
||||||
return [text] if split is True else text
|
return [text] if split is True else text
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
|
|||||||
@@ -413,18 +413,18 @@ class CosyVoice3Model(CosyVoice2Model):
|
|||||||
embedding=embedding.to(self.device),
|
embedding=embedding.to(self.device),
|
||||||
streaming=stream,
|
streaming=stream,
|
||||||
finalize=finalize)
|
finalize=finalize)
|
||||||
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||||
# append mel cache
|
# append mel cache
|
||||||
if self.hift_cache_dict[uuid] is not None:
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
hift_cache_mel = self.hift_cache_dict[uuid]['mel']
|
hift_cache_mel = self.hift_cache_dict[uuid]['mel']
|
||||||
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||||
self.hift_cache_dict[uuid]['mel'] = tts_mel
|
self.hift_cache_dict[uuid]['mel'] = tts_mel
|
||||||
else:
|
else:
|
||||||
self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
|
self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
|
||||||
if speed != 1.0:
|
if speed != 1.0:
|
||||||
assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
|
assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
|
||||||
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||||
tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
|
tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
|
||||||
tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
|
tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
|
||||||
self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
|
self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
|
||||||
return tts_speech
|
return tts_speech
|
||||||
|
|||||||
@@ -155,11 +155,13 @@ class SineGen(torch.nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, f0):
|
def forward(self, f0):
|
||||||
|
""" sine_tensor, uv = forward(f0)
|
||||||
|
input F0: tensor(batchsize=1, dim=1, length)
|
||||||
|
f0 for unvoiced steps should be 0
|
||||||
|
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||||
|
output uv: tensor(batchsize=1, length, 1)
|
||||||
"""
|
"""
|
||||||
:param f0: [B, 1, sample_len], Hz
|
f0 = f0.transpose(1, 2)
|
||||||
:return: [B, 1, sample_len]
|
|
||||||
"""
|
|
||||||
|
|
||||||
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
||||||
for i in range(self.harmonic_num + 1):
|
for i in range(self.harmonic_num + 1):
|
||||||
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
||||||
@@ -184,7 +186,7 @@ class SineGen(torch.nn.Module):
|
|||||||
# first: set the unvoiced part to 0 by uv
|
# first: set the unvoiced part to 0 by uv
|
||||||
# then: additive noise
|
# then: additive noise
|
||||||
sine_waves = sine_waves * uv + noise
|
sine_waves = sine_waves * uv + noise
|
||||||
return sine_waves, uv, noise
|
return sine_waves.transpose(1, 2), uv.transpose(1, 2), noise
|
||||||
|
|
||||||
|
|
||||||
class SineGen2(torch.nn.Module):
|
class SineGen2(torch.nn.Module):
|
||||||
@@ -221,7 +223,7 @@ class SineGen2(torch.nn.Module):
|
|||||||
if causal is True:
|
if causal is True:
|
||||||
self.rand_ini = torch.rand(1, 9)
|
self.rand_ini = torch.rand(1, 9)
|
||||||
self.rand_ini[:, 0] = 0
|
self.rand_ini[:, 0] = 0
|
||||||
self.sine_waves = torch.rand(1, 60 * 16000, 9)
|
self.sine_waves = torch.rand(1, 300 * 24000, 9)
|
||||||
|
|
||||||
def _f02uv(self, f0):
|
def _f02uv(self, f0):
|
||||||
# generate uv signal
|
# generate uv signal
|
||||||
@@ -351,7 +353,7 @@ class SourceModuleHnNSF(torch.nn.Module):
|
|||||||
self.l_tanh = torch.nn.Tanh()
|
self.l_tanh = torch.nn.Tanh()
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
if causal is True:
|
if causal is True:
|
||||||
self.uv = torch.rand(1, 60 * 24000, 1)
|
self.uv = torch.rand(1, 300 * 24000, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import random
|
|||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
from typing import Dict, Optional, Callable, List, Generator
|
from typing import Dict, Optional, Callable, List, Generator
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -216,7 +217,7 @@ class TransformerLM(torch.nn.Module):
|
|||||||
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
||||||
device=lm_input.device)).to(torch.bool))
|
device=lm_input.device)).to(torch.bool))
|
||||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
|
||||||
if top_ids == self.eos_token:
|
if top_ids == self.eos_token:
|
||||||
break
|
break
|
||||||
# in stream mode, yield token one by one
|
# in stream mode, yield token one by one
|
||||||
@@ -544,7 +545,7 @@ class Qwen2LM(TransformerLM):
|
|||||||
cache = None
|
cache = None
|
||||||
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
||||||
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
||||||
next_fill_index = -1
|
next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
|
||||||
for this_text in text:
|
for this_text in text:
|
||||||
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
||||||
# prompt_speech_token_emb not empty, try append to lm_input
|
# prompt_speech_token_emb not empty, try append to lm_input
|
||||||
@@ -582,7 +583,7 @@ class Qwen2LM(TransformerLM):
|
|||||||
top_ids = self.fill_token
|
top_ids = self.fill_token
|
||||||
next_fill_index += (self.mix_ratio[1] + 1)
|
next_fill_index += (self.mix_ratio[1] + 1)
|
||||||
else:
|
else:
|
||||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
|
||||||
if top_ids == self.fill_token:
|
if top_ids == self.fill_token:
|
||||||
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
||||||
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
||||||
|
|||||||
12
example.py
12
example.py
@@ -15,15 +15,15 @@ def cosyvoice_example():
|
|||||||
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M')
|
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M')
|
||||||
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
# zero_shot usage
|
||||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
|
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
|
||||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
# cross_lingual usage
|
# cross_lingual usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
||||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.',
|
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.',
|
||||||
'./asset/cross_lingual_prompt.wav')):
|
'./asset/cross_lingual_prompt.wav')):
|
||||||
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
# vc usage
|
# vc usage
|
||||||
for i, j in enumerate(cosyvoice.inference_vc('./asset/zero_shot_prompt.wav', './asset/cross_lingual_prompt.wav')):
|
for i, j in enumerate(cosyvoice.inference_vc('./asset/cross_lingual_prompt.wav', './asset/zero_shot_prompt.wav')):
|
||||||
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-Instruct')
|
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-Instruct')
|
||||||
@@ -65,7 +65,7 @@ def cosyvoice2_example():
|
|||||||
yield '让我心中充满了甜蜜的快乐,'
|
yield '让我心中充满了甜蜜的快乐,'
|
||||||
yield '笑容如花儿般绽放。'
|
yield '笑容如花儿般绽放。'
|
||||||
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
|
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
|
||||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
torchaudio.save('zero_shot_bistream_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
def cosyvoice3_example():
|
def cosyvoice3_example():
|
||||||
@@ -97,8 +97,8 @@ def cosyvoice3_example():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cosyvoice_example()
|
# cosyvoice_example()
|
||||||
cosyvoice2_example()
|
# cosyvoice2_example()
|
||||||
cosyvoice3_example()
|
cosyvoice3_example()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ def cosyvoice3_example():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cosyvoice2_example()
|
# cosyvoice2_example()
|
||||||
cosyvoice3_example()
|
cosyvoice3_example()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user