update vllm_codec_engine

This commit is contained in:
雾聪
2025-02-25 19:40:30 +08:00
parent 4df0683a37
commit f6a18ee07a
3 changed files with 48 additions and 19 deletions

View File

@@ -66,6 +66,7 @@ class CosyVoiceModel:
self.mel_overlap_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {}
self.vllm_codec_engine = None
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
@@ -117,7 +118,8 @@ class CosyVoiceModel:
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device)):
embedding=llm_embedding.to(self.device),
vllm_codec_engine=self.vllm_codec_engine):
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
@@ -314,6 +316,7 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
self.vllm_codec_engine = None
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)