mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
add onnx export
This commit is contained in:
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import threading
|
||||
@@ -20,7 +19,6 @@ from contextlib import nullcontext
|
||||
import uuid
|
||||
from cosyvoice.utils.common import fade_in_out
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
class CosyVoiceModel:
|
||||
|
||||
@@ -62,47 +60,22 @@ class CosyVoiceModel:
|
||||
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
def load_jit(self, llm_text_encoder_model, llm_llm_model):
|
||||
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
||||
llm_text_encoder = torch.jit.load(llm_text_encoder_model)
|
||||
self.llm.text_encoder = llm_text_encoder
|
||||
llm_llm = torch.jit.load(llm_llm_model)
|
||||
self.llm.llm = llm_llm
|
||||
flow_encoder = torch.jit.load(flow_encoder_model)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
# def load_trt(self, model_dir, use_fp16):
|
||||
# import tensorrt as trt
|
||||
# trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
|
||||
# trt_file_path = os.path.join(model_dir, trt_file_name)
|
||||
# if not os.path.isfile(trt_file_path):
|
||||
# raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
|
||||
|
||||
# trt.init_libnvinfer_plugins(None, "")
|
||||
# logger = trt.Logger(trt.Logger.WARNING)
|
||||
# runtime = trt.Runtime(logger)
|
||||
# with open(trt_file_path, 'rb') as f:
|
||||
# serialized_engine = f.read()
|
||||
# engine = runtime.deserialize_cuda_engine(serialized_engine)
|
||||
# self.flow.decoder.estimator_context = engine.create_execution_context()
|
||||
# self.flow.decoder.estimator = None
|
||||
|
||||
def load_onnx(self, model_dir, use_fp16):
|
||||
onnx_file_name = 'estimator_fp16.onnx' if use_fp16 else 'estimator_fp32.onnx'
|
||||
onnx_file_path = os.path.join(model_dir, onnx_file_name)
|
||||
if not os.path.isfile(onnx_file_path):
|
||||
raise f"{onnx_file_path} does not exist. Please use bin/export_trt.py to generate .onnx file"
|
||||
|
||||
providers = ['CUDAExecutionProvider']
|
||||
sess_options = ort.SessionOptions()
|
||||
|
||||
# Add TensorRT Execution Provider
|
||||
providers = [
|
||||
'CUDAExecutionProvider'
|
||||
]
|
||||
|
||||
# Load the ONNX model
|
||||
self.flow.decoder.session = ort.InferenceSession(onnx_file_path, sess_options=sess_options, providers=providers)
|
||||
# self.flow.decoder.estimator_context = None
|
||||
self.flow.decoder.estimator = None
|
||||
|
||||
def load_onnx(self, flow_decoder_estimator_model):
|
||||
import onnxruntime
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||
del self.flow.decoder.estimator
|
||||
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
with self.llm_context:
|
||||
@@ -207,4 +180,5 @@ class CosyVoiceModel:
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
Reference in New Issue
Block a user