support online onnx to trt conversion

This commit is contained in:
huzetao.hzt
2025-01-07 17:20:06 +08:00
parent 5d12ced727
commit b6a1116d15
4 changed files with 146 additions and 26 deletions

View File

@@ -19,6 +19,7 @@ from torch.nn import functional as F
from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
from cosyvoice.trt.estimator_trt import EstimatorTRT
class CosyVoiceModel:
@@ -81,14 +82,9 @@ class CosyVoiceModel:
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load_trt(self, flow_decoder_estimator_model):
def load_trt(self, flow_decoder_estimator_model, fp16):
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
if self.flow.decoder.estimator_engine is None:
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
self.flow.decoder.estimator = EstimatorTRT(flow_decoder_estimator_model, self.device, fp16)
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
with self.llm_context: