use automodel

This commit is contained in:
lyuxiang.lx
2025-12-09 15:15:05 +00:00
parent 56d9876037
commit 0c65d3c7ab
8 changed files with 56 additions and 88 deletions

View File

@@ -23,8 +23,10 @@ import torch
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
def get_args():
@@ -57,15 +59,17 @@ def main():
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
try:
model = CosyVoice(args.model_dir)
except Exception:
try:
model = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
model = AutoModel(model_dir=args.model_dir)
if not isinstance(model, CosyVoice2):
if get_model_type(model.model) == CosyVoiceModel:
# 1. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
elif get_model_type(model.model) == CosyVoice2Model:
# 1. export llm text_encoder
llm_text_encoder = model.model.llm.text_encoder
script = get_optimized_script(llm_text_encoder)
@@ -90,13 +94,7 @@ def main():
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
else:
# 3. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
raise ValueError('unsupported model type')
if __name__ == '__main__':