From 9f55c5af8fb336935b87f708a08a5ef788d22a74 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 9 May 2025 09:45:01 +0000 Subject: [PATCH 1/4] add vllm export --- cosyvoice/cli/cosyvoice.py | 2 + cosyvoice/cli/model.py | 2 +- cosyvoice/llm/vllm_use_cosyvoice2_model.py | 263 --------------------- cosyvoice/utils/file_utils.py | 44 +++- requirements.txt | 2 +- 5 files changed, 47 insertions(+), 266 deletions(-) delete mode 100644 cosyvoice/llm/vllm_use_cosyvoice2_model.py diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index a51a304..fb1cd7f 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -166,6 +166,8 @@ class CosyVoice2(CosyVoice): self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) + if load_vllm: + self.model.load_vllm('{}/vllm'.format(model_dir)) if load_jit: self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) if load_trt: diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index c1e441f..19bedd3 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -23,7 +23,7 @@ from torch.nn import functional as F from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out -from cosyvoice.utils.file_utils import convert_onnx_to_trt +from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm from cosyvoice.utils.common import TrtContextWrapper diff --git a/cosyvoice/llm/vllm_use_cosyvoice2_model.py b/cosyvoice/llm/vllm_use_cosyvoice2_model.py deleted file mode 100644 index 6e36ef3..0000000 --- a/cosyvoice/llm/vllm_use_cosyvoice2_model.py +++ /dev/null @@ -1,263 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py -# Copyright 2024 The Qwen team. -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any -from typing_extensions import TypeVar - -import torch -from torch import nn - -from vllm.attention import AttentionMetadata -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from vllm.model_executor.models.interfaces import T -from vllm.model_executor.models.qwen2 import Qwen2Model - -from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings - -logger = init_logger(__name__) - -IGNORE_ID = -1 - - -class CosyVoice2Model(nn.Module): - - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - self.lora_config = lora_config - self.quant_config = quant_config - - self.llm_input_size = 896 - self.llm_output_size = 896 - - self.speech_token_size = 6561+3 - self.llm_token_size = config.vocab_size - - # 2. build speech token language model related modules - self.sos_eos = 0 - self.task_id = 1 - self.fill_token = 2 - - - self.allow_patterns_overrides = ["llm.*"] - self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size) - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - # self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size) - self.llm_decoder = ParallelLMHead(self.speech_token_size, - self.llm_output_size, - bias=True, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "llm_decoder")) - self.logits_processor = LogitsProcessor(self.speech_token_size) - - # length_normalized_loss: bool = True, - # lsm_weight: float = 0.0, - # self.criterion_ce = LabelSmoothingLoss( - # size=self.speech_token_size, - # padding_idx=IGNORE_ID, - # smoothing=lsm_weight, - # normalize_length=length_normalized_loss, - # ) - - # 3. [Optional] build speech token related modules - self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size) - - # 4. sampling method - ## use vllm sampling method - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - self.mix_ratio: List[int] = [5, 15] - - # 定义特殊token常量 - self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32) - self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404 - self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405 - self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32) - - self.zero_embed_buffer = torch.zeros( - (vllm_config.scheduler_config.max_num_seqs, self.llm_input_size), - dtype=self.llm_embedding.weight.dtype, - device=self.llm_embedding.weight.device - ) - self.inputs_embed_buffer = torch.zeros( - (vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size), - dtype=self.llm_embedding.weight.dtype, - device=self.llm_embedding.weight.device, - ) - - def get_sos_eos_emb(self): - return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) - - def get_task_id_emb(self): - return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[T] = None, - attn_metadata: Optional["AttentionMetadata"] = None, - ) -> torch.Tensor: - """ - Returns the input embeddings merged from the text embeddings from - input_ids and the multimodal embeddings generated from multimodal - kwargs. - """ - # 创建掩码,标记哪些 token_id 属于音频 Token - mask = input_ids < self.speech_token_size - - # 获取 input_ids 的原始形状 - input_shape = input_ids.shape - # 展平 input_ids 和掩码以便统一处理 - flat_input_ids = input_ids.view(-1) - flat_mask = mask.view(-1) - - inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]] - inputs_embeds.zero_() - - # Process speech tokens - if flat_mask.any(): - speech_token_ids = flat_input_ids[flat_mask] - inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids) - - # 处理大于 delta 的 token_id - if (~flat_mask).any(): - llm_token_ids = flat_input_ids[~flat_mask] - llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask]) - - sos_eos_mask = llm_token_ids == self.sos_eos_token_id - task_mask = llm_token_ids == self.task_token_id - zero_mask = llm_token_ids == self.zero_token_id - normal_mask = ~(sos_eos_mask | task_mask | zero_mask) - - # 分层处理逻辑 - # 第一优先级:SOS/EOS标记 - if sos_eos_mask.any(): - llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0) - - # 第二优先级:任务标记 - if task_mask.any(): - llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0) - - # 第二优先级:空音频标记 - if zero_mask.any(): - llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])] - - # 常规LLM token - if normal_mask.any(): - original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta - # print('original_ids: ',original_ids) - llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids) - - inputs_embeds[~flat_mask] = llm_embeds - - inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size) - - # 合并多模态嵌入(如果有) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.audio_token_index - ) - return inputs_embeds - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings( - input_ids, - attn_metadata=attn_metadata, - ) - return self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.llm_decoder, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - @staticmethod - def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]: - for name, param in weights: - # 处理Qwen2Model核心参数 - if name.startswith("llm."): - if name.startswith("llm.model.model."): - name = name.replace("llm.model.model.", "model.") - else: - continue - # print('weights name: ', name) - yield name, param - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - weights = self.convert_weights(weights) - loader = AutoWeightsLoader(self) - loader.load_weights(weights) diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index ae860c9..fb849e6 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -1,5 +1,6 @@ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) # 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu) +# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import json -import torchaudio +import torch, torchaudio import logging logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.basicConfig(level=logging.DEBUG, @@ -83,3 +85,43 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): with open(trt_model, "wb") as f: f.write(engine_bytes) logging.info("Succesfully convert onnx to trt...") + + +def export_cosyvoice2_vllm(model, model_path, device): + if os.path.exists(model_path): + return + pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64 + vocab_size = model.speech_embedding.num_embeddings + feature_size = model.speech_embedding.embedding_dim + pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to + + dtype = torch.bfloat16 + # lm_head + new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True) + with torch.no_grad(): + new_lm_head.weight[:vocab_size] = model.llm_decoder.weight + new_lm_head.bias[:vocab_size] = model.llm_decoder.bias + new_lm_head.weight[vocab_size:] = 0 + new_lm_head.bias[vocab_size:] = 0 + model.llm.model.lm_head = new_lm_head + new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size) + # embed_tokens + embed_tokens = model.llm.model.model.embed_tokens + with torch.no_grad(): + new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight + new_codec_embed.weight[vocab_size:] = 0 + model.llm.model.set_input_embeddings(new_codec_embed) + model.llm.model.to(device) + model.llm.model.to(dtype) + tmp_vocab_size = model.llm.model.config.vocab_size + tmp_tie_embedding = model.llm.model.config.tie_word_embeddings + del model.llm.model.generation_config.eos_token_id + del model.llm.model.config.bos_token_id + del model.llm.model.config.eos_token_id + model.llm.model.config.vocab_size = pad_vocab_size + model.llm.model.config.tie_word_embeddings = False + model.llm.model.config.use_bias = True + model.llm.model.save_pretrained(model_path) + model.llm.model.config.vocab_size = tmp_vocab_size + model.llm.model.config.tie_word_embeddings = tmp_tie_embedding + model.llm.model.set_input_embeddings(embed_tokens) diff --git a/requirements.txt b/requirements.txt index 4166dac..781a8fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684 conformer==0.3.2 -deepspeed==0.14.2; sys_platform == 'linux' +deepspeed==0.15.1; sys_platform == 'linux' diffusers==0.29.0 gdown==5.1.0 gradio==5.4.0 From 6dd68b9d5eee233c90b10d696a19e9882d058b2e Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Fri, 30 May 2025 07:22:35 +0000 Subject: [PATCH 2/4] add vllm inference --- cosyvoice/cli/cosyvoice.py | 2 +- cosyvoice/cli/model.py | 33 ++++++-------- cosyvoice/flow/flow_matching.py | 44 ++++++++++--------- cosyvoice/llm/llm.py | 76 +++++++++++++++++++++++++-------- cosyvoice/utils/common.py | 11 ++--- cosyvoice/utils/file_utils.py | 3 +- 6 files changed, 105 insertions(+), 64 deletions(-) diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index fb1cd7f..71351a2 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -140,7 +140,7 @@ class CosyVoice: class CosyVoice2(CosyVoice): - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): + def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 19bedd3..6ebbe52 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -59,9 +59,6 @@ class CosyVoiceModel: self.stream_scale_factor = 1 assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() - self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) - for _ in range(trt_concurrent): - self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()) self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} @@ -69,7 +66,6 @@ class CosyVoiceModel: self.mel_overlap_dict = {} self.flow_cache_dict = {} self.hift_cache_dict = {} - self.trt_context_dict = {} def load(self, llm_model, flow_model, hift_model): self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) @@ -98,7 +94,7 @@ class CosyVoiceModel: with open(flow_decoder_estimator_model, 'rb') as f: estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) - self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent, device=self.device) def get_trt_kwargs(self): min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] @@ -125,7 +121,8 @@ class CosyVoiceModel: prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device), prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), - embedding=llm_embedding.to(self.device)): + embedding=llm_embedding.to(self.device), + uuid=uuid): self.tts_speech_token_dict[uuid].append(i) self.llm_end_dict[uuid] = True @@ -180,13 +177,11 @@ class CosyVoiceModel: prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) - this_trt_context = self.trt_context_pool.get() with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) - self.trt_context_dict[this_uuid] = this_trt_context if source_speech_token.shape[1] == 0: p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) else: @@ -240,8 +235,6 @@ class CosyVoiceModel: self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) - self.trt_context_pool.put(self.trt_context_dict[this_uuid]) - self.trt_context_dict.pop(this_uuid) if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.current_stream().synchronize() @@ -273,22 +266,28 @@ class CosyVoice2Model(CosyVoiceModel): self.speech_window = np.hamming(2 * self.source_cache_len) # rtf and decoding related self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() - self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) - for _ in range(trt_concurrent): - self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()) self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} self.hift_cache_dict = {} - self.trt_context_dict = {} def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder + def load_vllm(self, model_dir): + export_cosyvoice2_vllm(self.llm, model_dir, self.device) + from vllm import EngineArgs, LLMEngine + engine_args = EngineArgs(model=model_dir, + skip_tokenizer_init=True, + enable_prompt_embeds=True, + gpu_memory_utilization=0.2) + self.llm.vllm = LLMEngine.from_engine_args(engine_args) + del self.llm.llm.model.model.layers + def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): - with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]: + with torch.cuda.amp.autocast(self.fp16): tts_mel, _ = self.flow.inference(token=token.to(self.device), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), prompt_token=prompt_token.to(self.device), @@ -330,11 +329,9 @@ class CosyVoice2Model(CosyVoiceModel): prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) - this_trt_context = self.trt_context_pool.get() with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None - self.trt_context_dict[this_uuid] = this_trt_context if source_speech_token.shape[1] == 0: p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) else: @@ -388,8 +385,6 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) - self.trt_context_pool.put(self.trt_context_dict[this_uuid]) - self.trt_context_dict.pop(this_uuid) if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.current_stream().synchronize() diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 704ced3..9f7d0be 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -16,6 +16,7 @@ import threading import torch import torch.nn.functional as F from matcha.models.components.flow_matching import BASECFM +from cosyvoice.utils.common import set_all_random_seed class ConditionalCFM(BASECFM): @@ -32,7 +33,6 @@ class ConditionalCFM(BASECFM): in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) # Just change the architecture of the estimator here self.estimator = estimator - self.lock = threading.Lock() @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)): @@ -127,26 +127,27 @@ class ConditionalCFM(BASECFM): if isinstance(self.estimator, torch.nn.Module): return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) else: - estimator, trt_engine = self.estimator.acquire_estimator() - estimator.set_input_shape('x', (2, 80, x.size(2))) - estimator.set_input_shape('mask', (2, 1, x.size(2))) - estimator.set_input_shape('mu', (2, 80, x.size(2))) - estimator.set_input_shape('t', (2,)) - estimator.set_input_shape('spks', (2, 80)) - estimator.set_input_shape('cond', (2, 80, x.size(2))) - data_ptrs = [x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - x.data_ptr()] - for i, j in enumerate(data_ptrs): - estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) - # run trt engine - assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True - torch.cuda.current_stream().synchronize() - self.estimator.release_estimator(estimator) + [estimator, stream], trt_engine = self.estimator.acquire_estimator() + with stream: + estimator.set_input_shape('x', (2, 80, x.size(2))) + estimator.set_input_shape('mask', (2, 1, x.size(2))) + estimator.set_input_shape('mu', (2, 80, x.size(2))) + estimator.set_input_shape('t', (2,)) + estimator.set_input_shape('spks', (2, 80)) + estimator.set_input_shape('cond', (2, 80, x.size(2))) + data_ptrs = [x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()] + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.estimator.release_estimator(estimator, stream) return x def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): @@ -194,6 +195,7 @@ class ConditionalCFM(BASECFM): class CausalConditionalCFM(ConditionalCFM): def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator) + set_all_random_seed(0) self.rand_noise = torch.randn([1, 80, 50 * 300]) @torch.inference_mode() diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 670ae69..c5899ac 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import queue import random +import time +import threading from typing import Dict, Optional, Callable, List, Generator import torch from torch import nn @@ -170,6 +173,7 @@ class TransformerLM(torch.nn.Module): sampling: int = 25, max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, + uuid: str = '', ) -> Generator[torch.Tensor, None, None]: device = text.device text = torch.concat([prompt_text, text], dim=1) @@ -270,7 +274,6 @@ class Qwen2LM(TransformerLM): self.llm_input_size = llm_input_size self.llm_output_size = llm_output_size self.speech_token_size = speech_token_size - # 2. build speech token language model related modules self.sos_eos = 0 self.task_id = 1 @@ -292,6 +295,11 @@ class Qwen2LM(TransformerLM): # 4. sampling method self.sampling = sampling self.mix_ratio = mix_ratio + + # 5. vllm related + self.stop_token_ids = [speech_token_size + i for i in range(3)] + self.vllm_output_queue = {} + self.lock = threading.Lock() def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len): lm_target, lm_input = [], [] @@ -382,6 +390,7 @@ class Qwen2LM(TransformerLM): sampling: int = 25, max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, + uuid: str = '', ) -> Generator[torch.Tensor, None, None]: device = text.device text = torch.concat([prompt_text, text], dim=1) @@ -402,22 +411,55 @@ class Qwen2LM(TransformerLM): max_len = int((text_len - prompt_text_len) * max_token_text_ratio) # 5. step by step decode - out_tokens = [] - cache = None - for i in range(max_len): - y_pred, cache = self.llm.forward_one_step(lm_input, - masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), - cache=cache) - logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) - top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() - if top_ids == self.speech_token_size: - break - if top_ids > self.speech_token_size: - continue - # in stream mode, yield token one by one - yield top_ids - out_tokens.append(top_ids) - lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid): + yield token + + @torch.inference_mode() + def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid): + if hasattr(self, 'vllm'): + from vllm import SamplingParams, RequestOutput + sampling_params = SamplingParams(top_k=sampling, + stop_token_ids=self.stop_token_ids, + min_tokens=min_len, + max_tokens=max_len) + with self.lock: + self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params) + self.vllm_output_queue[uuid] = queue.Queue() + out_tokens = [] + while True: + with self.lock: + if self.vllm_output_queue[uuid].empty() is True: + request_outputs: List[RequestOutput] = self.vllm.step() + for request_output in request_outputs: + top_ids = list(request_output.outputs[0].token_ids)[-1] + self.vllm_output_queue[request_output.request_id].put(top_ids) + if self.vllm_output_queue[uuid].empty() is False: + top_ids = self.vllm_output_queue[uuid].get() + if top_ids in self.stop_token_ids: + break + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + time.sleep(0.001) + with self.lock: + self.vllm_output_queue.pop(uuid) + else: + out_tokens = [] + cache = None + for i in range(max_len): + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + if top_ids > self.speech_token_size: + continue + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) @torch.inference_mode() def inference_bistream( diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 088ca69..6f5a3dd 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -169,17 +169,18 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: class TrtContextWrapper: - def __init__(self, trt_engine, trt_concurrent=1): - self.trt_context_pool = queue.Queue() + def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) self.trt_engine = trt_engine for _ in range(trt_concurrent): trt_context = trt_engine.create_execution_context() + trt_stream = torch.cuda.stream(torch.cuda.Stream(device)) assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) - self.trt_context_pool.put(trt_context) + self.trt_context_pool.put([trt_context, trt_stream]) assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' def acquire_estimator(self): return self.trt_context_pool.get(), self.trt_engine - def release_estimator(self, context): - self.trt_context_pool.put(context) + def release_estimator(self, context, stream): + self.trt_context_pool.put([context, stream]) diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index fb849e6..1fbddae 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -58,7 +58,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): network = builder.create_network(network_flags) parser = trt.OnnxParser(network, logger) config = builder.create_builder_config() - config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) # 1GB + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB if fp16: config.set_flag(trt.BuilderFlag.FP16) profile = builder.create_optimization_profile() @@ -122,6 +122,7 @@ def export_cosyvoice2_vllm(model, model_path, device): model.llm.model.config.tie_word_embeddings = False model.llm.model.config.use_bias = True model.llm.model.save_pretrained(model_path) + os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path))) model.llm.model.config.vocab_size = tmp_vocab_size model.llm.model.config.tie_word_embeddings = tmp_tie_embedding model.llm.model.set_input_embeddings(embed_tokens) From 9b052a94c45a6d00781bb36f8fea558984a18111 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Fri, 30 May 2025 07:51:49 +0000 Subject: [PATCH 3/4] fix lint --- .github/workflows/lint.yml | 2 +- cosyvoice/cli/cosyvoice.py | 6 +- cosyvoice/cli/model.py | 15 +-- cosyvoice/flow/flow_matching.py | 13 +- cosyvoice/llm/llm.py | 7 +- cosyvoice/llm/llm_vllm.py | 212 -------------------------------- cosyvoice/utils/file_utils.py | 3 +- cosyvoice/vllm/cosyvoice2.py | 103 ++++++++++++++++ 8 files changed, 125 insertions(+), 236 deletions(-) delete mode 100644 cosyvoice/llm/llm_vllm.py create mode 100644 cosyvoice/vllm/cosyvoice2.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 884011d..ef28761 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -52,5 +52,5 @@ jobs: set -eux pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0 flake8 --version - flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py + flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py if [ $? != 0 ]; then exit 1; fi \ No newline at end of file diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 71351a2..cc443be 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -48,7 +48,7 @@ class CosyVoice: if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') - self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent) + self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16) self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) @@ -59,6 +59,7 @@ class CosyVoice: if load_trt: self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + trt_concurrent, self.fp16) del configs @@ -162,7 +163,7 @@ class CosyVoice2(CosyVoice): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') - self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent) + self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) @@ -173,6 +174,7 @@ class CosyVoice2(CosyVoice): if load_trt: self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + trt_concurrent, self.fp16) del configs diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 6ebbe52..0a1068c 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -14,7 +14,6 @@ # limitations under the License. import os from typing import Generator -import queue import torch import numpy as np import threading @@ -33,14 +32,12 @@ class CosyVoiceModel: llm: torch.nn.Module, flow: torch.nn.Module, hift: torch.nn.Module, - fp16: bool = False, - trt_concurrent: int = 1): + fp16: bool = False): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift self.fp16 = fp16 - self.trt_concurrent = trt_concurrent if self.fp16 is True: self.llm.half() self.flow.half() @@ -85,7 +82,7 @@ class CosyVoiceModel: flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder - def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16): assert torch.cuda.is_available(), 'tensorrt only supports gpu!' if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) @@ -94,7 +91,7 @@ class CosyVoiceModel: with open(flow_decoder_estimator_model, 'rb') as f: estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) - self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent, device=self.device) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) def get_trt_kwargs(self): min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] @@ -104,7 +101,7 @@ class CosyVoiceModel: return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): - with self.llm_context, torch.cuda.amp.autocast(self.fp16): + with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False): if isinstance(text, Generator): assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' for i in self.llm.inference_bistream(text=text, @@ -246,14 +243,12 @@ class CosyVoice2Model(CosyVoiceModel): llm: torch.nn.Module, flow: torch.nn.Module, hift: torch.nn.Module, - fp16: bool = False, - trt_concurrent: int = 1): + fp16: bool = False): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift self.fp16 = fp16 - self.trt_concurrent = trt_concurrent if self.fp16 is True: self.llm.half() self.flow.half() diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 9f7d0be..39b3415 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import threading import torch import torch.nn.functional as F from matcha.models.components.flow_matching import BASECFM @@ -136,12 +135,12 @@ class ConditionalCFM(BASECFM): estimator.set_input_shape('spks', (2, 80)) estimator.set_input_shape('cond', (2, 80, x.size(2))) data_ptrs = [x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - x.data_ptr()] + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()] for i, j in enumerate(data_ptrs): estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) # run trt engine diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index c5899ac..a316e5d 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -295,7 +296,7 @@ class Qwen2LM(TransformerLM): # 4. sampling method self.sampling = sampling self.mix_ratio = mix_ratio - + # 5. vllm related self.stop_token_ids = [speech_token_size + i for i in range(3)] self.vllm_output_queue = {} @@ -448,8 +449,8 @@ class Qwen2LM(TransformerLM): cache = None for i in range(max_len): y_pred, cache = self.llm.forward_one_step(lm_input, - masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), - cache=cache) + masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), + cache=cache) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() if top_ids == self.speech_token_size: diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py deleted file mode 100644 index a864a04..0000000 --- a/cosyvoice/llm/llm_vllm.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import time -import queue -import asyncio -import threading -from typing import List, Generator, AsyncGenerator -import torch -from cosyvoice.utils.file_utils import logging -from cosyvoice.llm.llm import Qwen2LM - -# 启用vllm V1版本 -import os -os.environ["VLLM_USE_V1"] = '1' -from vllm import ModelRegistry -from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput -from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs -from vllm.sampling_params import SamplingParams - -from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM -ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM) - -# EngineArgs -ENGINE_ARGS = { - "block_size": 16, - "swap_space": 0, - # "enforce_eager": True, - "gpu_memory_utilization": 0.4, - "max_num_batched_tokens": 1024, - "max_model_len": 1024, - "max_num_seqs": 256, - "disable_log_requests": True, - "disable_log_stats": True, - "dtype": "float16" -} - -from vllm.sampling_params import RequestOutputKind -# SamplingParams -SAMPLING_PARAMS = { - "temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token - "top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token - "top_k": 25, - # "min_tokens": 80, # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动 - # "presence_penalty": 1.0, # 不支持设置 - # "frequency_penalty": 0.0, # 不支持设置 - "max_tokens": 1024, - "detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算 - "ignore_eos": False, - "output_kind": RequestOutputKind.DELTA # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码 -} - -def tensor_to_list(tensor: torch.tensor): - return tensor.view(-1).cpu().numpy().tolist() - -class VllmQwen2LM(Qwen2LM): - def __init__( - self, - model_dir, - mix_ratio: List[int] = [5, 15], - ): - self.fp16 = False - self.half = lambda: None - self.mix_ratio = mix_ratio - # --------------------------------------------- - # vllm engine 的参数配置 - engine_args = AsyncEngineArgs( - model=model_dir, - **ENGINE_ARGS, - ) - self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args) - - self.speech_token_size = 6564 # 6561 + 3 - self.llm_token_size = 151936 # llm vocab_size - self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1 - self.task_token_id = self.sos_eos_token_id + 1 - self.zero_token_id = self.task_token_id + 1 - - # vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务 - self.loop = asyncio.new_event_loop() - self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True) - self.loop_thread.start() - - def _run_event_loop(self): - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens): - sampling_params = SamplingParams(**SAMPLING_PARAMS) - sampling_params.stop_token_ids = stop_token_ids or [6561] - if max_tokens: - sampling_params.max_tokens = max_tokens - async for output in self.llm_engine.generate( - { - "prompt_token_ids": prompt_token_ids, - }, - sampling_params=sampling_params, - request_id=request_id or f"{time.time()}", - ): - out_queue.put((output.outputs[0], output.finished)) - - def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None): - out_queue = queue.Queue() - asyncio.run_coroutine_threadsafe( - self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop - ) - # 接收 out_queue 返回的结果 - finished = False - while not finished: - (output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get() - yield output - - def inference( - self, - text: torch.Tensor, - text_len: torch.Tensor, - prompt_text: torch.Tensor, - prompt_text_len: torch.Tensor, - prompt_speech_token: torch.Tensor, - prompt_speech_token_len: torch.Tensor, - embedding: torch.Tensor, - sampling: int = 25, - max_token_text_ratio: float = 20, - min_token_text_ratio: float = 2, - ) -> Generator[torch.Tensor|int, None, None]: - prompt_text = tensor_to_list(prompt_text + torch.tensor(6564)) - prompt_speech_token = tensor_to_list(prompt_speech_token) - - text = tensor_to_list(text + torch.tensor(6564)) - prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \ - [self.task_token_id] + prompt_speech_token - max_tokens = len(text) * 20 - for output in self.llm_inference( - prompt_token_ids, - stop_token_ids=[6561], - max_tokens=max_tokens, - ): - if output.token_ids[-1] == 6561: - need_add_tokens = output.token_ids[:-1] - else: - need_add_tokens = output.token_ids - for token in need_add_tokens: - yield token - - def inference_bistream( - self, - text: Generator, - prompt_text: torch.Tensor, - prompt_text_len: torch.Tensor, - prompt_speech_token: torch.Tensor, - prompt_speech_token_len: torch.Tensor, - embedding: torch.Tensor, - sampling: int = 25, - max_token_text_ratio: float = 20, - min_token_text_ratio: float = 2, - ) -> Generator[torch.Tensor, None, None]: - prompt_text = tensor_to_list(prompt_text + torch.tensor(6564)) - prompt_speech_token = tensor_to_list(prompt_speech_token) - - last_tokens = [] - prompt_token_ids = [self.sos_eos_token_id] - text_tokens_cache = prompt_text - for this_text in text: - this_text = tensor_to_list(this_text + torch.tensor(6564)) - # text need tokens - assert isinstance(this_text, list), "text need token ids List[int]." - text_tokens_cache += this_text - while len(prompt_speech_token) != 0: - if len(text_tokens_cache) >= self.mix_ratio[0]: - text_input_token = text_tokens_cache[:self.mix_ratio[0]] - speech_input_token = prompt_speech_token[:self.mix_ratio[1]] - prompt_token_ids += text_input_token + speech_input_token - # reset the last cache - text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] - prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:] - else: - break - if len(prompt_speech_token) == 0: - if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1: - if len(text_tokens_cache) >= self.mix_ratio[0]: - text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]] - prompt_token_ids += text_tokens_temp - text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] - else: - continue - for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]): - last_tokens = output.token_ids - if last_tokens[-1] == 6563: - need_add_tokens = last_tokens[:-1] - else: - need_add_tokens = last_tokens - for token in need_add_tokens: - yield token - prompt_token_ids.extend(need_add_tokens) - prompt_token_ids += text_tokens_cache + [self.task_token_id] - for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]): - if output.token_ids[-1] == 6561: - need_add_tokens = output.token_ids[:-1] - else: - need_add_tokens = output.token_ids - for token in need_add_tokens: - yield token diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index 1fbddae..a92f8e7 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -16,7 +16,8 @@ import os import json -import torch, torchaudio +import torch +import torchaudio import logging logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.basicConfig(level=logging.DEBUG, diff --git a/cosyvoice/vllm/cosyvoice2.py b/cosyvoice/vllm/cosyvoice2.py new file mode 100644 index 0000000..de0bc76 --- /dev/null +++ b/cosyvoice/vllm/cosyvoice2.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2 model compatible with HuggingFace weights.""" +from vllm.model_executor.models.qwen2 import * + + +class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + True, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) From 7d9d84d32d115d84303ce80c8b576ed1509607bb Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Fri, 30 May 2025 09:12:03 +0000 Subject: [PATCH 4/4] update --- README.md | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c7a724d..3b0bc71 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,10 @@ ## Roadmap +- [x] 2025/05 + + - [x] add cosyvoice 2.0 vllm support + - [x] 2024/12 - [x] 25hz cosyvoice 2.0 released @@ -126,7 +130,7 @@ import torchaudio **CosyVoice2 Usage** ```python -cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False) +cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, load_vllm=False, fp16=False) # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference # zero_shot usage @@ -159,6 +163,27 @@ for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你 torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) ``` +If you want to use vllm for inference, please install `vllm==v0.9.0`. Older vllm version do not support CosyVoice2 inference. + +Notice that `vllm==v0.9.0` has a lot of specific requirements, for example `torch==2.7.0`. You can create a new env to in case your hardward do not support vllm and old env is corrupted. + +``` sh +conda create -n cosyvoice_vllm --clone cosyvoice +pip install vllm==v0.9.0 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com +conda activate cosyvoice_vllm +``` + +```python +import sys +sys.path.append('third_party/Matcha-TTS') +from cosyvoice.cli.cosyvoice import CosyVoice2 +from vllm import ModelRegistry +from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM +ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM) + +cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, load_vllm=True, fp16=False) +``` + **CosyVoice Usage** ```python cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)