add llm bistream

This commit is contained in:
lyuxiang.lx
2025-01-23 10:12:06 +08:00
parent 0b75c3a03f
commit 07e477519b
10 changed files with 163 additions and 39 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Generator
import json
import onnxruntime
import torch
@@ -31,6 +32,7 @@ except ImportError:
from tn.chinese.normalizer import Normalizer as ZhNormalizer
from tn.english.normalizer import Normalizer as EnNormalizer
use_ttsfrd = False
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
@@ -71,10 +73,21 @@ class CosyVoiceFrontEnd:
self.inflect_parser = inflect.engine()
def _extract_text_token(self, text):
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
return text_token, text_token_len
if isinstance(text, Generator):
logging.info('get tts_text generator, will return _extract_text_token_generator!')
# NOTE add a dummy text_token_len for compatibility
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
else:
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
return text_token, text_token_len
def _extract_text_token_generator(self, text_generator):
for text in text_generator:
text_token, _ = self._extract_text_token(text)
for i in range(text_token.shape[1]):
yield text_token[:, i: i + 1]
def _extract_speech_token(self, speech):
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
@@ -106,6 +119,9 @@ class CosyVoiceFrontEnd:
return speech_feat, speech_feat_len
def text_normalize(self, text, split=True, text_frontend=True):
if isinstance(text, Generator):
logging.info('get tts_text generator, will skip text_normalize!')
return [text]
if text_frontend is False:
return [text] if split is True else text
text = text.strip()