refactor(llm): 重构 VLLM 推理方式

- 新增基于队列和线程的异步推理机制
- 优化同步推理接口,使用新机制实现
This commit is contained in:
qihua
2025-03-07 23:53:50 +08:00
parent 90b666ea20
commit d4d187bd8c

View File

@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import contextlib
import time
import queue
import asyncio
import threading
from typing import List, Generator, AsyncGenerator
import torch
from cosyvoice.utils.file_utils import logging
@@ -41,6 +42,7 @@ ENGINE_ARGS = {
"max_num_seqs": 256,
"disable_log_requests": True,
"disable_log_stats": True,
"dtype": "float16"
}
from vllm.sampling_params import RequestOutputKind
@@ -84,13 +86,42 @@ class VllmQwen2LM(Qwen2LM):
self.task_token_id = self.sos_eos_token_id + 1
self.zero_token_id = self.task_token_id + 1
# 不能直接在同步函数正确的使用 异步的生成器函数即使使用协程也会对vllm造成崩溃
# 使用 queue 的方式,后台线程运行推理任务
self.task_queue = queue.Queue()
self.loop = asyncio.new_event_loop()
self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
self.loop_thread.start()
# 运行后台协程,用于处理任务队列中的任务
# TODO: 目前只能单任务运行,多任务运行需要对 inference_processor 进行修改
asyncio.run_coroutine_threadsafe(self.inference_processor(self.task_queue), self.loop)
def _run_event_loop(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
async def inference_processor(self, task_queue):
while True:
try:
print(f"inference_processor")
out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens = task_queue.get()
sampling_params = SamplingParams(**SAMPLING_PARAMS)
sampling_params.stop_token_ids = stop_token_ids or [6561]
if max_tokens:
sampling_params.max_tokens = max_tokens
async for output in self.llm_engine.generate(
{
"prompt_token_ids": prompt_token_ids,
},
sampling_params=sampling_params,
request_id=request_id or f"{time.time()}",
):
out_queue.put((output.outputs[0], output.finished))
except Exception as e:
logging.error(f"Error in inference_processor: {e}")
async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
-> AsyncGenerator[CompletionOutput, None]:
assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
# logging.debug('prompt_token_ids:', prompt_token_ids)
# TODO: 增加上下文控制,取消请求时
sampling_params = SamplingParams(**SAMPLING_PARAMS)
sampling_params.stop_token_ids = stop_token_ids or [6561]
if max_tokens:
@@ -104,49 +135,16 @@ class VllmQwen2LM(Qwen2LM):
):
yield output.outputs[0]
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
-> Generator[CompletionOutput, None, None]:
assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
# logging.debug('prompt_token_ids:', prompt_token_ids)
# TODO: 增加上下文控制,取消请求时
sampling_params = SamplingParams(**SAMPLING_PARAMS)
sampling_params.stop_token_ids = stop_token_ids or [6561]
if max_tokens:
sampling_params.max_tokens = max_tokens
# 创建独立事件循环
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
# 初始化异步生成器
async_gen = self.llm_engine.generate(
{
"prompt_token_ids": prompt_token_ids,
},
sampling_params=sampling_params,
request_id=request_id or f"{time.time()}",
)
while True:
try:
# 同步获取异步结果
output = loop.run_until_complete(async_gen.__anext__())
yield output.outputs[0]
except StopAsyncIteration:
break
except GeneratorExit:
if async_gen is not None:
loop.run_until_complete(async_gen.aclose())
raise
finally:
# 资源清理
print("资源清理...")
if async_gen is not None:
loop.run_until_complete(async_gen.aclose())
loop.close()
print("资源清理成功")
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
# 使用 同步转异步 会导致vllm崩溃目前选择 queue 的方式,后台线程运行推理任务
# 提交推理任务到队列中
out_queue = queue.Queue()
self.task_queue.put((out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens))
# 将 out_queue 的结果返回
finished = False
while not finished:
(output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
yield output
def inference(
self,
@@ -194,6 +192,9 @@ class VllmQwen2LM(Qwen2LM):
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor, None, None]:
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
prompt_speech_token = tensor_to_list(prompt_speech_token)
last_tokens = []
prompt_token_ids = [self.sos_eos_token_id]
text_tokens_cache = prompt_text
@@ -202,18 +203,18 @@ class VllmQwen2LM(Qwen2LM):
# text need tokens
assert isinstance(this_text, list), "text need token ids List[int]."
text_tokens_cache += this_text
while len(llm_prompt_speech_token) != 0:
while len(prompt_speech_token) != 0:
if len(text_tokens_cache) >= self.mix_ratio[0]:
text_input_token = text_tokens_cache[:self.mix_ratio[0]]
speech_input_token = llm_prompt_speech_token[:self.mix_ratio[1]]
speech_input_token = prompt_speech_token[:self.mix_ratio[1]]
prompt_token_ids += text_input_token + speech_input_token
# reset the last cache
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
llm_prompt_speech_token = llm_prompt_speech_token[self.mix_ratio[1]:]
prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
else:
logging.info('not enough text token to decode, wait for more')
break
if len(llm_prompt_speech_token) == 0:
if len(prompt_speech_token) == 0:
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
logging.info('get fill token, need to append more text token')
if len(text_tokens_cache) >= self.mix_ratio[0]: