初步合并vllm支持,异步推理的通道处理还存在bug

This commit is contained in:
qihua
2025-03-07 20:26:19 +08:00
parent fd45708e4b
commit 90b666ea20
5 changed files with 658 additions and 8 deletions

View File

@@ -19,7 +19,7 @@ from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download
import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, VllmCosyVoice2Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
@@ -63,6 +63,9 @@ class CosyVoice:
spks = list(self.frontend.spk2info.keys())
return spks
def add_spk_info(self, spk_id, spk_info):
self.frontend.add_spk_info(spk_id, spk_info)
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_sft(i, spk_id)
@@ -88,6 +91,22 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_zero_shot_by_spk_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
"""使用预定义的说话人执行 zero_shot 推理"""
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_zero_shot_by_spk_id(i, spk_id)
start_time = time.time()
last_time = start_time
chunk_index = 0
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech index:{}, len {:.2f}, rtf {:.3f}, cost {:.3f}s, all cost time {:.3f}s'.format(
chunk_index, speech_len, (time.time()-last_time)/speech_len, time.time()-last_time, time.time()-start_time))
yield model_output
last_time = time.time()
chunk_index += 1
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
@@ -126,7 +145,7 @@ class CosyVoice:
class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False):
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
self.fp16 = fp16
@@ -145,7 +164,14 @@ class CosyVoice2(CosyVoice):
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
if use_vllm:
try:
self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16)
except Exception as e:
logging.warning(f'use vllm inference failed. \n{e}')
raise e
else:
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
@@ -171,3 +197,14 @@ class CosyVoice2(CosyVoice):
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2_by_spk_id(i, instruct_text, spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Generator
from typing import Generator, Optional
import json
import onnxruntime
import torch
@@ -24,6 +24,8 @@ import torchaudio
import os
import re
import inflect
from pydantic import BaseModel, ConfigDict
try:
import ttsfrd
use_ttsfrd = True
@@ -36,6 +38,18 @@ from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
class SpeakerInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
name: Optional[str] = None
spk_id: str
prompt_text: str
prompt_text_token: torch.Tensor
speech_feat: torch.Tensor
speech_token: torch.Tensor
embedding: torch.Tensor
class CosyVoiceFrontEnd:
def __init__(self,
@@ -55,8 +69,9 @@ class CosyVoiceFrontEnd:
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
"CPUExecutionProvider"])
self.spk2info_path = spk2info
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device)
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=False)
else:
self.spk2info = {}
self.allowed_special = allowed_special
@@ -68,7 +83,8 @@ class CosyVoiceFrontEnd:
'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyinvg')
else:
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
# self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=False)
self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine()
@@ -86,8 +102,9 @@ class CosyVoiceFrontEnd:
def _extract_text_token_generator(self, text_generator):
for text in text_generator:
text_token, _ = self._extract_text_token(text)
for i in range(text_token.shape[1]):
yield text_token[:, i: i + 1]
# for i in range(text_token.shape[1]):
# yield text_token[:, i: i + 1]
yield text_token
def _extract_speech_token(self, speech):
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
@@ -138,11 +155,15 @@ class CosyVoiceFrontEnd:
text = text.replace(" - ", "")
text = remove_bracket(text)
text = re.sub(r'[,、]+$', '', text)
if not split:
return text
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
else:
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
if not split:
return text
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
texts = [i for i in texts if not is_only_punctuation(i)]
@@ -151,6 +172,7 @@ class CosyVoiceFrontEnd:
def frontend_sft(self, tts_text, spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
embedding = self.spk2info[spk_id]['embedding']
assert embedding is not None
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input
@@ -209,3 +231,60 @@ class CosyVoiceFrontEnd:
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
'flow_embedding': embedding}
return model_input
def generate_spk_info(self, spk_id: str, prompt_text: str, prompt_speech_16k: torch.Tensor, resample_rate:int=24000, name: str=None):
assert isinstance(spk_id, str)
assert spk_id not in self.spk2info, "spk_id already exists"
prompt_text_token, _ = self._extract_text_token(prompt_text)
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
speech_feat, _ = self._extract_speech_feat(prompt_speech_resample)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
if resample_rate == 24000:
# cosyvoice2, force speech_feat % speech_token = 2
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
speech_feat = speech_feat[:, :2 * token_len]
speech_token = speech_token[:, :token_len]
embedding = self._extract_spk_embedding(prompt_speech_16k)
spk_info = SpeakerInfo(
name=name,
spk_id=spk_id,
prompt_text=prompt_text,
prompt_text_token=prompt_text_token,
speech_feat=speech_feat,
speech_token=speech_token,
embedding=embedding,
)
self.add_spk_info(spk_id, spk_info)
def add_spk_info(self, spk_id: str, spk_info: dict|SpeakerInfo):
if isinstance(spk_info, BaseModel):
spk_info = spk_info.model_dump()
self.spk2info[spk_id] = spk_info
if self.spk2info_path:
torch.save(self.spk2info, self.spk2info_path)
def frontend_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id):
assert spk_id in self.spk2info
tts_text_token, _ = self._extract_text_token(tts_text)
prompt_text_token, _ = self._extract_text_token(instruct_text + '<|endofprompt|>')
model_input = {'text': tts_text_token,
'prompt_text': prompt_text_token,
'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'],
'llm_embedding': self.spk2info[spk_id]['embedding'],
'flow_embedding': self.spk2info[spk_id]['embedding'],
}
return model_input
def frontend_zero_shot_by_spk_id(self, tts_text, spk_id):
assert spk_id in self.spk2info
tts_text_token, _ = self._extract_text_token(tts_text)
model_input = {'text': tts_text_token,
'prompt_text': self.spk2info[spk_id]['prompt_text_token'],
'llm_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'],
'llm_embedding': self.spk2info[spk_id]['embedding'],
'flow_embedding': self.spk2info[spk_id]['embedding']
}
return model_input

View File

@@ -409,3 +409,26 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
torch.cuda.empty_cache()
class VllmCosyVoice2Model(CosyVoice2Model):
def __init__(self,
model_dir: str,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool):
try:
from cosyvoice.llm.llm_vllm import VllmQwen2LM
except Exception as e:
raise e
llm = VllmQwen2LM(model_dir)
super().__init__(llm,flow,hift,fp16)
def load(self, llm_model, flow_model, hift_model):
self.flow.load_state_dict(torch.load(flow_model, weights_only=True, map_location=self.device), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in
torch.load(hift_model, weights_only=True, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()

248
cosyvoice/llm/llm_vllm.py Normal file
View File

@@ -0,0 +1,248 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import contextlib
import time
from typing import List, Generator, AsyncGenerator
import torch
from cosyvoice.utils.file_utils import logging
from cosyvoice.llm.llm import Qwen2LM
# 启用vllm V1版本
import os
os.environ["VLLM_USE_V1"] = '1'
from vllm import ModelRegistry
from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
from vllm.sampling_params import SamplingParams
from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
# EngineArgs
ENGINE_ARGS = {
"block_size": 16,
"swap_space": 0,
# "enforce_eager": True,
"gpu_memory_utilization": 0.4,
"max_num_batched_tokens": 1024,
"max_model_len": 1024,
"max_num_seqs": 256,
"disable_log_requests": True,
"disable_log_stats": True,
}
from vllm.sampling_params import RequestOutputKind
# SamplingParams
SAMPLING_PARAMS = {
"temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频或者无法正常生成语音Token
"top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频或者无法正常生成语音Token
"top_k": 25,
# "min_tokens": 80, # 不支持设置最小的tokens数量设置开启后vllm直接崩溃无法启动
# "presence_penalty": 1.0, # 不支持设置
# "frequency_penalty": 0.0, # 不支持设置
"max_tokens": 1024,
"detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效待后续版本更新后减少计算
"ignore_eos": False,
"output_kind": RequestOutputKind.DELTA # 设置为DELTA如调整该参数请同时调整llm_inference的处理代码
}
def tensor_to_list(tensor: torch.tensor):
return tensor.view(-1).cpu().numpy().tolist()
class VllmQwen2LM(Qwen2LM):
def __init__(
self,
model_dir,
mix_ratio: List[int] = [5, 15],
):
self.fp16 = False
self.half = lambda: None
self.mix_ratio = mix_ratio
# ---------------------------------------------
# vllm engine 的参数配置
engine_args = AsyncEngineArgs(
model=model_dir,
**ENGINE_ARGS,
)
self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
self.speech_token_size = 6564 # 6561 + 3
self.llm_token_size = 151936 # llm vocab_size
self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1
self.task_token_id = self.sos_eos_token_id + 1
self.zero_token_id = self.task_token_id + 1
async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
-> AsyncGenerator[CompletionOutput, None]:
assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
# logging.debug('prompt_token_ids:', prompt_token_ids)
# TODO: 增加上下文控制,取消请求时
sampling_params = SamplingParams(**SAMPLING_PARAMS)
sampling_params.stop_token_ids = stop_token_ids or [6561]
if max_tokens:
sampling_params.max_tokens = max_tokens
async for output in self.llm_engine.generate(
{
"prompt_token_ids": prompt_token_ids,
},
sampling_params=sampling_params,
request_id=request_id or f"{time.time()}",
):
yield output.outputs[0]
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
-> Generator[CompletionOutput, None, None]:
assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
# logging.debug('prompt_token_ids:', prompt_token_ids)
# TODO: 增加上下文控制,取消请求时
sampling_params = SamplingParams(**SAMPLING_PARAMS)
sampling_params.stop_token_ids = stop_token_ids or [6561]
if max_tokens:
sampling_params.max_tokens = max_tokens
# 创建独立事件循环
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
# 初始化异步生成器
async_gen = self.llm_engine.generate(
{
"prompt_token_ids": prompt_token_ids,
},
sampling_params=sampling_params,
request_id=request_id or f"{time.time()}",
)
while True:
try:
# 同步获取异步结果
output = loop.run_until_complete(async_gen.__anext__())
yield output.outputs[0]
except StopAsyncIteration:
break
except GeneratorExit:
if async_gen is not None:
loop.run_until_complete(async_gen.aclose())
raise
finally:
# 资源清理
print("资源清理...")
if async_gen is not None:
loop.run_until_complete(async_gen.aclose())
loop.close()
print("资源清理成功")
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor|int, None, None]:
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
prompt_speech_token = tensor_to_list(prompt_speech_token)
text = tensor_to_list(text + torch.tensor(6564))
prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \
[self.task_token_id] + prompt_speech_token
max_tokens = len(text) * 20
for output in self.llm_inference(
prompt_token_ids,
stop_token_ids=[6561],
max_tokens=max_tokens,
):
if output.token_ids[-1] == 6561:
need_add_tokens = output.token_ids[:-1]
else:
need_add_tokens = output.token_ids
# 单个token 循环处理比较耗时建议是在model中进行批量extend处理减少循环
# yield need_add_tokens
for token in need_add_tokens:
yield token
def inference_bistream(
self,
text: Generator,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor, None, None]:
last_tokens = []
prompt_token_ids = [self.sos_eos_token_id]
text_tokens_cache = prompt_text
for this_text in text:
this_text = tensor_to_list(this_text + torch.tensor(6564))
# text need tokens
assert isinstance(this_text, list), "text need token ids List[int]."
text_tokens_cache += this_text
while len(llm_prompt_speech_token) != 0:
if len(text_tokens_cache) >= self.mix_ratio[0]:
text_input_token = text_tokens_cache[:self.mix_ratio[0]]
speech_input_token = llm_prompt_speech_token[:self.mix_ratio[1]]
prompt_token_ids += text_input_token + speech_input_token
# reset the last cache
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
llm_prompt_speech_token = llm_prompt_speech_token[self.mix_ratio[1]:]
else:
logging.info('not enough text token to decode, wait for more')
break
if len(llm_prompt_speech_token) == 0:
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
logging.info('get fill token, need to append more text token')
if len(text_tokens_cache) >= self.mix_ratio[0]:
text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
prompt_token_ids += text_tokens_temp
logging.info('append {} text token'.format(len(text_tokens_temp)))
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
else:
logging.info('not enough text token to decode, wait for more')
continue
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
last_tokens = output.token_ids
if last_tokens[-1] == 6563:
need_add_tokens = last_tokens[:-1]
else:
need_add_tokens = last_tokens
# 单个token 循环处理比较耗时建议是在model中进行批量extend处理减少循环
# yield need_add_tokens
for token in need_add_tokens:
yield token
prompt_token_ids.extend(need_add_tokens)
prompt_token_ids += text_tokens_cache + [self.task_token_id]
logging.info('no more text token, decode until met eos')
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
if output.token_ids[-1] == 6561:
need_add_tokens = output.token_ids[:-1]
else:
need_add_tokens = output.token_ids
# 单个token 循环处理比较耗时建议是在model中进行批量extend处理减少循环
# yield need_add_tokens
for token in need_add_tokens:
yield token

View File

@@ -0,0 +1,263 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any
from typing_extensions import TypeVar
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import T
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
logger = init_logger(__name__)
IGNORE_ID = -1
class CosyVoice2Model(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.llm_input_size = 896
self.llm_output_size = 896
self.speech_token_size = 6561+3
self.llm_token_size = config.vocab_size
# 2. build speech token language model related modules
self.sos_eos = 0
self.task_id = 1
self.fill_token = 2
self.allow_patterns_overrides = ["llm.*"]
self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size)
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
# self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size)
self.llm_decoder = ParallelLMHead(self.speech_token_size,
self.llm_output_size,
bias=True,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "llm_decoder"))
self.logits_processor = LogitsProcessor(self.speech_token_size)
# length_normalized_loss: bool = True,
# lsm_weight: float = 0.0,
# self.criterion_ce = LabelSmoothingLoss(
# size=self.speech_token_size,
# padding_idx=IGNORE_ID,
# smoothing=lsm_weight,
# normalize_length=length_normalized_loss,
# )
# 3. [Optional] build speech token related modules
self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size)
# 4. sampling method
## use vllm sampling method
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.mix_ratio: List[int] = [5, 15]
# 定义特殊token常量
self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32)
self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404
self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405
self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32)
self.zero_embed_buffer = torch.zeros(
(vllm_config.scheduler_config.max_num_seqs, self.llm_input_size),
dtype=self.llm_embedding.weight.dtype,
device=self.llm_embedding.weight.device
)
self.inputs_embed_buffer = torch.zeros(
(vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size),
dtype=self.llm_embedding.weight.dtype,
device=self.llm_embedding.weight.device,
)
def get_sos_eos_emb(self):
return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
def get_task_id_emb(self):
return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[T] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
) -> torch.Tensor:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
kwargs.
"""
# 创建掩码,标记哪些 token_id 属于音频 Token
mask = input_ids < self.speech_token_size
# 获取 input_ids 的原始形状
input_shape = input_ids.shape
# 展平 input_ids 和掩码以便统一处理
flat_input_ids = input_ids.view(-1)
flat_mask = mask.view(-1)
inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]]
inputs_embeds.zero_()
# Process speech tokens
if flat_mask.any():
speech_token_ids = flat_input_ids[flat_mask]
inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids)
# 处理大于 delta 的 token_id
if (~flat_mask).any():
llm_token_ids = flat_input_ids[~flat_mask]
llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask])
sos_eos_mask = llm_token_ids == self.sos_eos_token_id
task_mask = llm_token_ids == self.task_token_id
zero_mask = llm_token_ids == self.zero_token_id
normal_mask = ~(sos_eos_mask | task_mask | zero_mask)
# 分层处理逻辑
# 第一优先级SOS/EOS标记
if sos_eos_mask.any():
llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0)
# 第二优先级:任务标记
if task_mask.any():
llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0)
# 第二优先级:空音频标记
if zero_mask.any():
llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])]
# 常规LLM token
if normal_mask.any():
original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta
# print('original_ids: ',original_ids)
llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids)
inputs_embeds[~flat_mask] = llm_embeds
inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size)
# 合并多模态嵌入(如果有)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(
input_ids,
attn_metadata=attn_metadata,
)
return self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.llm_decoder, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
@staticmethod
def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]:
for name, param in weights:
# 处理Qwen2Model核心参数
if name.startswith("llm."):
if name.startswith("llm.model.model."):
name = name.replace("llm.model.model.", "model.")
else:
continue
# print('weights name: ', name)
yield name, param
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = self.convert_weights(weights)
loader = AutoWeightsLoader(self)
loader.load_weights(weights)