mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update func inference
This commit is contained in:
@@ -345,21 +345,22 @@ class Qwen2LM(TransformerLM):
|
|||||||
vllm_codec_engine.add_request(request_id,
|
vllm_codec_engine.add_request(request_id,
|
||||||
{"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)},
|
{"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)},
|
||||||
sampling_params)
|
sampling_params)
|
||||||
## generator
|
|
||||||
out_token_ids = []
|
|
||||||
while True:
|
while True:
|
||||||
|
speech_token_break = False
|
||||||
request_outputs: List[RequestOutput] = vllm_codec_engine.step()
|
request_outputs: List[RequestOutput] = vllm_codec_engine.step()
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
if str(request_output.request_id) != str(request_id):
|
if str(request_output.request_id) != str(request_id):
|
||||||
continue
|
continue
|
||||||
if not request_output.finished:
|
# print(f"request output: {request_output}")
|
||||||
# print(f"Partial request output: {request_output}")
|
top_ids = list(request_output.outputs[0].token_ids)[-1]
|
||||||
out_token = list(request_output.outputs[0].token_ids)[-1]
|
if top_ids == self.speech_token_size:
|
||||||
yield out_token
|
speech_token_break = True
|
||||||
out_token_ids.append(out_token)
|
|
||||||
else:
|
|
||||||
break
|
break
|
||||||
if not vllm_codec_engine.has_unfinished_requests():
|
if top_ids > self.speech_token_size:
|
||||||
|
continue
|
||||||
|
yield top_ids
|
||||||
|
|
||||||
|
if not vllm_codec_engine.has_unfinished_requests() or speech_token_break:
|
||||||
break
|
break
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
Reference in New Issue
Block a user