update export_jit

This commit is contained in:
lyuxiang.lx
2025-12-23 15:23:29 +08:00
parent 59cb2bf16c
commit 7538c6a73d
2 changed files with 26 additions and 22 deletions

View File

@@ -24,9 +24,8 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR)) sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import AutoModel from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.utils.file_utils import logging from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
def get_args(): def get_args():
@@ -61,7 +60,7 @@ def main():
model = AutoModel(model_dir=args.model_dir) model = AutoModel(model_dir=args.model_dir)
if get_model_type(model.model) == CosyVoiceModel: if isinstance(model.model, CosyVoiceModel):
# 1. export llm text_encoder # 1. export llm text_encoder
llm_text_encoder = model.model.llm.text_encoder llm_text_encoder = model.model.llm.text_encoder
script = get_optimized_script(llm_text_encoder) script = get_optimized_script(llm_text_encoder)
@@ -85,7 +84,7 @@ def main():
script = get_optimized_script(flow_encoder.half()) script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder') logging.info('successfully export flow_encoder')
elif get_model_type(model.model) == CosyVoice2Model: elif isinstance(model.model, CosyVoice2Model):
# 1. export flow encoder # 1. export flow encoder
flow_encoder = model.model.flow.encoder flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder) script = get_optimized_script(flow_encoder)

View File

@@ -20,18 +20,9 @@ import numpy as np
import whisper import whisper
from typing import Callable from typing import Callable
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
import torchaudio
import os import os
import re import re
import inflect import inflect
try:
import ttsfrd
use_ttsfrd = True
except ImportError:
print("failed to import ttsfrd, use wetext instead")
from wetext import Normalizer as ZhNormalizer
from wetext import Normalizer as EnNormalizer
use_ttsfrd = False
from cosyvoice.utils.file_utils import logging, load_wav from cosyvoice.utils.file_utils import logging, load_wav
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
@@ -60,17 +51,29 @@ class CosyVoiceFrontEnd:
else: else:
self.spk2info = {} self.spk2info = {}
self.allowed_special = allowed_special self.allowed_special = allowed_special
self.use_ttsfrd = use_ttsfrd self.inflect_parser = inflect.engine()
if self.use_ttsfrd: # NOTE compatible when no text frontend tool is avaliable
try:
import ttsfrd
self.frd = ttsfrd.TtsFrontendEngine() self.frd = ttsfrd.TtsFrontendEngine()
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
'failed to initialize ttsfrd resource' 'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyinvg') self.frd.set_lang_type('pinyinvg')
else: self.text_frontend = 'ttsfrd'
self.zh_tn_model = ZhNormalizer(remove_erhua=False) logging.info('use ttsfrd frontend')
self.en_tn_model = EnNormalizer() except:
self.inflect_parser = inflect.engine() try:
from wetext import Normalizer as ZhNormalizer
from wetext import Normalizer as EnNormalizer
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
self.en_tn_model = EnNormalizer()
self.text_frontend = 'wetext'
logging.info('use wetext frontend')
except:
self.text_frontend = ''
logging.info('no frontend is avaliable')
def _extract_text_token(self, text): def _extract_text_token(self, text):
if isinstance(text, Generator): if isinstance(text, Generator):
@@ -131,12 +134,13 @@ class CosyVoiceFrontEnd:
if text_frontend is False or text == '': if text_frontend is False or text == '':
return [text] if split is True else text return [text] if split is True else text
text = text.strip() text = text.strip()
if self.use_ttsfrd: if self.text_frontend == 'ttsfrd':
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]] texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
text = ''.join(texts) text = ''.join(texts)
else: else:
if contains_chinese(text): if contains_chinese(text):
text = self.zh_tn_model.normalize(text) if self.text_frontend == 'wetext':
text = self.zh_tn_model.normalize(text)
text = text.replace("\n", "") text = text.replace("\n", "")
text = replace_blank(text) text = replace_blank(text)
text = replace_corner_mark(text) text = replace_corner_mark(text)
@@ -147,7 +151,8 @@ class CosyVoiceFrontEnd:
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False)) token_min_n=60, merge_len=20, comma_split=False))
else: else:
text = self.en_tn_model.normalize(text) if self.text_frontend == 'wetext':
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser) text = spell_out_number(text, self.inflect_parser)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False)) token_min_n=60, merge_len=20, comma_split=False))