add flow decoder tensorrt infer

This commit is contained in:
zhoubofan.zbf
2024-08-29 23:35:07 +08:00
parent 1d881df8b2
commit 5f21aef786
5 changed files with 149 additions and 19 deletions

View File

@@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging
class CosyVoice:
def __init__(self, model_dir, load_jit=True, load_trt=True):
def __init__(self, model_dir, load_jit=True, load_trt=True, use_fp16=False):
instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
if not os.path.exists(model_dir):
@@ -43,8 +43,7 @@ class CosyVoice:
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
'{}/llm.llm.fp16.zip'.format(model_dir))
if load_trt:
# TODO
self.model.load_trt()
self.model.load_trt(model_dir, use_fp16)
del configs
def list_avaliable_spks(self):

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import numpy as np
import threading
@@ -19,6 +20,10 @@ from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
try:
import tensorrt as trt
except ImportError:
...
class CosyVoiceModel:
@@ -66,10 +71,20 @@ class CosyVoiceModel:
llm_llm = torch.jit.load(llm_llm_model)
self.llm.llm = llm_llm
def load_trt(self):
# TODO 你需要的TRT推理的准备
self.flow.decoder.estimator = xxx
self.flow.decoder.session = xxx
def load_trt(self, model_dir, use_fp16):
trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
trt_file_path = os.path.join(model_dir, trt_file_name)
if not os.path.isfile(trt_file_path):
raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
trt.init_libnvinfer_plugins(None, "")
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
with open(trt_file_path, 'rb') as f:
serialized_engine = f.read()
engine = runtime.deserialize_cuda_engine(serialized_engine)
self.flow.decoder.estimator_context = engine.create_execution_context()
self.flow.decoder.estimator_engine = engine
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
with self.llm_context: