From d4d187bd8c8b96763ac64f6d1a7b4a4b9a2c5392 Mon Sep 17 00:00:00 2001 From: qihua Date: Fri, 7 Mar 2025 23:53:50 +0800 Subject: [PATCH] =?UTF-8?q?refactor(llm):=20=E9=87=8D=E6=9E=84=20VLLM=20?= =?UTF-8?q?=E6=8E=A8=E7=90=86=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增基于队列和线程的异步推理机制 - 优化同步推理接口,使用新机制实现 --- cosyvoice/llm/llm_vllm.py | 109 +++++++++++++++++++------------------- 1 file changed, 55 insertions(+), 54 deletions(-) diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py index c43c53a..61b1090 100644 --- a/cosyvoice/llm/llm_vllm.py +++ b/cosyvoice/llm/llm_vllm.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import contextlib import time +import queue +import asyncio +import threading from typing import List, Generator, AsyncGenerator import torch from cosyvoice.utils.file_utils import logging @@ -41,6 +42,7 @@ ENGINE_ARGS = { "max_num_seqs": 256, "disable_log_requests": True, "disable_log_stats": True, + "dtype": "float16" } from vllm.sampling_params import RequestOutputKind @@ -84,13 +86,42 @@ class VllmQwen2LM(Qwen2LM): self.task_token_id = self.sos_eos_token_id + 1 self.zero_token_id = self.task_token_id + 1 + # 不能直接在同步函数正确的使用 异步的生成器函数,即使使用协程也会对vllm造成崩溃 + # 使用 queue 的方式,后台线程运行推理任务 + self.task_queue = queue.Queue() + self.loop = asyncio.new_event_loop() + self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True) + self.loop_thread.start() + # 运行后台协程,用于处理任务队列中的任务 + # TODO: 目前只能单任务运行,多任务运行需要对 inference_processor 进行修改 + asyncio.run_coroutine_threadsafe(self.inference_processor(self.task_queue), self.loop) + + def _run_event_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + async def inference_processor(self, task_queue): + while True: + try: + print(f"inference_processor") + out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens = task_queue.get() + sampling_params = SamplingParams(**SAMPLING_PARAMS) + sampling_params.stop_token_ids = stop_token_ids or [6561] + if max_tokens: + sampling_params.max_tokens = max_tokens + async for output in self.llm_engine.generate( + { + "prompt_token_ids": prompt_token_ids, + }, + sampling_params=sampling_params, + request_id=request_id or f"{time.time()}", + ): + out_queue.put((output.outputs[0], output.finished)) + except Exception as e: + logging.error(f"Error in inference_processor: {e}") + async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\ -> AsyncGenerator[CompletionOutput, None]: - assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]" - invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None) - assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}" - # logging.debug('prompt_token_ids:', prompt_token_ids) - # TODO: 增加上下文控制,取消请求时 sampling_params = SamplingParams(**SAMPLING_PARAMS) sampling_params.stop_token_ids = stop_token_ids or [6561] if max_tokens: @@ -104,49 +135,16 @@ class VllmQwen2LM(Qwen2LM): ): yield output.outputs[0] - - def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\ - -> Generator[CompletionOutput, None, None]: - assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]" - invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None) - assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}" - # logging.debug('prompt_token_ids:', prompt_token_ids) - # TODO: 增加上下文控制,取消请求时 - sampling_params = SamplingParams(**SAMPLING_PARAMS) - sampling_params.stop_token_ids = stop_token_ids or [6561] - if max_tokens: - sampling_params.max_tokens = max_tokens - - # 创建独立事件循环 - loop = asyncio.new_event_loop() - try: - asyncio.set_event_loop(loop) - # 初始化异步生成器 - async_gen = self.llm_engine.generate( - { - "prompt_token_ids": prompt_token_ids, - }, - sampling_params=sampling_params, - request_id=request_id or f"{time.time()}", - ) - while True: - try: - # 同步获取异步结果 - output = loop.run_until_complete(async_gen.__anext__()) - yield output.outputs[0] - except StopAsyncIteration: - break - except GeneratorExit: - if async_gen is not None: - loop.run_until_complete(async_gen.aclose()) - raise - finally: - # 资源清理 - print("资源清理...") - if async_gen is not None: - loop.run_until_complete(async_gen.aclose()) - loop.close() - print("资源清理成功") + def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None): + # 使用 同步转异步 会导致vllm崩溃,目前选择 queue 的方式,后台线程运行推理任务 + # 提交推理任务到队列中 + out_queue = queue.Queue() + self.task_queue.put((out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens)) + # 将 out_queue 的结果返回 + finished = False + while not finished: + (output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get() + yield output def inference( self, @@ -194,6 +192,9 @@ class VllmQwen2LM(Qwen2LM): max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, ) -> Generator[torch.Tensor, None, None]: + prompt_text = tensor_to_list(prompt_text + torch.tensor(6564)) + prompt_speech_token = tensor_to_list(prompt_speech_token) + last_tokens = [] prompt_token_ids = [self.sos_eos_token_id] text_tokens_cache = prompt_text @@ -202,18 +203,18 @@ class VllmQwen2LM(Qwen2LM): # text need tokens assert isinstance(this_text, list), "text need token ids List[int]." text_tokens_cache += this_text - while len(llm_prompt_speech_token) != 0: + while len(prompt_speech_token) != 0: if len(text_tokens_cache) >= self.mix_ratio[0]: text_input_token = text_tokens_cache[:self.mix_ratio[0]] - speech_input_token = llm_prompt_speech_token[:self.mix_ratio[1]] + speech_input_token = prompt_speech_token[:self.mix_ratio[1]] prompt_token_ids += text_input_token + speech_input_token # reset the last cache text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] - llm_prompt_speech_token = llm_prompt_speech_token[self.mix_ratio[1]:] + prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:] else: logging.info('not enough text token to decode, wait for more') break - if len(llm_prompt_speech_token) == 0: + if len(prompt_speech_token) == 0: if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1: logging.info('get fill token, need to append more text token') if len(text_tokens_cache) >= self.mix_ratio[0]: