mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
初步合并vllm支持,异步推理的通道处理还存在bug
This commit is contained in:
@@ -409,3 +409,26 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
class VllmCosyVoice2Model(CosyVoice2Model):
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool):
|
||||
try:
|
||||
from cosyvoice.llm.llm_vllm import VllmQwen2LM
|
||||
except Exception as e:
|
||||
raise e
|
||||
llm = VllmQwen2LM(model_dir)
|
||||
super().__init__(llm,flow,hift,fp16)
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.flow.load_state_dict(torch.load(flow_model, weights_only=True, map_location=self.device), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
# in case hift_model is a hifigan model
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in
|
||||
torch.load(hift_model, weights_only=True, map_location=self.device).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
Reference in New Issue
Block a user