add vllm example

This commit is contained in:
lyuxiang.lx
2025-05-30 10:06:14 +00:00
parent 1f50ae259b
commit a8e1774e82
2 changed files with 22 additions and 13 deletions

View File

@@ -172,19 +172,7 @@ Notice that `vllm==v0.9.0` has a lot of specific requirements, for example `torc
conda create -n cosyvoice_vllm --clone cosyvoice conda create -n cosyvoice_vllm --clone cosyvoice
pip install vllm==v0.9.0 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com pip install vllm==v0.9.0 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
conda activate cosyvoice_vllm conda activate cosyvoice_vllm
``` python vllm_example.py
Remember to register `CosyVoice2ForCausalLM` for vllm inference at the start of the code.
```python
import sys
sys.path.append('third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import CosyVoice2
from vllm import ModelRegistry
from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, load_vllm=True, fp16=False)
``` ```
#### CosyVoice Usage #### CosyVoice Usage

21
vllm_example.py Normal file
View File

@@ -0,0 +1,21 @@
import sys
sys.path.append('third_party/Matcha-TTS')
from vllm import ModelRegistry
from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
from cosyvoice.cli.cosyvoice import CosyVoice2
from cosyvoice.utils.file_utils import load_wav
from cosyvoice.utils.common import set_all_random_seed
from tqdm import tqdm
def main():
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
for i in tqdm(range(100)):
set_all_random_seed(i)
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
continue
if __name__=='__main__':
main()