From 90b666ea2060ced8e69ce1c6d43ec695d1c9f2f1 Mon Sep 17 00:00:00 2001 From: qihua Date: Fri, 7 Mar 2025 20:26:19 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=90=88=E5=B9=B6vllm?= =?UTF-8?q?=E6=94=AF=E6=8C=81=EF=BC=8C=E5=BC=82=E6=AD=A5=E6=8E=A8=E7=90=86?= =?UTF-8?q?=E7=9A=84=E9=80=9A=E9=81=93=E5=A4=84=E7=90=86=E8=BF=98=E5=AD=98?= =?UTF-8?q?=E5=9C=A8bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosyvoice/cli/cosyvoice.py | 43 +++- cosyvoice/cli/frontend.py | 89 ++++++- cosyvoice/cli/model.py | 23 ++ cosyvoice/llm/llm_vllm.py | 248 +++++++++++++++++++ cosyvoice/llm/vllm_use_cosyvoice2_model.py | 263 +++++++++++++++++++++ 5 files changed, 658 insertions(+), 8 deletions(-) create mode 100644 cosyvoice/llm/llm_vllm.py create mode 100644 cosyvoice/llm/vllm_use_cosyvoice2_model.py diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index e2d62e2..39464ca 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -19,7 +19,7 @@ from hyperpyyaml import load_hyperpyyaml from modelscope import snapshot_download import torch from cosyvoice.cli.frontend import CosyVoiceFrontEnd -from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model +from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, VllmCosyVoice2Model from cosyvoice.utils.file_utils import logging from cosyvoice.utils.class_utils import get_model_type @@ -63,6 +63,9 @@ class CosyVoice: spks = list(self.frontend.spk2info.keys()) return spks + def add_spk_info(self, spk_id, spk_info): + self.frontend.add_spk_info(spk_id, spk_info) + def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True): for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): model_input = self.frontend.frontend_sft(i, spk_id) @@ -88,6 +91,22 @@ class CosyVoice: yield model_output start_time = time.time() + def inference_zero_shot_by_spk_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True): + """使用预定义的说话人执行 zero_shot 推理""" + for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): + model_input = self.frontend.frontend_zero_shot_by_spk_id(i, spk_id) + start_time = time.time() + last_time = start_time + chunk_index = 0 + logging.info('synthesis text {}'.format(i)) + for model_output in self.model.tts(**model_input, stream=stream, speed=speed): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate + logging.info('yield speech index:{}, len {:.2f}, rtf {:.3f}, cost {:.3f}s, all cost time {:.3f}s'.format( + chunk_index, speech_len, (time.time()-last_time)/speech_len, time.time()-last_time, time.time()-start_time)) + yield model_output + last_time = time.time() + chunk_index += 1 + def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True): for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate) @@ -126,7 +145,7 @@ class CosyVoice: class CosyVoice2(CosyVoice): - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 @@ -145,7 +164,14 @@ class CosyVoice2(CosyVoice): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') - self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) + if use_vllm: + try: + self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16) + except Exception as e: + logging.warning(f'use vllm inference failed. \n{e}') + raise e + else: + self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) @@ -171,3 +197,14 @@ class CosyVoice2(CosyVoice): logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) yield model_output start_time = time.time() + + def inference_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id, stream=False, speed=1.0, text_frontend=True): + for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): + model_input = self.frontend.frontend_instruct2_by_spk_id(i, instruct_text, spk_id) + start_time = time.time() + logging.info('synthesis text {}'.format(i)) + for model_output in self.model.tts(**model_input, stream=stream, speed=speed): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 6e10f00..5aa2d34 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Generator +from typing import Generator, Optional import json import onnxruntime import torch @@ -24,6 +24,8 @@ import torchaudio import os import re import inflect +from pydantic import BaseModel, ConfigDict + try: import ttsfrd use_ttsfrd = True @@ -36,6 +38,18 @@ from cosyvoice.utils.file_utils import logging from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation +class SpeakerInfo(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: Optional[str] = None + spk_id: str + prompt_text: str + prompt_text_token: torch.Tensor + speech_feat: torch.Tensor + speech_token: torch.Tensor + embedding: torch.Tensor + + class CosyVoiceFrontEnd: def __init__(self, @@ -55,8 +69,9 @@ class CosyVoiceFrontEnd: self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"]) + self.spk2info_path = spk2info if os.path.exists(spk2info): - self.spk2info = torch.load(spk2info, map_location=self.device) + self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=False) else: self.spk2info = {} self.allowed_special = allowed_special @@ -68,7 +83,8 @@ class CosyVoiceFrontEnd: 'failed to initialize ttsfrd resource' self.frd.set_lang_type('pinyinvg') else: - self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) + # self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) + self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=False) self.en_tn_model = EnNormalizer() self.inflect_parser = inflect.engine() @@ -86,8 +102,9 @@ class CosyVoiceFrontEnd: def _extract_text_token_generator(self, text_generator): for text in text_generator: text_token, _ = self._extract_text_token(text) - for i in range(text_token.shape[1]): - yield text_token[:, i: i + 1] + # for i in range(text_token.shape[1]): + # yield text_token[:, i: i + 1] + yield text_token def _extract_speech_token(self, speech): assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' @@ -138,11 +155,15 @@ class CosyVoiceFrontEnd: text = text.replace(" - ", ",") text = remove_bracket(text) text = re.sub(r'[,,、]+$', '。', text) + if not split: + return text texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False)) else: text = self.en_tn_model.normalize(text) text = spell_out_number(text, self.inflect_parser) + if not split: + return text texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False)) texts = [i for i in texts if not is_only_punctuation(i)] @@ -151,6 +172,7 @@ class CosyVoiceFrontEnd: def frontend_sft(self, tts_text, spk_id): tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) embedding = self.spk2info[spk_id]['embedding'] + assert embedding is not None model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding} return model_input @@ -209,3 +231,60 @@ class CosyVoiceFrontEnd: 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len, 'flow_embedding': embedding} return model_input + + def generate_spk_info(self, spk_id: str, prompt_text: str, prompt_speech_16k: torch.Tensor, resample_rate:int=24000, name: str=None): + assert isinstance(spk_id, str) + assert spk_id not in self.spk2info, "spk_id already exists" + prompt_text_token, _ = self._extract_text_token(prompt_text) + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k) + speech_feat, _ = self._extract_speech_feat(prompt_speech_resample) + speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k) + if resample_rate == 24000: + # cosyvoice2, force speech_feat % speech_token = 2 + token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1]) + speech_feat = speech_feat[:, :2 * token_len] + speech_token = speech_token[:, :token_len] + embedding = self._extract_spk_embedding(prompt_speech_16k) + spk_info = SpeakerInfo( + name=name, + spk_id=spk_id, + prompt_text=prompt_text, + prompt_text_token=prompt_text_token, + speech_feat=speech_feat, + speech_token=speech_token, + embedding=embedding, + ) + self.add_spk_info(spk_id, spk_info) + + def add_spk_info(self, spk_id: str, spk_info: dict|SpeakerInfo): + if isinstance(spk_info, BaseModel): + spk_info = spk_info.model_dump() + self.spk2info[spk_id] = spk_info + if self.spk2info_path: + torch.save(self.spk2info, self.spk2info_path) + + def frontend_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id): + assert spk_id in self.spk2info + tts_text_token, _ = self._extract_text_token(tts_text) + prompt_text_token, _ = self._extract_text_token(instruct_text + '<|endofprompt|>') + model_input = {'text': tts_text_token, + 'prompt_text': prompt_text_token, + 'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'], + 'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'], + 'llm_embedding': self.spk2info[spk_id]['embedding'], + 'flow_embedding': self.spk2info[spk_id]['embedding'], + } + return model_input + + def frontend_zero_shot_by_spk_id(self, tts_text, spk_id): + assert spk_id in self.spk2info + tts_text_token, _ = self._extract_text_token(tts_text) + model_input = {'text': tts_text_token, + 'prompt_text': self.spk2info[spk_id]['prompt_text_token'], + 'llm_prompt_speech_token': self.spk2info[spk_id]['speech_token'], + 'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'], + 'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'], + 'llm_embedding': self.spk2info[spk_id]['embedding'], + 'flow_embedding': self.spk2info[spk_id]['embedding'] + } + return model_input \ No newline at end of file diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 9ebf8cb..c0d25ba 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -409,3 +409,26 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) torch.cuda.empty_cache() + + +class VllmCosyVoice2Model(CosyVoice2Model): + def __init__(self, + model_dir: str, + flow: torch.nn.Module, + hift: torch.nn.Module, + fp16: bool): + try: + from cosyvoice.llm.llm_vllm import VllmQwen2LM + except Exception as e: + raise e + llm = VllmQwen2LM(model_dir) + super().__init__(llm,flow,hift,fp16) + + def load(self, llm_model, flow_model, hift_model): + self.flow.load_state_dict(torch.load(flow_model, weights_only=True, map_location=self.device), strict=True) + self.flow.to(self.device).eval() + # in case hift_model is a hifigan model + hift_state_dict = {k.replace('generator.', ''): v for k, v in + torch.load(hift_model, weights_only=True, map_location=self.device).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py new file mode 100644 index 0000000..c43c53a --- /dev/null +++ b/cosyvoice/llm/llm_vllm.py @@ -0,0 +1,248 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import contextlib +import time +from typing import List, Generator, AsyncGenerator +import torch +from cosyvoice.utils.file_utils import logging +from cosyvoice.llm.llm import Qwen2LM + +# 启用vllm V1版本 +import os +os.environ["VLLM_USE_V1"] = '1' +from vllm import ModelRegistry +from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput +from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs +from vllm.sampling_params import SamplingParams + +from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM +ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM) + +# EngineArgs +ENGINE_ARGS = { + "block_size": 16, + "swap_space": 0, + # "enforce_eager": True, + "gpu_memory_utilization": 0.4, + "max_num_batched_tokens": 1024, + "max_model_len": 1024, + "max_num_seqs": 256, + "disable_log_requests": True, + "disable_log_stats": True, +} + +from vllm.sampling_params import RequestOutputKind +# SamplingParams +SAMPLING_PARAMS = { + "temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token + "top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token + "top_k": 25, + # "min_tokens": 80, # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动 + # "presence_penalty": 1.0, # 不支持设置 + # "frequency_penalty": 0.0, # 不支持设置 + "max_tokens": 1024, + "detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算 + "ignore_eos": False, + "output_kind": RequestOutputKind.DELTA # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码 +} + +def tensor_to_list(tensor: torch.tensor): + return tensor.view(-1).cpu().numpy().tolist() + +class VllmQwen2LM(Qwen2LM): + def __init__( + self, + model_dir, + mix_ratio: List[int] = [5, 15], + ): + self.fp16 = False + self.half = lambda: None + self.mix_ratio = mix_ratio + # --------------------------------------------- + # vllm engine 的参数配置 + engine_args = AsyncEngineArgs( + model=model_dir, + **ENGINE_ARGS, + ) + self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args) + + self.speech_token_size = 6564 # 6561 + 3 + self.llm_token_size = 151936 # llm vocab_size + self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1 + self.task_token_id = self.sos_eos_token_id + 1 + self.zero_token_id = self.task_token_id + 1 + + async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\ + -> AsyncGenerator[CompletionOutput, None]: + assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]" + invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None) + assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}" + # logging.debug('prompt_token_ids:', prompt_token_ids) + # TODO: 增加上下文控制,取消请求时 + sampling_params = SamplingParams(**SAMPLING_PARAMS) + sampling_params.stop_token_ids = stop_token_ids or [6561] + if max_tokens: + sampling_params.max_tokens = max_tokens + async for output in self.llm_engine.generate( + { + "prompt_token_ids": prompt_token_ids, + }, + sampling_params=sampling_params, + request_id=request_id or f"{time.time()}", + ): + yield output.outputs[0] + + + def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\ + -> Generator[CompletionOutput, None, None]: + assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]" + invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None) + assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}" + # logging.debug('prompt_token_ids:', prompt_token_ids) + # TODO: 增加上下文控制,取消请求时 + sampling_params = SamplingParams(**SAMPLING_PARAMS) + sampling_params.stop_token_ids = stop_token_ids or [6561] + if max_tokens: + sampling_params.max_tokens = max_tokens + + # 创建独立事件循环 + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + # 初始化异步生成器 + async_gen = self.llm_engine.generate( + { + "prompt_token_ids": prompt_token_ids, + }, + sampling_params=sampling_params, + request_id=request_id or f"{time.time()}", + ) + while True: + try: + # 同步获取异步结果 + output = loop.run_until_complete(async_gen.__anext__()) + yield output.outputs[0] + except StopAsyncIteration: + break + except GeneratorExit: + if async_gen is not None: + loop.run_until_complete(async_gen.aclose()) + raise + finally: + # 资源清理 + print("资源清理...") + if async_gen is not None: + loop.run_until_complete(async_gen.aclose()) + loop.close() + print("资源清理成功") + + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor|int, None, None]: + prompt_text = tensor_to_list(prompt_text + torch.tensor(6564)) + prompt_speech_token = tensor_to_list(prompt_speech_token) + + text = tensor_to_list(text + torch.tensor(6564)) + prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \ + [self.task_token_id] + prompt_speech_token + max_tokens = len(text) * 20 + for output in self.llm_inference( + prompt_token_ids, + stop_token_ids=[6561], + max_tokens=max_tokens, + ): + if output.token_ids[-1] == 6561: + need_add_tokens = output.token_ids[:-1] + else: + need_add_tokens = output.token_ids + # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 + # yield need_add_tokens + for token in need_add_tokens: + yield token + + def inference_bistream( + self, + text: Generator, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + last_tokens = [] + prompt_token_ids = [self.sos_eos_token_id] + text_tokens_cache = prompt_text + for this_text in text: + this_text = tensor_to_list(this_text + torch.tensor(6564)) + # text need tokens + assert isinstance(this_text, list), "text need token ids List[int]." + text_tokens_cache += this_text + while len(llm_prompt_speech_token) != 0: + if len(text_tokens_cache) >= self.mix_ratio[0]: + text_input_token = text_tokens_cache[:self.mix_ratio[0]] + speech_input_token = llm_prompt_speech_token[:self.mix_ratio[1]] + prompt_token_ids += text_input_token + speech_input_token + # reset the last cache + text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] + llm_prompt_speech_token = llm_prompt_speech_token[self.mix_ratio[1]:] + else: + logging.info('not enough text token to decode, wait for more') + break + if len(llm_prompt_speech_token) == 0: + if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1: + logging.info('get fill token, need to append more text token') + if len(text_tokens_cache) >= self.mix_ratio[0]: + text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]] + prompt_token_ids += text_tokens_temp + logging.info('append {} text token'.format(len(text_tokens_temp))) + text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] + else: + logging.info('not enough text token to decode, wait for more') + continue + for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]): + last_tokens = output.token_ids + if last_tokens[-1] == 6563: + need_add_tokens = last_tokens[:-1] + else: + need_add_tokens = last_tokens + # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 + # yield need_add_tokens + for token in need_add_tokens: + yield token + prompt_token_ids.extend(need_add_tokens) + prompt_token_ids += text_tokens_cache + [self.task_token_id] + logging.info('no more text token, decode until met eos') + for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]): + if output.token_ids[-1] == 6561: + need_add_tokens = output.token_ids[:-1] + else: + need_add_tokens = output.token_ids + # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 + # yield need_add_tokens + for token in need_add_tokens: + yield token diff --git a/cosyvoice/llm/vllm_use_cosyvoice2_model.py b/cosyvoice/llm/vllm_use_cosyvoice2_model.py new file mode 100644 index 0000000..6e36ef3 --- /dev/null +++ b/cosyvoice/llm/vllm_use_cosyvoice2_model.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2 model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any +from typing_extensions import TypeVar + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm.model_executor.models.interfaces import T +from vllm.model_executor.models.qwen2 import Qwen2Model + +from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings + +logger = init_logger(__name__) + +IGNORE_ID = -1 + + +class CosyVoice2Model(nn.Module): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.llm_input_size = 896 + self.llm_output_size = 896 + + self.speech_token_size = 6561+3 + self.llm_token_size = config.vocab_size + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.fill_token = 2 + + + self.allow_patterns_overrides = ["llm.*"] + self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size) + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + # self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size) + self.llm_decoder = ParallelLMHead(self.speech_token_size, + self.llm_output_size, + bias=True, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "llm_decoder")) + self.logits_processor = LogitsProcessor(self.speech_token_size) + + # length_normalized_loss: bool = True, + # lsm_weight: float = 0.0, + # self.criterion_ce = LabelSmoothingLoss( + # size=self.speech_token_size, + # padding_idx=IGNORE_ID, + # smoothing=lsm_weight, + # normalize_length=length_normalized_loss, + # ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size) + + # 4. sampling method + ## use vllm sampling method + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.mix_ratio: List[int] = [5, 15] + + # 定义特殊token常量 + self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32) + self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404 + self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405 + self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32) + + self.zero_embed_buffer = torch.zeros( + (vllm_config.scheduler_config.max_num_seqs, self.llm_input_size), + dtype=self.llm_embedding.weight.dtype, + device=self.llm_embedding.weight.device + ) + self.inputs_embed_buffer = torch.zeros( + (vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size), + dtype=self.llm_embedding.weight.dtype, + device=self.llm_embedding.weight.device, + ) + + def get_sos_eos_emb(self): + return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + + def get_task_id_emb(self): + return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[T] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + ) -> torch.Tensor: + """ + Returns the input embeddings merged from the text embeddings from + input_ids and the multimodal embeddings generated from multimodal + kwargs. + """ + # 创建掩码,标记哪些 token_id 属于音频 Token + mask = input_ids < self.speech_token_size + + # 获取 input_ids 的原始形状 + input_shape = input_ids.shape + # 展平 input_ids 和掩码以便统一处理 + flat_input_ids = input_ids.view(-1) + flat_mask = mask.view(-1) + + inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]] + inputs_embeds.zero_() + + # Process speech tokens + if flat_mask.any(): + speech_token_ids = flat_input_ids[flat_mask] + inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids) + + # 处理大于 delta 的 token_id + if (~flat_mask).any(): + llm_token_ids = flat_input_ids[~flat_mask] + llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask]) + + sos_eos_mask = llm_token_ids == self.sos_eos_token_id + task_mask = llm_token_ids == self.task_token_id + zero_mask = llm_token_ids == self.zero_token_id + normal_mask = ~(sos_eos_mask | task_mask | zero_mask) + + # 分层处理逻辑 + # 第一优先级:SOS/EOS标记 + if sos_eos_mask.any(): + llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0) + + # 第二优先级:任务标记 + if task_mask.any(): + llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0) + + # 第二优先级:空音频标记 + if zero_mask.any(): + llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])] + + # 常规LLM token + if normal_mask.any(): + original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta + # print('original_ids: ',original_ids) + llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids) + + inputs_embeds[~flat_mask] = llm_embeds + + inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size) + + # 合并多模态嵌入(如果有) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.audio_token_index + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings( + input_ids, + attn_metadata=attn_metadata, + ) + return self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.llm_decoder, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + @staticmethod + def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]: + for name, param in weights: + # 处理Qwen2Model核心参数 + if name.startswith("llm."): + if name.startswith("llm.model.model."): + name = name.replace("llm.model.model.", "model.") + else: + continue + # print('weights name: ', name) + yield name, param + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + weights = self.convert_weights(weights) + loader = AutoWeightsLoader(self) + loader.load_weights(weights) From d4d187bd8c8b96763ac64f6d1a7b4a4b9a2c5392 Mon Sep 17 00:00:00 2001 From: qihua Date: Fri, 7 Mar 2025 23:53:50 +0800 Subject: [PATCH 2/7] =?UTF-8?q?refactor(llm):=20=E9=87=8D=E6=9E=84=20VLLM?= =?UTF-8?q?=20=E6=8E=A8=E7=90=86=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增基于队列和线程的异步推理机制 - 优化同步推理接口,使用新机制实现 --- cosyvoice/llm/llm_vllm.py | 109 +++++++++++++++++++------------------- 1 file changed, 55 insertions(+), 54 deletions(-) diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py index c43c53a..61b1090 100644 --- a/cosyvoice/llm/llm_vllm.py +++ b/cosyvoice/llm/llm_vllm.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import contextlib import time +import queue +import asyncio +import threading from typing import List, Generator, AsyncGenerator import torch from cosyvoice.utils.file_utils import logging @@ -41,6 +42,7 @@ ENGINE_ARGS = { "max_num_seqs": 256, "disable_log_requests": True, "disable_log_stats": True, + "dtype": "float16" } from vllm.sampling_params import RequestOutputKind @@ -84,13 +86,42 @@ class VllmQwen2LM(Qwen2LM): self.task_token_id = self.sos_eos_token_id + 1 self.zero_token_id = self.task_token_id + 1 + # 不能直接在同步函数正确的使用 异步的生成器函数,即使使用协程也会对vllm造成崩溃 + # 使用 queue 的方式,后台线程运行推理任务 + self.task_queue = queue.Queue() + self.loop = asyncio.new_event_loop() + self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True) + self.loop_thread.start() + # 运行后台协程,用于处理任务队列中的任务 + # TODO: 目前只能单任务运行,多任务运行需要对 inference_processor 进行修改 + asyncio.run_coroutine_threadsafe(self.inference_processor(self.task_queue), self.loop) + + def _run_event_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + async def inference_processor(self, task_queue): + while True: + try: + print(f"inference_processor") + out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens = task_queue.get() + sampling_params = SamplingParams(**SAMPLING_PARAMS) + sampling_params.stop_token_ids = stop_token_ids or [6561] + if max_tokens: + sampling_params.max_tokens = max_tokens + async for output in self.llm_engine.generate( + { + "prompt_token_ids": prompt_token_ids, + }, + sampling_params=sampling_params, + request_id=request_id or f"{time.time()}", + ): + out_queue.put((output.outputs[0], output.finished)) + except Exception as e: + logging.error(f"Error in inference_processor: {e}") + async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\ -> AsyncGenerator[CompletionOutput, None]: - assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]" - invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None) - assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}" - # logging.debug('prompt_token_ids:', prompt_token_ids) - # TODO: 增加上下文控制,取消请求时 sampling_params = SamplingParams(**SAMPLING_PARAMS) sampling_params.stop_token_ids = stop_token_ids or [6561] if max_tokens: @@ -104,49 +135,16 @@ class VllmQwen2LM(Qwen2LM): ): yield output.outputs[0] - - def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\ - -> Generator[CompletionOutput, None, None]: - assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]" - invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None) - assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}" - # logging.debug('prompt_token_ids:', prompt_token_ids) - # TODO: 增加上下文控制,取消请求时 - sampling_params = SamplingParams(**SAMPLING_PARAMS) - sampling_params.stop_token_ids = stop_token_ids or [6561] - if max_tokens: - sampling_params.max_tokens = max_tokens - - # 创建独立事件循环 - loop = asyncio.new_event_loop() - try: - asyncio.set_event_loop(loop) - # 初始化异步生成器 - async_gen = self.llm_engine.generate( - { - "prompt_token_ids": prompt_token_ids, - }, - sampling_params=sampling_params, - request_id=request_id or f"{time.time()}", - ) - while True: - try: - # 同步获取异步结果 - output = loop.run_until_complete(async_gen.__anext__()) - yield output.outputs[0] - except StopAsyncIteration: - break - except GeneratorExit: - if async_gen is not None: - loop.run_until_complete(async_gen.aclose()) - raise - finally: - # 资源清理 - print("资源清理...") - if async_gen is not None: - loop.run_until_complete(async_gen.aclose()) - loop.close() - print("资源清理成功") + def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None): + # 使用 同步转异步 会导致vllm崩溃,目前选择 queue 的方式,后台线程运行推理任务 + # 提交推理任务到队列中 + out_queue = queue.Queue() + self.task_queue.put((out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens)) + # 将 out_queue 的结果返回 + finished = False + while not finished: + (output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get() + yield output def inference( self, @@ -194,6 +192,9 @@ class VllmQwen2LM(Qwen2LM): max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, ) -> Generator[torch.Tensor, None, None]: + prompt_text = tensor_to_list(prompt_text + torch.tensor(6564)) + prompt_speech_token = tensor_to_list(prompt_speech_token) + last_tokens = [] prompt_token_ids = [self.sos_eos_token_id] text_tokens_cache = prompt_text @@ -202,18 +203,18 @@ class VllmQwen2LM(Qwen2LM): # text need tokens assert isinstance(this_text, list), "text need token ids List[int]." text_tokens_cache += this_text - while len(llm_prompt_speech_token) != 0: + while len(prompt_speech_token) != 0: if len(text_tokens_cache) >= self.mix_ratio[0]: text_input_token = text_tokens_cache[:self.mix_ratio[0]] - speech_input_token = llm_prompt_speech_token[:self.mix_ratio[1]] + speech_input_token = prompt_speech_token[:self.mix_ratio[1]] prompt_token_ids += text_input_token + speech_input_token # reset the last cache text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] - llm_prompt_speech_token = llm_prompt_speech_token[self.mix_ratio[1]:] + prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:] else: logging.info('not enough text token to decode, wait for more') break - if len(llm_prompt_speech_token) == 0: + if len(prompt_speech_token) == 0: if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1: logging.info('get fill token, need to append more text token') if len(text_tokens_cache) >= self.mix_ratio[0]: From 2fbeba50ae077cc0082eb9e9ebd8dc66eebb5df9 Mon Sep 17 00:00:00 2001 From: qihua Date: Sat, 8 Mar 2025 00:04:01 +0800 Subject: [PATCH 3/7] =?UTF-8?q?refactor(llm):=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E6=9C=AA=E4=BD=BF=E7=94=A8=E7=9A=84=E5=BC=82=E6=AD=A5=E6=8E=A8?= =?UTF-8?q?=E7=90=86=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 删除了 LLM 类中的 async_llm_inference 方法 - 该方法尚未使用,且再在loop_thread之外运行后会导致 vllm 崩溃,因此将其移除 --- cosyvoice/llm/llm_vllm.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py index 61b1090..3fd7152 100644 --- a/cosyvoice/llm/llm_vllm.py +++ b/cosyvoice/llm/llm_vllm.py @@ -120,21 +120,6 @@ class VllmQwen2LM(Qwen2LM): except Exception as e: logging.error(f"Error in inference_processor: {e}") - async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\ - -> AsyncGenerator[CompletionOutput, None]: - sampling_params = SamplingParams(**SAMPLING_PARAMS) - sampling_params.stop_token_ids = stop_token_ids or [6561] - if max_tokens: - sampling_params.max_tokens = max_tokens - async for output in self.llm_engine.generate( - { - "prompt_token_ids": prompt_token_ids, - }, - sampling_params=sampling_params, - request_id=request_id or f"{time.time()}", - ): - yield output.outputs[0] - def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None): # 使用 同步转异步 会导致vllm崩溃,目前选择 queue 的方式,后台线程运行推理任务 # 提交推理任务到队列中 From a1314e573a662d7284312c1ebfb99b5d651e46d6 Mon Sep 17 00:00:00 2001 From: qihua Date: Sat, 8 Mar 2025 00:40:17 +0800 Subject: [PATCH 4/7] =?UTF-8?q?chore:=20=E6=96=B0=E5=A2=9E=20requirements?= =?UTF-8?q?=5Fvllm.txt=20=E6=96=87=E4=BB=B6=EF=BC=8C=E6=8C=87=E5=AE=9AVLLM?= =?UTF-8?q?=20=E6=A8=A1=E5=9E=8B=E6=89=80=E9=9C=80=E7=9A=84=E4=BE=9D?= =?UTF-8?q?=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosyvoice/llm/llm_vllm.py | 2 +- requirements_vllm.txt | 40 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 requirements_vllm.txt diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py index 3fd7152..1e9bc28 100644 --- a/cosyvoice/llm/llm_vllm.py +++ b/cosyvoice/llm/llm_vllm.py @@ -103,7 +103,7 @@ class VllmQwen2LM(Qwen2LM): async def inference_processor(self, task_queue): while True: try: - print(f"inference_processor") + logging.debug(f"inference_processor") out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens = task_queue.get() sampling_params = SamplingParams(**SAMPLING_PARAMS) sampling_params.stop_token_ids = stop_token_ids or [6561] diff --git a/requirements_vllm.txt b/requirements_vllm.txt new file mode 100644 index 0000000..f3dcb25 --- /dev/null +++ b/requirements_vllm.txt @@ -0,0 +1,40 @@ +vllm==0.7.3 +pydantic==2.10.6 +torch==2.5.1 +torchaudio==2.5.1 + +conformer==0.3.2 + +diffusers==0.32.2 +gdown==5.1.0 +grpcio==1.57.0 +grpcio-tools==1.57.0 +hydra-core==1.3.2 +HyperPyYAML==1.2.2 +inflect==7.3.1 +librosa==0.10.2 + +lightning==2.5.0.post0 +matplotlib==3.7.5 +modelscope==1.15.0 + +networkx==3.4.2 +omegaconf==2.3.0 +onnx==1.17.0 + +onnxruntime-gpu==1.19.0; sys_platform == 'linux' + +#openai-whisper==20231117 +openai-whisper==20240930 +protobuf==4.25 +pyworld==0.3.4 +rich==13.7.1 +soundfile==0.12.1 +tensorboard==2.14.0 +wget==3.2 +WeTextProcessing==1.0.3 + +# trt use +tensorrt-cu12==10.0.1 +tensorrt-cu12-bindings==10.0.1 +tensorrt-cu12-libs==10.0.1 \ No newline at end of file From b4fe05d466a97995e1005f60d00115438d58296a Mon Sep 17 00:00:00 2001 From: qihua Date: Sat, 8 Mar 2025 00:41:34 +0800 Subject: [PATCH 5/7] =?UTF-8?q?docs:=20=E6=B7=BB=E5=8A=A0speed=5Ftest.ipyn?= =?UTF-8?q?b=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 speed_test.ipynb 文件,用于测试 CosyVoice2模型的性能 - 包含测试环境配置、默认情况下的使用示例、使用 vllm 加速 LLM 推理的步骤 --- speed_test.ipynb | 486 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 486 insertions(+) create mode 100644 speed_test.ipynb diff --git a/speed_test.ipynb b/speed_test.ipynb new file mode 100644 index 0000000..0444806 --- /dev/null +++ b/speed_test.ipynb @@ -0,0 +1,486 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 测试效果\n", + "\n", + "- 测试代码: [speed_test.ipynb](speed_test.ipynb)\n", + "- 测试环境: Intel i5-12400 CPU, 48GB RAM, 1x NVIDIA GeForce RTX 4070\n", + "- 运行环境: Ubuntu 24.04.1 LTS, cuda 12.4, python 3.10.16\n", + "- 测试说明: 单任务执行的数据(非并发测试)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 默认情况下使用" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import asyncio\n", + "import torchaudio\n", + "\n", + "import sys\n", + "sys.path.append('third_party/Matcha-TTS')\n", + "\n", + "from cosyvoice.cli.cosyvoice import CosyVoice2\n", + "from cosyvoice.utils.file_utils import load_wav\n", + "\n", + "prompt_text = '希望你以后能够做得比我还好哟'\n", + "prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)\n", + "\n", + "# cosyvoice = CosyVoice2('./pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=True)\n", + "cosyvoice = CosyVoice2('./pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, fp16=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 使用vllm加速llm推理\n", + "\n", + "#### 1. **安装依赖**\n", + "\n", + "(该依赖环境下可以运行原本cosyvoice2代码)\n", + "```bash\n", + "pip install -r requirements_vllm.txt\n", + "```\n", + "\n", + "#### 2. **文件复制**\n", + "将 pretrained_models/CosyVoice2-0.5B/CosyVoice-BlankEN 文件夹下的部分文件复制到下载的CosyVoice2-0.5B模型文件夹下,并替换 config.json 文件中的 Qwen2ForCausalLM 为 CosyVoice2Model。\n", + "```bash\n", + "cp pretrained_models/CosyVoice2-0.5B/CosyVoice-BlankEN/{config.json,tokenizer_config.json,vocab.json,merges.txt} pretrained_models/CosyVoice2-0.5B/\n", + "sed -i 's/Qwen2ForCausalLM/CosyVoice2Model/' pretrained_models/CosyVoice2-0.5B/config.json\n", + "```\n", + "\n", + "#### **注意:**\n", + "\n", + "- 使用 load_trt 后,需要进行 **预热** 10次推理以上,使用流式推理预热效果较好\n", + "- 在 jupyter notebook 中,如果要使用 **vllm** 运行下列代码,需要将vllm_use_cosyvoice2_model.py正确复制到 vllm 包中,并注册到 _VLLM_MODELS 字典中。运行下面的 code 完成" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "\n", + "# 获取vllm包的安装路径\n", + "try:\n", + " import vllm\n", + "except ImportError:\n", + " raise ImportError(\"vllm package not installed\")\n", + "\n", + "\n", + "vllm_path = os.path.dirname(vllm.__file__)\n", + "print(f\"vllm package path: {vllm_path}\")\n", + "\n", + "# 定义目标路径\n", + "target_dir = os.path.join(vllm_path, \"model_executor\", \"models\")\n", + "target_file = os.path.join(target_dir, \"cosyvoice2.py\")\n", + "\n", + "# 复制模型文件\n", + "source_file = \"./cosyvoice/llm/vllm_use_cosyvoice2_model.py\"\n", + "if not os.path.exists(source_file):\n", + " raise FileNotFoundError(f\"Source file {source_file} not found\")\n", + "\n", + "shutil.copy(source_file, target_file)\n", + "print(f\"Copied {source_file} to {target_file}\")\n", + "\n", + "# 修改registry.py文件\n", + "registry_path = os.path.join(target_dir, \"registry.py\")\n", + "new_entry = ' \"CosyVoice2Model\": (\"cosyvoice2\", \"CosyVoice2Model\"), # noqa: E501\\n'\n", + "\n", + "# 读取并修改文件内容\n", + "with open(registry_path, \"r\") as f:\n", + " lines = f.readlines()\n", + "\n", + "# 检查是否已存在条目\n", + "entry_exists = any(\"CosyVoice2Model\" in line for line in lines)\n", + "\n", + "if not entry_exists:\n", + " # 寻找插入位置\n", + " insert_pos = None\n", + " for i, line in enumerate(lines):\n", + " if line.strip().startswith(\"**_FALLBACK_MODEL\"):\n", + " insert_pos = i + 1\n", + " break\n", + " \n", + " if insert_pos is None:\n", + " raise ValueError(\"Could not find insertion point in registry.py\")\n", + " \n", + " # 插入新条目\n", + " lines.insert(insert_pos, new_entry)\n", + " \n", + " # 写回文件\n", + " with open(registry_path, \"w\") as f:\n", + " f.writelines(lines)\n", + " print(\"Successfully updated registry.py\")\n", + "else:\n", + " print(\"Entry already exists in registry.py, skipping modification\")\n", + "\n", + "print(\"All operations completed successfully!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "failed to import ttsfrd, use WeTextProcessing instead\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n", + "/opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/diffusers/models/lora.py:393: FutureWarning: `LoRACompatibleLinear` is deprecated and will be removed in version 1.0.0. Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`.\n", + " deprecate(\"LoRACompatibleLinear\", \"1.0.0\", deprecation_message)\n", + "2025-03-08 00:37:04,867 INFO input frame rate=25\n", + "/opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:115: UserWarning: Specified provider 'CUDAExecutionProvider' is not in available provider names.Available providers: 'AzureExecutionProvider, CPUExecutionProvider'\n", + " warnings.warn(\n", + "2025-03-08 00:37:06,103 WETEXT INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_tagger.fst\n", + "2025-03-08 00:37:06,103 INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_tagger.fst\n", + "2025-03-08 00:37:06,104 WETEXT INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_verbalizer.fst\n", + "2025-03-08 00:37:06,104 INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/zh_tn_verbalizer.fst\n", + "2025-03-08 00:37:06,104 WETEXT INFO skip building fst for zh_normalizer ...\n", + "2025-03-08 00:37:06,104 INFO skip building fst for zh_normalizer ...\n", + "2025-03-08 00:37:06,313 WETEXT INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_tagger.fst\n", + "2025-03-08 00:37:06,313 INFO found existing fst: /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_tagger.fst\n", + "2025-03-08 00:37:06,314 WETEXT INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_verbalizer.fst\n", + "2025-03-08 00:37:06,314 INFO /opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/tn/en_tn_verbalizer.fst\n", + "2025-03-08 00:37:06,314 WETEXT INFO skip building fst for en_normalizer ...\n", + "2025-03-08 00:37:06,314 INFO skip building fst for en_normalizer ...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 03-08 00:37:07 __init__.py:207] Automatically detected platform cuda.\n", + "WARNING 03-08 00:37:07 registry.py:352] Model architecture CosyVoice2Model is already registered, and will be overwritten by the new model class .\n", + "WARNING 03-08 00:37:07 config.py:2517] Casting torch.bfloat16 to torch.float16.\n", + "INFO 03-08 00:37:07 config.py:560] This model supports multiple tasks: {'embed', 'classify', 'reward', 'generate', 'score'}. Defaulting to 'generate'.\n", + "INFO 03-08 00:37:07 config.py:1624] Chunked prefill is enabled with max_num_batched_tokens=1024.\n", + "WARNING 03-08 00:37:08 utils.py:2164] CUDA was previously initialized. We must use the `spawn` multiprocessing start method. Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#python-multiprocessing for more information.\n", + "INFO 03-08 00:37:10 __init__.py:207] Automatically detected platform cuda.\n", + "INFO 03-08 00:37:11 core.py:50] Initializing a V1 LLM engine (v0.7.3.dev213+gede41bc7.d20250219) with config: model='./pretrained_models/CosyVoice2-0.5B', speculative_config=None, tokenizer='./pretrained_models/CosyVoice2-0.5B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=1024, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=./pretrained_models/CosyVoice2-0.5B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={\"level\":3,\"custom_ops\":[\"none\"],\"splitting_ops\":[\"vllm.unified_attention\",\"vllm.unified_attention_with_output\"],\"use_inductor\":true,\"compile_sizes\":[],\"use_cudagraph\":true,\"cudagraph_num_of_warmups\":1,\"cudagraph_capture_sizes\":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],\"max_capture_size\":512}\n", + "WARNING 03-08 00:37:11 utils.py:2298] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,list_loras,load_config,pin_lora,remove_lora,scheduler_config not implemented in \n", + "INFO 03-08 00:37:11 parallel_state.py:948] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0\n", + "INFO 03-08 00:37:11 gpu_model_runner.py:1055] Starting to load model ./pretrained_models/CosyVoice2-0.5B...\n", + "INFO 03-08 00:37:11 cuda.py:157] Using Flash Attention backend on V1 engine.\n", + "WARNING 03-08 00:37:11 topk_topp_sampler.py:46] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.\n", + "WARNING 03-08 00:37:11 rejection_sampler.py:47] FlashInfer is not available. Falling back to the PyTorch-native implementation of rejection sampling. For the best performance, please install FlashInfer.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/anaconda3/envs/cosyvoice/lib/python3.10/site-packages/torch/utils/_device.py:106: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " return func(*args, **kwargs)\n", + "Loading pt checkpoint shards: 0% Completed | 0/1 [00:00\n", + "2025-03-08 00:39:03,237 INFO not enough text token to decode, wait for more\n", + "2025-03-08 00:39:03,252 INFO get fill token, need to append more text token\n", + "2025-03-08 00:39:03,253 INFO append 5 text token\n", + "2025-03-08 00:39:03,311 INFO get fill token, need to append more text token\n", + "2025-03-08 00:39:03,312 INFO append 5 text token\n", + "2025-03-08 00:39:03,456 INFO no more text token, decode until met eos\n", + "2025-03-08 00:39:04,861 INFO yield speech len 15.16, rtf 0.1072180145334128\n", + "100%|██████████| 1/1 [00:01<00:00, 1.88s/it]\n" + ] + } + ], + "source": [ + "def text_generator():\n", + " yield '收到好友从远方寄来的生日礼物,'\n", + " yield '那份意外的惊喜与深深的祝福'\n", + " yield '让我心中充满了甜蜜的快乐,'\n", + " yield '笑容如花儿般绽放。'\n", + "\n", + " \n", + "for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), prompt_text, prompt_speech_16k, stream=False)):\n", + " torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-03-08 00:39:04,878 INFO get tts_text generator, will skip text_normalize!\n", + " 0%| | 0/1 [00:00\n", + "2025-03-08 00:39:05,152 INFO not enough text token to decode, wait for more\n", + "2025-03-08 00:39:05,169 INFO get fill token, need to append more text token\n", + "2025-03-08 00:39:05,169 INFO append 5 text token\n", + "2025-03-08 00:39:05,292 INFO get fill token, need to append more text token\n", + "2025-03-08 00:39:05,293 INFO append 5 text token\n", + "2025-03-08 00:39:05,438 INFO no more text token, decode until met eos\n", + "2025-03-08 00:39:05,638 INFO yield speech len 1.84, rtf 0.26492670826289966\n", + "2025-03-08 00:39:05,841 INFO yield speech len 2.0, rtf 0.10065567493438721\n", + "2025-03-08 00:39:06,164 INFO yield speech len 2.0, rtf 0.16065263748168945\n", + "2025-03-08 00:39:06,422 INFO yield speech len 2.0, rtf 0.12791669368743896\n", + "2025-03-08 00:39:06,697 INFO yield speech len 2.0, rtf 0.13690149784088135\n", + "2025-03-08 00:39:06,998 INFO yield speech len 2.0, rtf 0.14957869052886963\n", + "2025-03-08 00:39:07,335 INFO yield speech len 1.0, rtf 0.3356931209564209\n", + "100%|██████████| 1/1 [00:02<00:00, 2.46s/it]\n" + ] + } + ], + "source": [ + "def text_generator():\n", + " yield '收到好友从远方寄来的生日礼物,'\n", + " yield '那份意外的惊喜与深深的祝福'\n", + " yield '让我心中充满了甜蜜的快乐,'\n", + " yield '笑容如花儿般绽放。'\n", + "for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), prompt_text, prompt_speech_16k, stream=True)):\n", + " torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/1 [00:00 Date: Sat, 8 Mar 2025 10:41:49 +0800 Subject: [PATCH 6/7] =?UTF-8?q?refactor(llm):=20=E9=87=8D=E6=9E=84=20vLLM?= =?UTF-8?q?=20=E6=8E=A8=E7=90=86=E4=BB=BB=E5=8A=A1=E5=A4=84=E7=90=86?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=EF=BC=8C=E6=94=AF=E6=8C=81=E5=A4=9A=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除任务队列和单任务处理限制 - 使用 asyncio.run_coroutine_threadsafe() 在后台线程中运行推理任务 --- cosyvoice/llm/llm_vllm.py | 47 +++++++++++++++------------------------ 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py index 1e9bc28..839bf88 100644 --- a/cosyvoice/llm/llm_vllm.py +++ b/cosyvoice/llm/llm_vllm.py @@ -86,46 +86,35 @@ class VllmQwen2LM(Qwen2LM): self.task_token_id = self.sos_eos_token_id + 1 self.zero_token_id = self.task_token_id + 1 - # 不能直接在同步函数正确的使用 异步的生成器函数,即使使用协程也会对vllm造成崩溃 - # 使用 queue 的方式,后台线程运行推理任务 - self.task_queue = queue.Queue() + # vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务 self.loop = asyncio.new_event_loop() self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True) self.loop_thread.start() - # 运行后台协程,用于处理任务队列中的任务 - # TODO: 目前只能单任务运行,多任务运行需要对 inference_processor 进行修改 - asyncio.run_coroutine_threadsafe(self.inference_processor(self.task_queue), self.loop) def _run_event_loop(self): asyncio.set_event_loop(self.loop) self.loop.run_forever() - async def inference_processor(self, task_queue): - while True: - try: - logging.debug(f"inference_processor") - out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens = task_queue.get() - sampling_params = SamplingParams(**SAMPLING_PARAMS) - sampling_params.stop_token_ids = stop_token_ids or [6561] - if max_tokens: - sampling_params.max_tokens = max_tokens - async for output in self.llm_engine.generate( - { - "prompt_token_ids": prompt_token_ids, - }, - sampling_params=sampling_params, - request_id=request_id or f"{time.time()}", - ): - out_queue.put((output.outputs[0], output.finished)) - except Exception as e: - logging.error(f"Error in inference_processor: {e}") + async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens): + sampling_params = SamplingParams(**SAMPLING_PARAMS) + sampling_params.stop_token_ids = stop_token_ids or [6561] + if max_tokens: + sampling_params.max_tokens = max_tokens + async for output in self.llm_engine.generate( + { + "prompt_token_ids": prompt_token_ids, + }, + sampling_params=sampling_params, + request_id=request_id or f"{time.time()}", + ): + out_queue.put((output.outputs[0], output.finished)) def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None): - # 使用 同步转异步 会导致vllm崩溃,目前选择 queue 的方式,后台线程运行推理任务 - # 提交推理任务到队列中 out_queue = queue.Queue() - self.task_queue.put((out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens)) - # 将 out_queue 的结果返回 + asyncio.run_coroutine_threadsafe( + self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop + ) + # 接收 out_queue 返回的结果 finished = False while not finished: (output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get() From c0f6a474f36643fb061338115e71593776d8345b Mon Sep 17 00:00:00 2001 From: qihua Date: Sat, 8 Mar 2025 16:03:35 +0800 Subject: [PATCH 7/7] =?UTF-8?q?fix(async=5Fcosyvoice):=20=E6=81=A2?= =?UTF-8?q?=E5=A4=8D=E5=8E=9F=E6=9C=AC=E6=96=87=E6=9C=AC=E4=BB=A4=E7=89=8C?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 Frontend 中,恢复原本逐个生成文本令牌 - 在 Model 类中,移除了不必要的日志信息和断言,简化了文本令牌的处理流程 --- cosyvoice/cli/frontend.py | 5 ++--- cosyvoice/llm/llm_vllm.py | 11 ----------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 5aa2d34..834f0b0 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -102,9 +102,8 @@ class CosyVoiceFrontEnd: def _extract_text_token_generator(self, text_generator): for text in text_generator: text_token, _ = self._extract_text_token(text) - # for i in range(text_token.shape[1]): - # yield text_token[:, i: i + 1] - yield text_token + for i in range(text_token.shape[1]): + yield text_token[:, i: i + 1] def _extract_speech_token(self, speech): assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py index 839bf88..a864a04 100644 --- a/cosyvoice/llm/llm_vllm.py +++ b/cosyvoice/llm/llm_vllm.py @@ -149,8 +149,6 @@ class VllmQwen2LM(Qwen2LM): need_add_tokens = output.token_ids[:-1] else: need_add_tokens = output.token_ids - # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 - # yield need_add_tokens for token in need_add_tokens: yield token @@ -186,18 +184,14 @@ class VllmQwen2LM(Qwen2LM): text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:] else: - logging.info('not enough text token to decode, wait for more') break if len(prompt_speech_token) == 0: if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1: - logging.info('get fill token, need to append more text token') if len(text_tokens_cache) >= self.mix_ratio[0]: text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]] prompt_token_ids += text_tokens_temp - logging.info('append {} text token'.format(len(text_tokens_temp))) text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] else: - logging.info('not enough text token to decode, wait for more') continue for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]): last_tokens = output.token_ids @@ -205,19 +199,14 @@ class VllmQwen2LM(Qwen2LM): need_add_tokens = last_tokens[:-1] else: need_add_tokens = last_tokens - # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 - # yield need_add_tokens for token in need_add_tokens: yield token prompt_token_ids.extend(need_add_tokens) prompt_token_ids += text_tokens_cache + [self.task_token_id] - logging.info('no more text token, decode until met eos') for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]): if output.token_ids[-1] == 6561: need_add_tokens = output.token_ids[:-1] else: need_add_tokens = output.token_ids - # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 - # yield need_add_tokens for token in need_add_tokens: yield token