add trt script TODO

This commit is contained in:
lyuxiang.lx
2024-08-29 10:44:04 +08:00
parent 8b097f7625
commit f1e374a9bb
5 changed files with 31 additions and 7 deletions

View File

@@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging
class CosyVoice:
def __init__(self, model_dir, load_script=True):
def __init__(self, model_dir, load_jit=True, load_trt=True):
instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
if not os.path.exists(model_dir):
@@ -39,9 +39,12 @@ class CosyVoice:
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
if load_script:
self.model.load_script('{}/llm.text_encoder.fp16.zip'.format(model_dir),
if load_jit:
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
'{}/llm.llm.fp16.zip'.format(model_dir))
if load_trt:
# TODO
self.model.load_trt()
del configs
def list_avaliable_spks(self):

View File

@@ -53,12 +53,17 @@ class CosyVoiceModel:
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
self.hift.to(self.device).eval()
def load_script(self, llm_text_encoder_model, llm_llm_model):
def load_jit(self, llm_text_encoder_model, llm_llm_model):
llm_text_encoder = torch.jit.load(llm_text_encoder_model)
self.llm.text_encoder = llm_text_encoder
llm_llm = torch.jit.load(llm_llm_model)
self.llm.llm = llm_llm
def load_trt(self):
# TODO 你需要的TRT推理的准备
self.flow.decoder.estimator = xxx
self.flow.decoder.session = xxx
def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
with self.llm_context:
for i in self.llm.inference(text=text.to(self.device),