mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 09:59:23 +08:00
export onnx
This commit is contained in:
@@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging
|
||||
|
||||
class CosyVoice:
|
||||
|
||||
def __init__(self, model_dir, load_jit=True, load_trt=True, use_fp16=False):
|
||||
def __init__(self, model_dir, load_jit=True, load_trt=False, load_onnx=True, use_fp16=False):
|
||||
instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
if not os.path.exists(model_dir):
|
||||
@@ -39,13 +39,16 @@ class CosyVoice:
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
|
||||
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
||||
'{}/llm.llm.fp16.zip'.format(model_dir))
|
||||
|
||||
if load_trt:
|
||||
self.model.load_trt(model_dir, use_fp16)
|
||||
# if load_trt:
|
||||
# self.model.load_trt(model_dir, use_fp16)
|
||||
|
||||
if load_onnx:
|
||||
self.model.load_onnx(model_dir, use_fp16)
|
||||
|
||||
del configs
|
||||
|
||||
|
||||
@@ -19,6 +19,13 @@ import time
|
||||
from contextlib import nullcontext
|
||||
import uuid
|
||||
from cosyvoice.utils.common import fade_in_out
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
# try:
|
||||
# import tensorrt as trt
|
||||
# except ImportError:
|
||||
# ...
|
||||
|
||||
class CosyVoiceModel:
|
||||
|
||||
@@ -66,21 +73,40 @@ class CosyVoiceModel:
|
||||
llm_llm = torch.jit.load(llm_llm_model)
|
||||
self.llm.llm = llm_llm
|
||||
|
||||
def load_trt(self, model_dir, use_fp16):
|
||||
import tensorrt as trt
|
||||
trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
|
||||
trt_file_path = os.path.join(model_dir, trt_file_name)
|
||||
if not os.path.isfile(trt_file_path):
|
||||
raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
|
||||
# def load_trt(self, model_dir, use_fp16):
|
||||
# trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
|
||||
# trt_file_path = os.path.join(model_dir, trt_file_name)
|
||||
# if not os.path.isfile(trt_file_path):
|
||||
# raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
|
||||
|
||||
trt.init_libnvinfer_plugins(None, "")
|
||||
logger = trt.Logger(trt.Logger.WARNING)
|
||||
runtime = trt.Runtime(logger)
|
||||
with open(trt_file_path, 'rb') as f:
|
||||
serialized_engine = f.read()
|
||||
engine = runtime.deserialize_cuda_engine(serialized_engine)
|
||||
self.flow.decoder.estimator_context = engine.create_execution_context()
|
||||
# trt.init_libnvinfer_plugins(None, "")
|
||||
# logger = trt.Logger(trt.Logger.WARNING)
|
||||
# runtime = trt.Runtime(logger)
|
||||
# with open(trt_file_path, 'rb') as f:
|
||||
# serialized_engine = f.read()
|
||||
# engine = runtime.deserialize_cuda_engine(serialized_engine)
|
||||
# self.flow.decoder.estimator_context = engine.create_execution_context()
|
||||
# self.flow.decoder.estimator = None
|
||||
|
||||
def load_onnx(self, model_dir, use_fp16):
|
||||
onnx_file_name = 'estimator_fp16.onnx' if use_fp16 else 'estimator_fp32.onnx'
|
||||
onnx_file_path = os.path.join(model_dir, onnx_file_name)
|
||||
if not os.path.isfile(onnx_file_path):
|
||||
raise f"{onnx_file_path} does not exist. Please use bin/export_trt.py to generate .onnx file"
|
||||
|
||||
providers = ['CUDAExecutionProvider']
|
||||
sess_options = ort.SessionOptions()
|
||||
|
||||
# Add TensorRT Execution Provider
|
||||
providers = [
|
||||
'CUDAExecutionProvider'
|
||||
]
|
||||
|
||||
# Load the ONNX model
|
||||
self.flow.decoder.session = ort.InferenceSession(onnx_file_path, sess_options=sess_options, providers=providers)
|
||||
# self.flow.decoder.estimator_context = None
|
||||
self.flow.decoder.estimator = None
|
||||
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
with self.llm_context:
|
||||
|
||||
Reference in New Issue
Block a user