优先使用ttsfrd,ttsfrd不存在时使用WeTextProcessing

This commit is contained in:
passerbya
2024-07-09 17:25:55 +08:00
parent 39afb98fa1
commit 95b8866f3c

View File

@@ -23,7 +23,12 @@ import os
import inflect import inflect
from tn.chinese.normalizer import Normalizer as ZhNormalizer from tn.chinese.normalizer import Normalizer as ZhNormalizer
from tn.english.normalizer import Normalizer as EnNormalizer from tn.english.normalizer import Normalizer as EnNormalizer
try:
import ttsfrd
use_ttsfrd = True
except:
print("failed to import ttsfrd, please normalize input text manually")
use_ttsfrd = False
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
@@ -50,8 +55,17 @@ class CosyVoiceFrontEnd:
self.instruct = instruct self.instruct = instruct
self.allowed_special = allowed_special self.allowed_special = allowed_special
self.inflect_parser = inflect.engine() self.inflect_parser = inflect.engine()
self.zh_tn_model = ZhNormalizer(remove_erhua=False,full_to_half=False) self.use_ttsfrd = use_ttsfrd
self.en_tn_model = EnNormalizer() if self.use_ttsfrd:
self.frd = ttsfrd.TtsFrontendEngine()
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyin')
self.frd.enable_pinyin_mix(True)
self.frd.set_breakmodel_index(1)
else:
self.zh_tn_model = ZhNormalizer(remove_erhua=False,full_to_half=False)
self.en_tn_model = EnNormalizer()
def _extract_text_token(self, text): def _extract_text_token(self, text):
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special) text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
@@ -86,7 +100,10 @@ class CosyVoiceFrontEnd:
def text_normalize(self, text, split=True): def text_normalize(self, text, split=True):
text = text.strip() text = text.strip()
if contains_chinese(text): if contains_chinese(text):
text = self.zh_tn_model.normalize(text) if self.use_ttsfrd:
text = self.frd.get_frd_extra_info(text, 'input')
else:
text = self.zh_tn_model.normalize(text)
text = text.replace("\n", "") text = text.replace("\n", "")
text = replace_blank(text) text = replace_blank(text)
text = replace_corner_mark(text) text = replace_corner_mark(text)