mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add cosyvoice3 vllm example
This commit is contained in:
@@ -88,6 +88,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
|||||||
logging.info("Succesfully convert onnx to trt...")
|
logging.info("Succesfully convert onnx to trt...")
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE do not support bistream inference as only speech token embedding/head is kept
|
||||||
def export_cosyvoice2_vllm(model, model_path, device):
|
def export_cosyvoice2_vllm(model, model_path, device):
|
||||||
if os.path.exists(model_path):
|
if os.path.exists(model_path):
|
||||||
return
|
return
|
||||||
@@ -98,11 +99,13 @@ def export_cosyvoice2_vllm(model, model_path, device):
|
|||||||
|
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
# lm_head
|
# lm_head
|
||||||
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
|
use_bias = True if model.llm_decoder.bias is not None else False
|
||||||
|
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=use_bias)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
|
new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
|
||||||
new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
|
|
||||||
new_lm_head.weight[vocab_size:] = 0
|
new_lm_head.weight[vocab_size:] = 0
|
||||||
|
if use_bias is True:
|
||||||
|
new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
|
||||||
new_lm_head.bias[vocab_size:] = 0
|
new_lm_head.bias[vocab_size:] = 0
|
||||||
model.llm.model.lm_head = new_lm_head
|
model.llm.model.lm_head = new_lm_head
|
||||||
new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
||||||
@@ -121,9 +124,12 @@ def export_cosyvoice2_vllm(model, model_path, device):
|
|||||||
del model.llm.model.config.eos_token_id
|
del model.llm.model.config.eos_token_id
|
||||||
model.llm.model.config.vocab_size = pad_vocab_size
|
model.llm.model.config.vocab_size = pad_vocab_size
|
||||||
model.llm.model.config.tie_word_embeddings = False
|
model.llm.model.config.tie_word_embeddings = False
|
||||||
model.llm.model.config.use_bias = True
|
model.llm.model.config.use_bias = use_bias
|
||||||
model.llm.model.save_pretrained(model_path)
|
model.llm.model.save_pretrained(model_path)
|
||||||
|
if use_bias is True:
|
||||||
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||||
|
else:
|
||||||
|
os.system('sed -i s@Qwen2ForCausalLM@Qwen2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||||
model.llm.model.config.vocab_size = tmp_vocab_size
|
model.llm.model.config.vocab_size = tmp_vocab_size
|
||||||
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||||
model.llm.model.set_input_embeddings(embed_tokens)
|
model.llm.model.set_input_embeddings(embed_tokens)
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ def cosyvoice3_example():
|
|||||||
torchaudio.save('hotfix_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
torchaudio.save('hotfix_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# cosyvoice_example()
|
cosyvoice_example()
|
||||||
cosyvoice2_example()
|
cosyvoice2_example()
|
||||||
cosyvoice3_example()
|
cosyvoice3_example()
|
||||||
|
|
||||||
|
|||||||
@@ -4,20 +4,33 @@ from vllm import ModelRegistry
|
|||||||
from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
|
from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
|
||||||
ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
|
ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
|
||||||
|
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
from cosyvoice.cli.cosyvoice import CosyVoice2, CosyVoice3
|
||||||
from cosyvoice.utils.file_utils import load_wav
|
|
||||||
from cosyvoice.utils.common import set_all_random_seed
|
from cosyvoice.utils.common import set_all_random_seed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def cosyvoice2_example():
|
||||||
|
""" CosyVoice2 vllm usage
|
||||||
|
"""
|
||||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
|
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
|
||||||
|
for i in tqdm(range(100)):
|
||||||
|
set_all_random_seed(i)
|
||||||
|
for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
def cosyvoice3_example():
|
||||||
|
""" CosyVoice3 vllm usage
|
||||||
|
"""
|
||||||
|
cosyvoice = CosyVoice3('pretrained_models/CosyVoice3-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
|
||||||
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
||||||
for i in tqdm(range(100)):
|
for i in tqdm(range(100)):
|
||||||
set_all_random_seed(i)
|
set_all_random_seed(i)
|
||||||
for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cosyvoice2_example()
|
||||||
|
cosyvoice3_example()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user