mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add llm bistream
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
import time
|
||||
from typing import Generator
|
||||
from tqdm import tqdm
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from modelscope import snapshot_download
|
||||
@@ -76,7 +77,7 @@ class CosyVoice:
|
||||
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
||||
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
if len(i) < 0.5 * len(prompt_text):
|
||||
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
||||
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
||||
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
||||
start_time = time.time()
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
import json
|
||||
import onnxruntime
|
||||
import torch
|
||||
@@ -31,6 +32,7 @@ except ImportError:
|
||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||
from tn.english.normalizer import Normalizer as EnNormalizer
|
||||
use_ttsfrd = False
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
||||
|
||||
|
||||
@@ -71,10 +73,21 @@ class CosyVoiceFrontEnd:
|
||||
self.inflect_parser = inflect.engine()
|
||||
|
||||
def _extract_text_token(self, text):
|
||||
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
||||
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||
return text_token, text_token_len
|
||||
if isinstance(text, Generator):
|
||||
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
||||
# NOTE add a dummy text_token_len for compatibility
|
||||
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
||||
else:
|
||||
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
||||
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||
return text_token, text_token_len
|
||||
|
||||
def _extract_text_token_generator(self, text_generator):
|
||||
for text in text_generator:
|
||||
text_token, _ = self._extract_text_token(text)
|
||||
for i in range(text_token.shape[1]):
|
||||
yield text_token[:, i: i + 1]
|
||||
|
||||
def _extract_speech_token(self, speech):
|
||||
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
||||
@@ -106,6 +119,9 @@ class CosyVoiceFrontEnd:
|
||||
return speech_feat, speech_feat_len
|
||||
|
||||
def text_normalize(self, text, split=True, text_frontend=True):
|
||||
if isinstance(text, Generator):
|
||||
logging.info('get tts_text generator, will skip text_normalize!')
|
||||
return [text]
|
||||
if text_frontend is False:
|
||||
return [text] if split is True else text
|
||||
text = text.strip()
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Generator
|
||||
import torch
|
||||
import numpy as np
|
||||
import threading
|
||||
@@ -99,14 +100,24 @@ class CosyVoiceModel:
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
with self.llm_context:
|
||||
for i in self.llm.inference(text=text.to(self.device),
|
||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device)):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
if isinstance(text, Generator):
|
||||
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
||||
for i in self.llm.inference_bistream(text=text,
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device)):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
else:
|
||||
for i in self.llm.inference(text=text.to(self.device),
|
||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device)):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
self.llm_end_dict[uuid] = True
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
||||
|
||||
Reference in New Issue
Block a user