mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add vllm_codec_engine
This commit is contained in:
@@ -49,6 +49,7 @@ class CosyVoice:
|
|||||||
self.model.load('{}/llm.pt'.format(model_dir),
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
'{}/flow.pt'.format(model_dir),
|
'{}/flow.pt'.format(model_dir),
|
||||||
'{}/hift.pt'.format(model_dir))
|
'{}/hift.pt'.format(model_dir))
|
||||||
|
self.vllm_codec_engine = None
|
||||||
if load_jit:
|
if load_jit:
|
||||||
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
@@ -149,8 +150,16 @@ class CosyVoice2(CosyVoice):
|
|||||||
self.model.load('{}/llm.pt'.format(model_dir),
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
'{}/flow.pt'.format(model_dir),
|
'{}/flow.pt'.format(model_dir),
|
||||||
'{}/hift.pt'.format(model_dir))
|
'{}/hift.pt'.format(model_dir))
|
||||||
|
self.vllm_codec_engine = None
|
||||||
if use_vllm:
|
if use_vllm:
|
||||||
|
from vllm import EngineArgs, LLMEngine
|
||||||
self.model.export_codec_vllm(''.join([model_dir, '/codec_vllm_model']))
|
self.model.export_codec_vllm(''.join([model_dir, '/codec_vllm_model']))
|
||||||
|
engine_args = EngineArgs(model=''.join([model_dir, '/codec_vllm_model']),
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
gpu_memory_utilization=0.1)
|
||||||
|
self.vllm_codec_engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
self.model.llm.vllm_codec_engine = self.vllm_codec_engine
|
||||||
|
|
||||||
if load_jit:
|
if load_jit:
|
||||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
if load_trt:
|
if load_trt:
|
||||||
|
|||||||
@@ -282,6 +282,7 @@ class Qwen2LM(TransformerLM):
|
|||||||
# 4. sampling method
|
# 4. sampling method
|
||||||
self.sampling = sampling
|
self.sampling = sampling
|
||||||
self.mix_ratio = mix_ratio
|
self.mix_ratio = mix_ratio
|
||||||
|
self.vllm_codec_engine = None
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(
|
def inference(
|
||||||
|
|||||||
Reference in New Issue
Block a user