diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index e2d62e2..39464ca 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -19,7 +19,7 @@ from hyperpyyaml import load_hyperpyyaml from modelscope import snapshot_download import torch from cosyvoice.cli.frontend import CosyVoiceFrontEnd -from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model +from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, VllmCosyVoice2Model from cosyvoice.utils.file_utils import logging from cosyvoice.utils.class_utils import get_model_type @@ -63,6 +63,9 @@ class CosyVoice: spks = list(self.frontend.spk2info.keys()) return spks + def add_spk_info(self, spk_id, spk_info): + self.frontend.add_spk_info(spk_id, spk_info) + def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True): for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): model_input = self.frontend.frontend_sft(i, spk_id) @@ -88,6 +91,22 @@ class CosyVoice: yield model_output start_time = time.time() + def inference_zero_shot_by_spk_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True): + """使用预定义的说话人执行 zero_shot 推理""" + for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): + model_input = self.frontend.frontend_zero_shot_by_spk_id(i, spk_id) + start_time = time.time() + last_time = start_time + chunk_index = 0 + logging.info('synthesis text {}'.format(i)) + for model_output in self.model.tts(**model_input, stream=stream, speed=speed): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate + logging.info('yield speech index:{}, len {:.2f}, rtf {:.3f}, cost {:.3f}s, all cost time {:.3f}s'.format( + chunk_index, speech_len, (time.time()-last_time)/speech_len, time.time()-last_time, time.time()-start_time)) + yield model_output + last_time = time.time() + chunk_index += 1 + def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True): for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate) @@ -126,7 +145,7 @@ class CosyVoice: class CosyVoice2(CosyVoice): - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 @@ -145,7 +164,14 @@ class CosyVoice2(CosyVoice): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') - self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) + if use_vllm: + try: + self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16) + except Exception as e: + logging.warning(f'use vllm inference failed. \n{e}') + raise e + else: + self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) @@ -171,3 +197,14 @@ class CosyVoice2(CosyVoice): logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) yield model_output start_time = time.time() + + def inference_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id, stream=False, speed=1.0, text_frontend=True): + for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): + model_input = self.frontend.frontend_instruct2_by_spk_id(i, instruct_text, spk_id) + start_time = time.time() + logging.info('synthesis text {}'.format(i)) + for model_output in self.model.tts(**model_input, stream=stream, speed=speed): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 6e10f00..5aa2d34 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Generator +from typing import Generator, Optional import json import onnxruntime import torch @@ -24,6 +24,8 @@ import torchaudio import os import re import inflect +from pydantic import BaseModel, ConfigDict + try: import ttsfrd use_ttsfrd = True @@ -36,6 +38,18 @@ from cosyvoice.utils.file_utils import logging from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation +class SpeakerInfo(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: Optional[str] = None + spk_id: str + prompt_text: str + prompt_text_token: torch.Tensor + speech_feat: torch.Tensor + speech_token: torch.Tensor + embedding: torch.Tensor + + class CosyVoiceFrontEnd: def __init__(self, @@ -55,8 +69,9 @@ class CosyVoiceFrontEnd: self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"]) + self.spk2info_path = spk2info if os.path.exists(spk2info): - self.spk2info = torch.load(spk2info, map_location=self.device) + self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=False) else: self.spk2info = {} self.allowed_special = allowed_special @@ -68,7 +83,8 @@ class CosyVoiceFrontEnd: 'failed to initialize ttsfrd resource' self.frd.set_lang_type('pinyinvg') else: - self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) + # self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) + self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=False) self.en_tn_model = EnNormalizer() self.inflect_parser = inflect.engine() @@ -86,8 +102,9 @@ class CosyVoiceFrontEnd: def _extract_text_token_generator(self, text_generator): for text in text_generator: text_token, _ = self._extract_text_token(text) - for i in range(text_token.shape[1]): - yield text_token[:, i: i + 1] + # for i in range(text_token.shape[1]): + # yield text_token[:, i: i + 1] + yield text_token def _extract_speech_token(self, speech): assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' @@ -138,11 +155,15 @@ class CosyVoiceFrontEnd: text = text.replace(" - ", ",") text = remove_bracket(text) text = re.sub(r'[,,、]+$', '。', text) + if not split: + return text texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False)) else: text = self.en_tn_model.normalize(text) text = spell_out_number(text, self.inflect_parser) + if not split: + return text texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False)) texts = [i for i in texts if not is_only_punctuation(i)] @@ -151,6 +172,7 @@ class CosyVoiceFrontEnd: def frontend_sft(self, tts_text, spk_id): tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) embedding = self.spk2info[spk_id]['embedding'] + assert embedding is not None model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding} return model_input @@ -209,3 +231,60 @@ class CosyVoiceFrontEnd: 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len, 'flow_embedding': embedding} return model_input + + def generate_spk_info(self, spk_id: str, prompt_text: str, prompt_speech_16k: torch.Tensor, resample_rate:int=24000, name: str=None): + assert isinstance(spk_id, str) + assert spk_id not in self.spk2info, "spk_id already exists" + prompt_text_token, _ = self._extract_text_token(prompt_text) + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k) + speech_feat, _ = self._extract_speech_feat(prompt_speech_resample) + speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k) + if resample_rate == 24000: + # cosyvoice2, force speech_feat % speech_token = 2 + token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1]) + speech_feat = speech_feat[:, :2 * token_len] + speech_token = speech_token[:, :token_len] + embedding = self._extract_spk_embedding(prompt_speech_16k) + spk_info = SpeakerInfo( + name=name, + spk_id=spk_id, + prompt_text=prompt_text, + prompt_text_token=prompt_text_token, + speech_feat=speech_feat, + speech_token=speech_token, + embedding=embedding, + ) + self.add_spk_info(spk_id, spk_info) + + def add_spk_info(self, spk_id: str, spk_info: dict|SpeakerInfo): + if isinstance(spk_info, BaseModel): + spk_info = spk_info.model_dump() + self.spk2info[spk_id] = spk_info + if self.spk2info_path: + torch.save(self.spk2info, self.spk2info_path) + + def frontend_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id): + assert spk_id in self.spk2info + tts_text_token, _ = self._extract_text_token(tts_text) + prompt_text_token, _ = self._extract_text_token(instruct_text + '<|endofprompt|>') + model_input = {'text': tts_text_token, + 'prompt_text': prompt_text_token, + 'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'], + 'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'], + 'llm_embedding': self.spk2info[spk_id]['embedding'], + 'flow_embedding': self.spk2info[spk_id]['embedding'], + } + return model_input + + def frontend_zero_shot_by_spk_id(self, tts_text, spk_id): + assert spk_id in self.spk2info + tts_text_token, _ = self._extract_text_token(tts_text) + model_input = {'text': tts_text_token, + 'prompt_text': self.spk2info[spk_id]['prompt_text_token'], + 'llm_prompt_speech_token': self.spk2info[spk_id]['speech_token'], + 'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'], + 'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'], + 'llm_embedding': self.spk2info[spk_id]['embedding'], + 'flow_embedding': self.spk2info[spk_id]['embedding'] + } + return model_input \ No newline at end of file diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 9ebf8cb..c0d25ba 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -409,3 +409,26 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) torch.cuda.empty_cache() + + +class VllmCosyVoice2Model(CosyVoice2Model): + def __init__(self, + model_dir: str, + flow: torch.nn.Module, + hift: torch.nn.Module, + fp16: bool): + try: + from cosyvoice.llm.llm_vllm import VllmQwen2LM + except Exception as e: + raise e + llm = VllmQwen2LM(model_dir) + super().__init__(llm,flow,hift,fp16) + + def load(self, llm_model, flow_model, hift_model): + self.flow.load_state_dict(torch.load(flow_model, weights_only=True, map_location=self.device), strict=True) + self.flow.to(self.device).eval() + # in case hift_model is a hifigan model + hift_state_dict = {k.replace('generator.', ''): v for k, v in + torch.load(hift_model, weights_only=True, map_location=self.device).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py new file mode 100644 index 0000000..c43c53a --- /dev/null +++ b/cosyvoice/llm/llm_vllm.py @@ -0,0 +1,248 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import contextlib +import time +from typing import List, Generator, AsyncGenerator +import torch +from cosyvoice.utils.file_utils import logging +from cosyvoice.llm.llm import Qwen2LM + +# 启用vllm V1版本 +import os +os.environ["VLLM_USE_V1"] = '1' +from vllm import ModelRegistry +from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput +from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs +from vllm.sampling_params import SamplingParams + +from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM +ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM) + +# EngineArgs +ENGINE_ARGS = { + "block_size": 16, + "swap_space": 0, + # "enforce_eager": True, + "gpu_memory_utilization": 0.4, + "max_num_batched_tokens": 1024, + "max_model_len": 1024, + "max_num_seqs": 256, + "disable_log_requests": True, + "disable_log_stats": True, +} + +from vllm.sampling_params import RequestOutputKind +# SamplingParams +SAMPLING_PARAMS = { + "temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token + "top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token + "top_k": 25, + # "min_tokens": 80, # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动 + # "presence_penalty": 1.0, # 不支持设置 + # "frequency_penalty": 0.0, # 不支持设置 + "max_tokens": 1024, + "detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算 + "ignore_eos": False, + "output_kind": RequestOutputKind.DELTA # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码 +} + +def tensor_to_list(tensor: torch.tensor): + return tensor.view(-1).cpu().numpy().tolist() + +class VllmQwen2LM(Qwen2LM): + def __init__( + self, + model_dir, + mix_ratio: List[int] = [5, 15], + ): + self.fp16 = False + self.half = lambda: None + self.mix_ratio = mix_ratio + # --------------------------------------------- + # vllm engine 的参数配置 + engine_args = AsyncEngineArgs( + model=model_dir, + **ENGINE_ARGS, + ) + self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args) + + self.speech_token_size = 6564 # 6561 + 3 + self.llm_token_size = 151936 # llm vocab_size + self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1 + self.task_token_id = self.sos_eos_token_id + 1 + self.zero_token_id = self.task_token_id + 1 + + async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\ + -> AsyncGenerator[CompletionOutput, None]: + assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]" + invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None) + assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}" + # logging.debug('prompt_token_ids:', prompt_token_ids) + # TODO: 增加上下文控制,取消请求时 + sampling_params = SamplingParams(**SAMPLING_PARAMS) + sampling_params.stop_token_ids = stop_token_ids or [6561] + if max_tokens: + sampling_params.max_tokens = max_tokens + async for output in self.llm_engine.generate( + { + "prompt_token_ids": prompt_token_ids, + }, + sampling_params=sampling_params, + request_id=request_id or f"{time.time()}", + ): + yield output.outputs[0] + + + def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\ + -> Generator[CompletionOutput, None, None]: + assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]" + invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None) + assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}" + # logging.debug('prompt_token_ids:', prompt_token_ids) + # TODO: 增加上下文控制,取消请求时 + sampling_params = SamplingParams(**SAMPLING_PARAMS) + sampling_params.stop_token_ids = stop_token_ids or [6561] + if max_tokens: + sampling_params.max_tokens = max_tokens + + # 创建独立事件循环 + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + # 初始化异步生成器 + async_gen = self.llm_engine.generate( + { + "prompt_token_ids": prompt_token_ids, + }, + sampling_params=sampling_params, + request_id=request_id or f"{time.time()}", + ) + while True: + try: + # 同步获取异步结果 + output = loop.run_until_complete(async_gen.__anext__()) + yield output.outputs[0] + except StopAsyncIteration: + break + except GeneratorExit: + if async_gen is not None: + loop.run_until_complete(async_gen.aclose()) + raise + finally: + # 资源清理 + print("资源清理...") + if async_gen is not None: + loop.run_until_complete(async_gen.aclose()) + loop.close() + print("资源清理成功") + + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor|int, None, None]: + prompt_text = tensor_to_list(prompt_text + torch.tensor(6564)) + prompt_speech_token = tensor_to_list(prompt_speech_token) + + text = tensor_to_list(text + torch.tensor(6564)) + prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \ + [self.task_token_id] + prompt_speech_token + max_tokens = len(text) * 20 + for output in self.llm_inference( + prompt_token_ids, + stop_token_ids=[6561], + max_tokens=max_tokens, + ): + if output.token_ids[-1] == 6561: + need_add_tokens = output.token_ids[:-1] + else: + need_add_tokens = output.token_ids + # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 + # yield need_add_tokens + for token in need_add_tokens: + yield token + + def inference_bistream( + self, + text: Generator, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + last_tokens = [] + prompt_token_ids = [self.sos_eos_token_id] + text_tokens_cache = prompt_text + for this_text in text: + this_text = tensor_to_list(this_text + torch.tensor(6564)) + # text need tokens + assert isinstance(this_text, list), "text need token ids List[int]." + text_tokens_cache += this_text + while len(llm_prompt_speech_token) != 0: + if len(text_tokens_cache) >= self.mix_ratio[0]: + text_input_token = text_tokens_cache[:self.mix_ratio[0]] + speech_input_token = llm_prompt_speech_token[:self.mix_ratio[1]] + prompt_token_ids += text_input_token + speech_input_token + # reset the last cache + text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] + llm_prompt_speech_token = llm_prompt_speech_token[self.mix_ratio[1]:] + else: + logging.info('not enough text token to decode, wait for more') + break + if len(llm_prompt_speech_token) == 0: + if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1: + logging.info('get fill token, need to append more text token') + if len(text_tokens_cache) >= self.mix_ratio[0]: + text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]] + prompt_token_ids += text_tokens_temp + logging.info('append {} text token'.format(len(text_tokens_temp))) + text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] + else: + logging.info('not enough text token to decode, wait for more') + continue + for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]): + last_tokens = output.token_ids + if last_tokens[-1] == 6563: + need_add_tokens = last_tokens[:-1] + else: + need_add_tokens = last_tokens + # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 + # yield need_add_tokens + for token in need_add_tokens: + yield token + prompt_token_ids.extend(need_add_tokens) + prompt_token_ids += text_tokens_cache + [self.task_token_id] + logging.info('no more text token, decode until met eos') + for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]): + if output.token_ids[-1] == 6561: + need_add_tokens = output.token_ids[:-1] + else: + need_add_tokens = output.token_ids + # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 + # yield need_add_tokens + for token in need_add_tokens: + yield token diff --git a/cosyvoice/llm/vllm_use_cosyvoice2_model.py b/cosyvoice/llm/vllm_use_cosyvoice2_model.py new file mode 100644 index 0000000..6e36ef3 --- /dev/null +++ b/cosyvoice/llm/vllm_use_cosyvoice2_model.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2 model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any +from typing_extensions import TypeVar + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm.model_executor.models.interfaces import T +from vllm.model_executor.models.qwen2 import Qwen2Model + +from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings + +logger = init_logger(__name__) + +IGNORE_ID = -1 + + +class CosyVoice2Model(nn.Module): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.llm_input_size = 896 + self.llm_output_size = 896 + + self.speech_token_size = 6561+3 + self.llm_token_size = config.vocab_size + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.fill_token = 2 + + + self.allow_patterns_overrides = ["llm.*"] + self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size) + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + # self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size) + self.llm_decoder = ParallelLMHead(self.speech_token_size, + self.llm_output_size, + bias=True, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "llm_decoder")) + self.logits_processor = LogitsProcessor(self.speech_token_size) + + # length_normalized_loss: bool = True, + # lsm_weight: float = 0.0, + # self.criterion_ce = LabelSmoothingLoss( + # size=self.speech_token_size, + # padding_idx=IGNORE_ID, + # smoothing=lsm_weight, + # normalize_length=length_normalized_loss, + # ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size) + + # 4. sampling method + ## use vllm sampling method + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.mix_ratio: List[int] = [5, 15] + + # 定义特殊token常量 + self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32) + self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404 + self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405 + self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32) + + self.zero_embed_buffer = torch.zeros( + (vllm_config.scheduler_config.max_num_seqs, self.llm_input_size), + dtype=self.llm_embedding.weight.dtype, + device=self.llm_embedding.weight.device + ) + self.inputs_embed_buffer = torch.zeros( + (vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size), + dtype=self.llm_embedding.weight.dtype, + device=self.llm_embedding.weight.device, + ) + + def get_sos_eos_emb(self): + return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + + def get_task_id_emb(self): + return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[T] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + ) -> torch.Tensor: + """ + Returns the input embeddings merged from the text embeddings from + input_ids and the multimodal embeddings generated from multimodal + kwargs. + """ + # 创建掩码,标记哪些 token_id 属于音频 Token + mask = input_ids < self.speech_token_size + + # 获取 input_ids 的原始形状 + input_shape = input_ids.shape + # 展平 input_ids 和掩码以便统一处理 + flat_input_ids = input_ids.view(-1) + flat_mask = mask.view(-1) + + inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]] + inputs_embeds.zero_() + + # Process speech tokens + if flat_mask.any(): + speech_token_ids = flat_input_ids[flat_mask] + inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids) + + # 处理大于 delta 的 token_id + if (~flat_mask).any(): + llm_token_ids = flat_input_ids[~flat_mask] + llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask]) + + sos_eos_mask = llm_token_ids == self.sos_eos_token_id + task_mask = llm_token_ids == self.task_token_id + zero_mask = llm_token_ids == self.zero_token_id + normal_mask = ~(sos_eos_mask | task_mask | zero_mask) + + # 分层处理逻辑 + # 第一优先级:SOS/EOS标记 + if sos_eos_mask.any(): + llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0) + + # 第二优先级:任务标记 + if task_mask.any(): + llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0) + + # 第二优先级:空音频标记 + if zero_mask.any(): + llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])] + + # 常规LLM token + if normal_mask.any(): + original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta + # print('original_ids: ',original_ids) + llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids) + + inputs_embeds[~flat_mask] = llm_embeds + + inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size) + + # 合并多模态嵌入(如果有) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.audio_token_index + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings( + input_ids, + attn_metadata=attn_metadata, + ) + return self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.llm_decoder, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + @staticmethod + def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]: + for name, param in weights: + # 处理Qwen2Model核心参数 + if name.startswith("llm."): + if name.startswith("llm.model.model."): + name = name.replace("llm.model.model.", "model.") + else: + continue + # print('weights name: ', name) + yield name, param + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + weights = self.convert_weights(weights) + loader = AutoWeightsLoader(self) + loader.load_weights(weights)