add vllm_codec_engine

This commit is contained in:
雾聪
2025-02-25 17:43:33 +08:00
parent c37c00ff94
commit 4df0683a37
2 changed files with 10 additions and 0 deletions

View File

@@ -49,6 +49,7 @@ class CosyVoice:
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
self.vllm_codec_engine = None
if load_jit:
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
@@ -149,8 +150,16 @@ class CosyVoice2(CosyVoice):
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
self.vllm_codec_engine = None
if use_vllm:
from vllm import EngineArgs, LLMEngine
self.model.export_codec_vllm(''.join([model_dir, '/codec_vllm_model']))
engine_args = EngineArgs(model=''.join([model_dir, '/codec_vllm_model']),
skip_tokenizer_init=True,
gpu_memory_utilization=0.1)
self.vllm_codec_engine = LLMEngine.from_engine_args(engine_args)
self.model.llm.vllm_codec_engine = self.vllm_codec_engine
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt:

View File

@@ -282,6 +282,7 @@ class Qwen2LM(TransformerLM):
# 4. sampling method
self.sampling = sampling
self.mix_ratio = mix_ratio
self.vllm_codec_engine = None
@torch.inference_mode()
def inference(