diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 5d6e4db..e5a961c 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -345,21 +345,22 @@ class Qwen2LM(TransformerLM): vllm_codec_engine.add_request(request_id, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)}, sampling_params) - ## generator - out_token_ids = [] while True: + speech_token_break = False request_outputs: List[RequestOutput] = vllm_codec_engine.step() for request_output in request_outputs: if str(request_output.request_id) != str(request_id): continue - if not request_output.finished: - # print(f"Partial request output: {request_output}") - out_token = list(request_output.outputs[0].token_ids)[-1] - yield out_token - out_token_ids.append(out_token) - else: + # print(f"request output: {request_output}") + top_ids = list(request_output.outputs[0].token_ids)[-1] + if top_ids == self.speech_token_size: + speech_token_break = True break - if not vllm_codec_engine.has_unfinished_requests(): + if top_ids > self.speech_token_size: + continue + yield top_ids + + if not vllm_codec_engine.has_unfinished_requests() or speech_token_break: break @torch.inference_mode()