mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
add trt script TODO
This commit is contained in:
@@ -53,12 +53,17 @@ class CosyVoiceModel:
|
||||
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
def load_script(self, llm_text_encoder_model, llm_llm_model):
|
||||
def load_jit(self, llm_text_encoder_model, llm_llm_model):
|
||||
llm_text_encoder = torch.jit.load(llm_text_encoder_model)
|
||||
self.llm.text_encoder = llm_text_encoder
|
||||
llm_llm = torch.jit.load(llm_llm_model)
|
||||
self.llm.llm = llm_llm
|
||||
|
||||
def load_trt(self):
|
||||
# TODO 你需要的TRT推理的准备
|
||||
self.flow.decoder.estimator = xxx
|
||||
self.flow.decoder.session = xxx
|
||||
|
||||
def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
|
||||
with self.llm_context:
|
||||
for i in self.llm.inference(text=text.to(self.device),
|
||||
|
||||
Reference in New Issue
Block a user