add onnx export

This commit is contained in:
lyuxiang.lx
2024-09-04 18:15:33 +08:00
parent d8197de4cc
commit 2ce724045b
6 changed files with 105 additions and 280 deletions

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import numpy as np
import threading
@@ -20,7 +19,6 @@ from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
import numpy as np
import onnxruntime as ort
class CosyVoiceModel:
@@ -62,47 +60,22 @@ class CosyVoiceModel:
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
self.hift.to(self.device).eval()
def load_jit(self, llm_text_encoder_model, llm_llm_model):
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
llm_text_encoder = torch.jit.load(llm_text_encoder_model)
self.llm.text_encoder = llm_text_encoder
llm_llm = torch.jit.load(llm_llm_model)
self.llm.llm = llm_llm
flow_encoder = torch.jit.load(flow_encoder_model)
self.flow.encoder = flow_encoder
# def load_trt(self, model_dir, use_fp16):
# import tensorrt as trt
# trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
# trt_file_path = os.path.join(model_dir, trt_file_name)
# if not os.path.isfile(trt_file_path):
# raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
# trt.init_libnvinfer_plugins(None, "")
# logger = trt.Logger(trt.Logger.WARNING)
# runtime = trt.Runtime(logger)
# with open(trt_file_path, 'rb') as f:
# serialized_engine = f.read()
# engine = runtime.deserialize_cuda_engine(serialized_engine)
# self.flow.decoder.estimator_context = engine.create_execution_context()
# self.flow.decoder.estimator = None
def load_onnx(self, model_dir, use_fp16):
onnx_file_name = 'estimator_fp16.onnx' if use_fp16 else 'estimator_fp32.onnx'
onnx_file_path = os.path.join(model_dir, onnx_file_name)
if not os.path.isfile(onnx_file_path):
raise f"{onnx_file_path} does not exist. Please use bin/export_trt.py to generate .onnx file"
providers = ['CUDAExecutionProvider']
sess_options = ort.SessionOptions()
# Add TensorRT Execution Provider
providers = [
'CUDAExecutionProvider'
]
# Load the ONNX model
self.flow.decoder.session = ort.InferenceSession(onnx_file_path, sess_options=sess_options, providers=providers)
# self.flow.decoder.estimator_context = None
self.flow.decoder.estimator = None
def load_onnx(self, flow_decoder_estimator_model):
import onnxruntime
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
del self.flow.decoder.estimator
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
with self.llm_context:
@@ -207,4 +180,5 @@ class CosyVoiceModel:
self.llm_end_dict.pop(this_uuid)
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()