mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
use automodel
This commit is contained in:
@@ -23,8 +23,10 @@ import torch
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.class_utils import get_model_type
|
||||
|
||||
|
||||
def get_args():
|
||||
@@ -57,15 +59,17 @@ def main():
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
|
||||
try:
|
||||
model = CosyVoice(args.model_dir)
|
||||
except Exception:
|
||||
try:
|
||||
model = CosyVoice2(args.model_dir)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
model = AutoModel(model_dir=args.model_dir)
|
||||
|
||||
if not isinstance(model, CosyVoice2):
|
||||
if get_model_type(model.model) == CosyVoiceModel:
|
||||
# 1. export flow encoder
|
||||
flow_encoder = model.model.flow.encoder
|
||||
script = get_optimized_script(flow_encoder)
|
||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(flow_encoder.half())
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export flow_encoder')
|
||||
elif get_model_type(model.model) == CosyVoice2Model:
|
||||
# 1. export llm text_encoder
|
||||
llm_text_encoder = model.model.llm.text_encoder
|
||||
script = get_optimized_script(llm_text_encoder)
|
||||
@@ -90,13 +94,7 @@ def main():
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export flow_encoder')
|
||||
else:
|
||||
# 3. export flow encoder
|
||||
flow_encoder = model.model.flow.encoder
|
||||
script = get_optimized_script(flow_encoder)
|
||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(flow_encoder.half())
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export flow_encoder')
|
||||
raise ValueError('unsupported model type')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -27,7 +27,7 @@ from tqdm import tqdm
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2, CosyVoice3
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
|
||||
|
||||
@@ -58,16 +58,7 @@ def main():
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
try:
|
||||
model = CosyVoice(args.model_dir)
|
||||
except Exception:
|
||||
try:
|
||||
model = CosyVoice2(args.model_dir)
|
||||
except Exception:
|
||||
try:
|
||||
model = CosyVoice3(args.model_dir)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
model = AutoModel(model_dir=args.model_dir)
|
||||
|
||||
# 1. export flow decoder estimator
|
||||
estimator = model.model.flow.decoder.estimator
|
||||
|
||||
Reference in New Issue
Block a user