update func inference

This commit is contained in:
雾聪
2025-03-01 19:10:20 +08:00
parent 9a4aebb0ea
commit 11dbd88947

View File

@@ -345,21 +345,22 @@ class Qwen2LM(TransformerLM):
vllm_codec_engine.add_request(request_id, vllm_codec_engine.add_request(request_id,
{"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)}, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)},
sampling_params) sampling_params)
## generator
out_token_ids = []
while True: while True:
speech_token_break = False
request_outputs: List[RequestOutput] = vllm_codec_engine.step() request_outputs: List[RequestOutput] = vllm_codec_engine.step()
for request_output in request_outputs: for request_output in request_outputs:
if str(request_output.request_id) != str(request_id): if str(request_output.request_id) != str(request_id):
continue continue
if not request_output.finished: # print(f"request output: {request_output}")
# print(f"Partial request output: {request_output}") top_ids = list(request_output.outputs[0].token_ids)[-1]
out_token = list(request_output.outputs[0].token_ids)[-1] if top_ids == self.speech_token_size:
yield out_token speech_token_break = True
out_token_ids.append(out_token)
else:
break break
if not vllm_codec_engine.has_unfinished_requests(): if top_ids > self.speech_token_size:
continue
yield top_ids
if not vllm_codec_engine.has_unfinished_requests() or speech_token_break:
break break
@torch.inference_mode() @torch.inference_mode()