export onnx

This commit is contained in:
禾息
2024-09-03 11:06:24 +08:00
parent 18599be8d5
commit fadb22086f
5 changed files with 318 additions and 164 deletions

View File

@@ -19,6 +19,13 @@ import time
from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
import numpy as np
import onnxruntime as ort
# try:
# import tensorrt as trt
# except ImportError:
# ...
class CosyVoiceModel:
@@ -66,21 +73,40 @@ class CosyVoiceModel:
llm_llm = torch.jit.load(llm_llm_model)
self.llm.llm = llm_llm
def load_trt(self, model_dir, use_fp16):
import tensorrt as trt
trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
trt_file_path = os.path.join(model_dir, trt_file_name)
if not os.path.isfile(trt_file_path):
raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
# def load_trt(self, model_dir, use_fp16):
# trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
# trt_file_path = os.path.join(model_dir, trt_file_name)
# if not os.path.isfile(trt_file_path):
# raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
trt.init_libnvinfer_plugins(None, "")
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
with open(trt_file_path, 'rb') as f:
serialized_engine = f.read()
engine = runtime.deserialize_cuda_engine(serialized_engine)
self.flow.decoder.estimator_context = engine.create_execution_context()
# trt.init_libnvinfer_plugins(None, "")
# logger = trt.Logger(trt.Logger.WARNING)
# runtime = trt.Runtime(logger)
# with open(trt_file_path, 'rb') as f:
# serialized_engine = f.read()
# engine = runtime.deserialize_cuda_engine(serialized_engine)
# self.flow.decoder.estimator_context = engine.create_execution_context()
# self.flow.decoder.estimator = None
def load_onnx(self, model_dir, use_fp16):
onnx_file_name = 'estimator_fp16.onnx' if use_fp16 else 'estimator_fp32.onnx'
onnx_file_path = os.path.join(model_dir, onnx_file_name)
if not os.path.isfile(onnx_file_path):
raise f"{onnx_file_path} does not exist. Please use bin/export_trt.py to generate .onnx file"
providers = ['CUDAExecutionProvider']
sess_options = ort.SessionOptions()
# Add TensorRT Execution Provider
providers = [
'CUDAExecutionProvider'
]
# Load the ONNX model
self.flow.decoder.session = ort.InferenceSession(onnx_file_path, sess_options=sess_options, providers=providers)
# self.flow.decoder.estimator_context = None
self.flow.decoder.estimator = None
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
with self.llm_context: