add vllm inference

This commit is contained in:
lyuxiang.lx
2025-05-30 07:22:35 +00:00
parent 9f55c5af8f
commit 6dd68b9d5e
6 changed files with 105 additions and 64 deletions

View File

@@ -59,9 +59,6 @@ class CosyVoiceModel:
self.stream_scale_factor = 1
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
for _ in range(trt_concurrent):
self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
@@ -69,7 +66,6 @@ class CosyVoiceModel:
self.mel_overlap_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {}
self.trt_context_dict = {}
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
@@ -98,7 +94,7 @@ class CosyVoiceModel:
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent, device=self.device)
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
@@ -125,7 +121,8 @@ class CosyVoiceModel:
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device)):
embedding=llm_embedding.to(self.device),
uuid=uuid):
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
@@ -180,13 +177,11 @@ class CosyVoiceModel:
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
this_trt_context = self.trt_context_pool.get()
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
self.trt_context_dict[this_uuid] = this_trt_context
if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
else:
@@ -240,8 +235,6 @@ class CosyVoiceModel:
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
self.trt_context_dict.pop(this_uuid)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
@@ -273,22 +266,28 @@ class CosyVoice2Model(CosyVoiceModel):
self.speech_window = np.hamming(2 * self.source_cache_len)
# rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
for _ in range(trt_concurrent):
self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
self.trt_context_dict = {}
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load_vllm(self, model_dir):
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
from vllm import EngineArgs, LLMEngine
engine_args = EngineArgs(model=model_dir,
skip_tokenizer_init=True,
enable_prompt_embeds=True,
gpu_memory_utilization=0.2)
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
del self.llm.llm.model.model.layers
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
@@ -330,11 +329,9 @@ class CosyVoice2Model(CosyVoiceModel):
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
this_trt_context = self.trt_context_pool.get()
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
self.trt_context_dict[this_uuid] = this_trt_context
if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
else:
@@ -388,8 +385,6 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
self.trt_context_dict.pop(this_uuid)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()