From 96950745a6a2b9b9c5952474456930dd4c8063e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E8=81=AA?= Date: Fri, 21 Mar 2025 16:17:35 +0800 Subject: [PATCH] Revert "mv AsyncLLMEngine init to CosyVoice2" This reverts commit 9b3f35149620681af225c3a61e614f307ac5aacd. --- cosyvoice/cli/cosyvoice.py | 22 ---------------------- cosyvoice/llm/llm_vllm.py | 22 +++++++++++++++++++++- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index b9f2392..39464ca 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -166,29 +166,7 @@ class CosyVoice2(CosyVoice): logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') if use_vllm: try: - os.environ["VLLM_USE_V1"] = '1' - from vllm import AsyncLLMEngine - from vllm.engine.arg_utils import AsyncEngineArgs - # EngineArgs - ENGINE_ARGS = { - "block_size": 16, - "swap_space": 0, - # "enforce_eager": True, - "gpu_memory_utilization": 0.4, - "max_num_batched_tokens": 1024, - "max_model_len": 1024, - "max_num_seqs": 256, - "disable_log_requests": True, - "disable_log_stats": True, - "dtype": "bfloat16" - } self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16) - engine_args = AsyncEngineArgs( - model=model_dir, - **ENGINE_ARGS, - ) - self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args) - self.model.llm_engine = self.llm_engine except Exception as e: logging.warning(f'use vllm inference failed. \n{e}') raise e diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py index 4f6699f..a864a04 100644 --- a/cosyvoice/llm/llm_vllm.py +++ b/cosyvoice/llm/llm_vllm.py @@ -31,6 +31,20 @@ from vllm.sampling_params import SamplingParams from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM) +# EngineArgs +ENGINE_ARGS = { + "block_size": 16, + "swap_space": 0, + # "enforce_eager": True, + "gpu_memory_utilization": 0.4, + "max_num_batched_tokens": 1024, + "max_model_len": 1024, + "max_num_seqs": 256, + "disable_log_requests": True, + "disable_log_stats": True, + "dtype": "float16" +} + from vllm.sampling_params import RequestOutputKind # SamplingParams SAMPLING_PARAMS = { @@ -58,7 +72,13 @@ class VllmQwen2LM(Qwen2LM): self.fp16 = False self.half = lambda: None self.mix_ratio = mix_ratio - self.llm_engine = None + # --------------------------------------------- + # vllm engine 的参数配置 + engine_args = AsyncEngineArgs( + model=model_dir, + **ENGINE_ARGS, + ) + self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args) self.speech_token_size = 6564 # 6561 + 3 self.llm_token_size = 151936 # llm vocab_size