mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add llm bistream
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
import json
|
||||
import onnxruntime
|
||||
import torch
|
||||
@@ -31,6 +32,7 @@ except ImportError:
|
||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||
from tn.english.normalizer import Normalizer as EnNormalizer
|
||||
use_ttsfrd = False
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
||||
|
||||
|
||||
@@ -71,10 +73,21 @@ class CosyVoiceFrontEnd:
|
||||
self.inflect_parser = inflect.engine()
|
||||
|
||||
def _extract_text_token(self, text):
|
||||
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
||||
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||
return text_token, text_token_len
|
||||
if isinstance(text, Generator):
|
||||
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
||||
# NOTE add a dummy text_token_len for compatibility
|
||||
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
||||
else:
|
||||
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
||||
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||
return text_token, text_token_len
|
||||
|
||||
def _extract_text_token_generator(self, text_generator):
|
||||
for text in text_generator:
|
||||
text_token, _ = self._extract_text_token(text)
|
||||
for i in range(text_token.shape[1]):
|
||||
yield text_token[:, i: i + 1]
|
||||
|
||||
def _extract_speech_token(self, speech):
|
||||
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
||||
@@ -106,6 +119,9 @@ class CosyVoiceFrontEnd:
|
||||
return speech_feat, speech_feat_len
|
||||
|
||||
def text_normalize(self, text, split=True, text_frontend=True):
|
||||
if isinstance(text, Generator):
|
||||
logging.info('get tts_text generator, will skip text_normalize!')
|
||||
return [text]
|
||||
if text_frontend is False:
|
||||
return [text] if split is True else text
|
||||
text = text.strip()
|
||||
|
||||
Reference in New Issue
Block a user