mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
refactor(llm): 重构 VLLM 推理方式
- 新增基于队列和线程的异步推理机制 - 优化同步推理接口,使用新机制实现
This commit is contained in:
@@ -11,9 +11,10 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import asyncio
|
|
||||||
import contextlib
|
|
||||||
import time
|
import time
|
||||||
|
import queue
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
from typing import List, Generator, AsyncGenerator
|
from typing import List, Generator, AsyncGenerator
|
||||||
import torch
|
import torch
|
||||||
from cosyvoice.utils.file_utils import logging
|
from cosyvoice.utils.file_utils import logging
|
||||||
@@ -41,6 +42,7 @@ ENGINE_ARGS = {
|
|||||||
"max_num_seqs": 256,
|
"max_num_seqs": 256,
|
||||||
"disable_log_requests": True,
|
"disable_log_requests": True,
|
||||||
"disable_log_stats": True,
|
"disable_log_stats": True,
|
||||||
|
"dtype": "float16"
|
||||||
}
|
}
|
||||||
|
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
@@ -84,13 +86,42 @@ class VllmQwen2LM(Qwen2LM):
|
|||||||
self.task_token_id = self.sos_eos_token_id + 1
|
self.task_token_id = self.sos_eos_token_id + 1
|
||||||
self.zero_token_id = self.task_token_id + 1
|
self.zero_token_id = self.task_token_id + 1
|
||||||
|
|
||||||
|
# 不能直接在同步函数正确的使用 异步的生成器函数,即使使用协程也会对vllm造成崩溃
|
||||||
|
# 使用 queue 的方式,后台线程运行推理任务
|
||||||
|
self.task_queue = queue.Queue()
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
|
||||||
|
self.loop_thread.start()
|
||||||
|
# 运行后台协程,用于处理任务队列中的任务
|
||||||
|
# TODO: 目前只能单任务运行,多任务运行需要对 inference_processor 进行修改
|
||||||
|
asyncio.run_coroutine_threadsafe(self.inference_processor(self.task_queue), self.loop)
|
||||||
|
|
||||||
|
def _run_event_loop(self):
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
self.loop.run_forever()
|
||||||
|
|
||||||
|
async def inference_processor(self, task_queue):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
print(f"inference_processor")
|
||||||
|
out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens = task_queue.get()
|
||||||
|
sampling_params = SamplingParams(**SAMPLING_PARAMS)
|
||||||
|
sampling_params.stop_token_ids = stop_token_ids or [6561]
|
||||||
|
if max_tokens:
|
||||||
|
sampling_params.max_tokens = max_tokens
|
||||||
|
async for output in self.llm_engine.generate(
|
||||||
|
{
|
||||||
|
"prompt_token_ids": prompt_token_ids,
|
||||||
|
},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
request_id=request_id or f"{time.time()}",
|
||||||
|
):
|
||||||
|
out_queue.put((output.outputs[0], output.finished))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in inference_processor: {e}")
|
||||||
|
|
||||||
async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
|
async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
|
||||||
-> AsyncGenerator[CompletionOutput, None]:
|
-> AsyncGenerator[CompletionOutput, None]:
|
||||||
assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
|
|
||||||
invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
|
|
||||||
assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
|
|
||||||
# logging.debug('prompt_token_ids:', prompt_token_ids)
|
|
||||||
# TODO: 增加上下文控制,取消请求时
|
|
||||||
sampling_params = SamplingParams(**SAMPLING_PARAMS)
|
sampling_params = SamplingParams(**SAMPLING_PARAMS)
|
||||||
sampling_params.stop_token_ids = stop_token_ids or [6561]
|
sampling_params.stop_token_ids = stop_token_ids or [6561]
|
||||||
if max_tokens:
|
if max_tokens:
|
||||||
@@ -104,49 +135,16 @@ class VllmQwen2LM(Qwen2LM):
|
|||||||
):
|
):
|
||||||
yield output.outputs[0]
|
yield output.outputs[0]
|
||||||
|
|
||||||
|
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
|
||||||
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
|
# 使用 同步转异步 会导致vllm崩溃,目前选择 queue 的方式,后台线程运行推理任务
|
||||||
-> Generator[CompletionOutput, None, None]:
|
# 提交推理任务到队列中
|
||||||
assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
|
out_queue = queue.Queue()
|
||||||
invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
|
self.task_queue.put((out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens))
|
||||||
assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
|
# 将 out_queue 的结果返回
|
||||||
# logging.debug('prompt_token_ids:', prompt_token_ids)
|
finished = False
|
||||||
# TODO: 增加上下文控制,取消请求时
|
while not finished:
|
||||||
sampling_params = SamplingParams(**SAMPLING_PARAMS)
|
(output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
|
||||||
sampling_params.stop_token_ids = stop_token_ids or [6561]
|
yield output
|
||||||
if max_tokens:
|
|
||||||
sampling_params.max_tokens = max_tokens
|
|
||||||
|
|
||||||
# 创建独立事件循环
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
try:
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
# 初始化异步生成器
|
|
||||||
async_gen = self.llm_engine.generate(
|
|
||||||
{
|
|
||||||
"prompt_token_ids": prompt_token_ids,
|
|
||||||
},
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
request_id=request_id or f"{time.time()}",
|
|
||||||
)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# 同步获取异步结果
|
|
||||||
output = loop.run_until_complete(async_gen.__anext__())
|
|
||||||
yield output.outputs[0]
|
|
||||||
except StopAsyncIteration:
|
|
||||||
break
|
|
||||||
except GeneratorExit:
|
|
||||||
if async_gen is not None:
|
|
||||||
loop.run_until_complete(async_gen.aclose())
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
# 资源清理
|
|
||||||
print("资源清理...")
|
|
||||||
if async_gen is not None:
|
|
||||||
loop.run_until_complete(async_gen.aclose())
|
|
||||||
loop.close()
|
|
||||||
print("资源清理成功")
|
|
||||||
|
|
||||||
def inference(
|
def inference(
|
||||||
self,
|
self,
|
||||||
@@ -194,6 +192,9 @@ class VllmQwen2LM(Qwen2LM):
|
|||||||
max_token_text_ratio: float = 20,
|
max_token_text_ratio: float = 20,
|
||||||
min_token_text_ratio: float = 2,
|
min_token_text_ratio: float = 2,
|
||||||
) -> Generator[torch.Tensor, None, None]:
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
|
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
|
||||||
|
prompt_speech_token = tensor_to_list(prompt_speech_token)
|
||||||
|
|
||||||
last_tokens = []
|
last_tokens = []
|
||||||
prompt_token_ids = [self.sos_eos_token_id]
|
prompt_token_ids = [self.sos_eos_token_id]
|
||||||
text_tokens_cache = prompt_text
|
text_tokens_cache = prompt_text
|
||||||
@@ -202,18 +203,18 @@ class VllmQwen2LM(Qwen2LM):
|
|||||||
# text need tokens
|
# text need tokens
|
||||||
assert isinstance(this_text, list), "text need token ids List[int]."
|
assert isinstance(this_text, list), "text need token ids List[int]."
|
||||||
text_tokens_cache += this_text
|
text_tokens_cache += this_text
|
||||||
while len(llm_prompt_speech_token) != 0:
|
while len(prompt_speech_token) != 0:
|
||||||
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
||||||
text_input_token = text_tokens_cache[:self.mix_ratio[0]]
|
text_input_token = text_tokens_cache[:self.mix_ratio[0]]
|
||||||
speech_input_token = llm_prompt_speech_token[:self.mix_ratio[1]]
|
speech_input_token = prompt_speech_token[:self.mix_ratio[1]]
|
||||||
prompt_token_ids += text_input_token + speech_input_token
|
prompt_token_ids += text_input_token + speech_input_token
|
||||||
# reset the last cache
|
# reset the last cache
|
||||||
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
||||||
llm_prompt_speech_token = llm_prompt_speech_token[self.mix_ratio[1]:]
|
prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
|
||||||
else:
|
else:
|
||||||
logging.info('not enough text token to decode, wait for more')
|
logging.info('not enough text token to decode, wait for more')
|
||||||
break
|
break
|
||||||
if len(llm_prompt_speech_token) == 0:
|
if len(prompt_speech_token) == 0:
|
||||||
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
|
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
|
||||||
logging.info('get fill token, need to append more text token')
|
logging.info('get fill token, need to append more text token')
|
||||||
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
||||||
|
|||||||
Reference in New Issue
Block a user