From ab5b8eb160776244660ee090a01faf1488b73459 Mon Sep 17 00:00:00 2001 From: qihua Date: Sat, 8 Mar 2025 10:41:49 +0800 Subject: [PATCH] =?UTF-8?q?refactor(llm):=20=E9=87=8D=E6=9E=84=20vLLM=20?= =?UTF-8?q?=E6=8E=A8=E7=90=86=E4=BB=BB=E5=8A=A1=E5=A4=84=E7=90=86=E6=96=B9?= =?UTF-8?q?=E5=BC=8F=EF=BC=8C=E6=94=AF=E6=8C=81=E5=A4=9A=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除任务队列和单任务处理限制 - 使用 asyncio.run_coroutine_threadsafe() 在后台线程中运行推理任务 --- cosyvoice/llm/llm_vllm.py | 47 +++++++++++++++------------------------ 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py index 1e9bc28..839bf88 100644 --- a/cosyvoice/llm/llm_vllm.py +++ b/cosyvoice/llm/llm_vllm.py @@ -86,46 +86,35 @@ class VllmQwen2LM(Qwen2LM): self.task_token_id = self.sos_eos_token_id + 1 self.zero_token_id = self.task_token_id + 1 - # 不能直接在同步函数正确的使用 异步的生成器函数,即使使用协程也会对vllm造成崩溃 - # 使用 queue 的方式,后台线程运行推理任务 - self.task_queue = queue.Queue() + # vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务 self.loop = asyncio.new_event_loop() self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True) self.loop_thread.start() - # 运行后台协程,用于处理任务队列中的任务 - # TODO: 目前只能单任务运行,多任务运行需要对 inference_processor 进行修改 - asyncio.run_coroutine_threadsafe(self.inference_processor(self.task_queue), self.loop) def _run_event_loop(self): asyncio.set_event_loop(self.loop) self.loop.run_forever() - async def inference_processor(self, task_queue): - while True: - try: - logging.debug(f"inference_processor") - out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens = task_queue.get() - sampling_params = SamplingParams(**SAMPLING_PARAMS) - sampling_params.stop_token_ids = stop_token_ids or [6561] - if max_tokens: - sampling_params.max_tokens = max_tokens - async for output in self.llm_engine.generate( - { - "prompt_token_ids": prompt_token_ids, - }, - sampling_params=sampling_params, - request_id=request_id or f"{time.time()}", - ): - out_queue.put((output.outputs[0], output.finished)) - except Exception as e: - logging.error(f"Error in inference_processor: {e}") + async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens): + sampling_params = SamplingParams(**SAMPLING_PARAMS) + sampling_params.stop_token_ids = stop_token_ids or [6561] + if max_tokens: + sampling_params.max_tokens = max_tokens + async for output in self.llm_engine.generate( + { + "prompt_token_ids": prompt_token_ids, + }, + sampling_params=sampling_params, + request_id=request_id or f"{time.time()}", + ): + out_queue.put((output.outputs[0], output.finished)) def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None): - # 使用 同步转异步 会导致vllm崩溃,目前选择 queue 的方式,后台线程运行推理任务 - # 提交推理任务到队列中 out_queue = queue.Queue() - self.task_queue.put((out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens)) - # 将 out_queue 的结果返回 + asyncio.run_coroutine_threadsafe( + self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop + ) + # 接收 out_queue 返回的结果 finished = False while not finished: (output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()