refactor(llm): 重构 VLLM 推理方式

- 新增基于队列和线程的异步推理机制
- 优化同步推理接口,使用新机制实现
This commit is contained in:
qihua
2025-03-07 23:53:50 +08:00
parent 90b666ea20
commit d4d187bd8c

View File

@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import contextlib
import time import time
import queue
import asyncio
import threading
from typing import List, Generator, AsyncGenerator from typing import List, Generator, AsyncGenerator
import torch import torch
from cosyvoice.utils.file_utils import logging from cosyvoice.utils.file_utils import logging
@@ -41,6 +42,7 @@ ENGINE_ARGS = {
"max_num_seqs": 256, "max_num_seqs": 256,
"disable_log_requests": True, "disable_log_requests": True,
"disable_log_stats": True, "disable_log_stats": True,
"dtype": "float16"
} }
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
@@ -84,13 +86,42 @@ class VllmQwen2LM(Qwen2LM):
self.task_token_id = self.sos_eos_token_id + 1 self.task_token_id = self.sos_eos_token_id + 1
self.zero_token_id = self.task_token_id + 1 self.zero_token_id = self.task_token_id + 1
# 不能直接在同步函数正确的使用 异步的生成器函数即使使用协程也会对vllm造成崩溃
# 使用 queue 的方式,后台线程运行推理任务
self.task_queue = queue.Queue()
self.loop = asyncio.new_event_loop()
self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
self.loop_thread.start()
# 运行后台协程,用于处理任务队列中的任务
# TODO: 目前只能单任务运行,多任务运行需要对 inference_processor 进行修改
asyncio.run_coroutine_threadsafe(self.inference_processor(self.task_queue), self.loop)
def _run_event_loop(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
async def inference_processor(self, task_queue):
while True:
try:
print(f"inference_processor")
out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens = task_queue.get()
sampling_params = SamplingParams(**SAMPLING_PARAMS)
sampling_params.stop_token_ids = stop_token_ids or [6561]
if max_tokens:
sampling_params.max_tokens = max_tokens
async for output in self.llm_engine.generate(
{
"prompt_token_ids": prompt_token_ids,
},
sampling_params=sampling_params,
request_id=request_id or f"{time.time()}",
):
out_queue.put((output.outputs[0], output.finished))
except Exception as e:
logging.error(f"Error in inference_processor: {e}")
async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\ async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
-> AsyncGenerator[CompletionOutput, None]: -> AsyncGenerator[CompletionOutput, None]:
assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
# logging.debug('prompt_token_ids:', prompt_token_ids)
# TODO: 增加上下文控制,取消请求时
sampling_params = SamplingParams(**SAMPLING_PARAMS) sampling_params = SamplingParams(**SAMPLING_PARAMS)
sampling_params.stop_token_ids = stop_token_ids or [6561] sampling_params.stop_token_ids = stop_token_ids or [6561]
if max_tokens: if max_tokens:
@@ -104,49 +135,16 @@ class VllmQwen2LM(Qwen2LM):
): ):
yield output.outputs[0] yield output.outputs[0]
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\ # 使用 同步转异步 会导致vllm崩溃目前选择 queue 的方式,后台线程运行推理任务
-> Generator[CompletionOutput, None, None]: # 提交推理任务到队列中
assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]" out_queue = queue.Queue()
invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None) self.task_queue.put((out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens))
assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}" # 将 out_queue 的结果返回
# logging.debug('prompt_token_ids:', prompt_token_ids) finished = False
# TODO: 增加上下文控制,取消请求时 while not finished:
sampling_params = SamplingParams(**SAMPLING_PARAMS) (output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
sampling_params.stop_token_ids = stop_token_ids or [6561] yield output
if max_tokens:
sampling_params.max_tokens = max_tokens
# 创建独立事件循环
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
# 初始化异步生成器
async_gen = self.llm_engine.generate(
{
"prompt_token_ids": prompt_token_ids,
},
sampling_params=sampling_params,
request_id=request_id or f"{time.time()}",
)
while True:
try:
# 同步获取异步结果
output = loop.run_until_complete(async_gen.__anext__())
yield output.outputs[0]
except StopAsyncIteration:
break
except GeneratorExit:
if async_gen is not None:
loop.run_until_complete(async_gen.aclose())
raise
finally:
# 资源清理
print("资源清理...")
if async_gen is not None:
loop.run_until_complete(async_gen.aclose())
loop.close()
print("资源清理成功")
def inference( def inference(
self, self,
@@ -194,6 +192,9 @@ class VllmQwen2LM(Qwen2LM):
max_token_text_ratio: float = 20, max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2, min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor, None, None]: ) -> Generator[torch.Tensor, None, None]:
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
prompt_speech_token = tensor_to_list(prompt_speech_token)
last_tokens = [] last_tokens = []
prompt_token_ids = [self.sos_eos_token_id] prompt_token_ids = [self.sos_eos_token_id]
text_tokens_cache = prompt_text text_tokens_cache = prompt_text
@@ -202,18 +203,18 @@ class VllmQwen2LM(Qwen2LM):
# text need tokens # text need tokens
assert isinstance(this_text, list), "text need token ids List[int]." assert isinstance(this_text, list), "text need token ids List[int]."
text_tokens_cache += this_text text_tokens_cache += this_text
while len(llm_prompt_speech_token) != 0: while len(prompt_speech_token) != 0:
if len(text_tokens_cache) >= self.mix_ratio[0]: if len(text_tokens_cache) >= self.mix_ratio[0]:
text_input_token = text_tokens_cache[:self.mix_ratio[0]] text_input_token = text_tokens_cache[:self.mix_ratio[0]]
speech_input_token = llm_prompt_speech_token[:self.mix_ratio[1]] speech_input_token = prompt_speech_token[:self.mix_ratio[1]]
prompt_token_ids += text_input_token + speech_input_token prompt_token_ids += text_input_token + speech_input_token
# reset the last cache # reset the last cache
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
llm_prompt_speech_token = llm_prompt_speech_token[self.mix_ratio[1]:] prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
else: else:
logging.info('not enough text token to decode, wait for more') logging.info('not enough text token to decode, wait for more')
break break
if len(llm_prompt_speech_token) == 0: if len(prompt_speech_token) == 0:
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1: if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
logging.info('get fill token, need to append more text token') logging.info('get fill token, need to append more text token')
if len(text_tokens_cache) >= self.mix_ratio[0]: if len(text_tokens_cache) >= self.mix_ratio[0]: