初步合并vllm支持,异步推理的通道处理还存在bug

This commit is contained in:
qihua
2025-03-07 20:26:19 +08:00
parent fd45708e4b
commit 90b666ea20
5 changed files with 658 additions and 8 deletions

View File

@@ -19,7 +19,7 @@ from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download from modelscope import snapshot_download
import torch import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, VllmCosyVoice2Model
from cosyvoice.utils.file_utils import logging from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type from cosyvoice.utils.class_utils import get_model_type
@@ -63,6 +63,9 @@ class CosyVoice:
spks = list(self.frontend.spk2info.keys()) spks = list(self.frontend.spk2info.keys())
return spks return spks
def add_spk_info(self, spk_id, spk_info):
self.frontend.add_spk_info(spk_id, spk_info)
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True): def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_sft(i, spk_id) model_input = self.frontend.frontend_sft(i, spk_id)
@@ -88,6 +91,22 @@ class CosyVoice:
yield model_output yield model_output
start_time = time.time() start_time = time.time()
def inference_zero_shot_by_spk_id(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
"""使用预定义的说话人执行 zero_shot 推理"""
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_zero_shot_by_spk_id(i, spk_id)
start_time = time.time()
last_time = start_time
chunk_index = 0
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech index:{}, len {:.2f}, rtf {:.3f}, cost {:.3f}s, all cost time {:.3f}s'.format(
chunk_index, speech_len, (time.time()-last_time)/speech_len, time.time()-last_time, time.time()-start_time))
yield model_output
last_time = time.time()
chunk_index += 1
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True): def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate) model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
@@ -126,7 +145,7 @@ class CosyVoice:
class CosyVoice2(CosyVoice): class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vllm=False):
self.instruct = True if '-Instruct' in model_dir else False self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir self.model_dir = model_dir
self.fp16 = fp16 self.fp16 = fp16
@@ -145,7 +164,14 @@ class CosyVoice2(CosyVoice):
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) if use_vllm:
try:
self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16)
except Exception as e:
logging.warning(f'use vllm inference failed. \n{e}')
raise e
else:
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir), self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir), '{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir)) '{}/hift.pt'.format(model_dir))
@@ -171,3 +197,14 @@ class CosyVoice2(CosyVoice):
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output yield model_output
start_time = time.time() start_time = time.time()
def inference_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2_by_spk_id(i, instruct_text, spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
from typing import Generator from typing import Generator, Optional
import json import json
import onnxruntime import onnxruntime
import torch import torch
@@ -24,6 +24,8 @@ import torchaudio
import os import os
import re import re
import inflect import inflect
from pydantic import BaseModel, ConfigDict
try: try:
import ttsfrd import ttsfrd
use_ttsfrd = True use_ttsfrd = True
@@ -36,6 +38,18 @@ from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
class SpeakerInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
name: Optional[str] = None
spk_id: str
prompt_text: str
prompt_text_token: torch.Tensor
speech_feat: torch.Tensor
speech_token: torch.Tensor
embedding: torch.Tensor
class CosyVoiceFrontEnd: class CosyVoiceFrontEnd:
def __init__(self, def __init__(self,
@@ -55,8 +69,9 @@ class CosyVoiceFrontEnd:
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
"CPUExecutionProvider"]) "CPUExecutionProvider"])
self.spk2info_path = spk2info
if os.path.exists(spk2info): if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device) self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=False)
else: else:
self.spk2info = {} self.spk2info = {}
self.allowed_special = allowed_special self.allowed_special = allowed_special
@@ -68,7 +83,8 @@ class CosyVoiceFrontEnd:
'failed to initialize ttsfrd resource' 'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyinvg') self.frd.set_lang_type('pinyinvg')
else: else:
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) # self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=False)
self.en_tn_model = EnNormalizer() self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine() self.inflect_parser = inflect.engine()
@@ -86,8 +102,9 @@ class CosyVoiceFrontEnd:
def _extract_text_token_generator(self, text_generator): def _extract_text_token_generator(self, text_generator):
for text in text_generator: for text in text_generator:
text_token, _ = self._extract_text_token(text) text_token, _ = self._extract_text_token(text)
for i in range(text_token.shape[1]): # for i in range(text_token.shape[1]):
yield text_token[:, i: i + 1] # yield text_token[:, i: i + 1]
yield text_token
def _extract_speech_token(self, speech): def _extract_speech_token(self, speech):
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
@@ -138,11 +155,15 @@ class CosyVoiceFrontEnd:
text = text.replace(" - ", "") text = text.replace(" - ", "")
text = remove_bracket(text) text = remove_bracket(text)
text = re.sub(r'[,、]+$', '', text) text = re.sub(r'[,、]+$', '', text)
if not split:
return text
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False)) token_min_n=60, merge_len=20, comma_split=False))
else: else:
text = self.en_tn_model.normalize(text) text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser) text = spell_out_number(text, self.inflect_parser)
if not split:
return text
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False)) token_min_n=60, merge_len=20, comma_split=False))
texts = [i for i in texts if not is_only_punctuation(i)] texts = [i for i in texts if not is_only_punctuation(i)]
@@ -151,6 +172,7 @@ class CosyVoiceFrontEnd:
def frontend_sft(self, tts_text, spk_id): def frontend_sft(self, tts_text, spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
embedding = self.spk2info[spk_id]['embedding'] embedding = self.spk2info[spk_id]['embedding']
assert embedding is not None
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding} model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input return model_input
@@ -209,3 +231,60 @@ class CosyVoiceFrontEnd:
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len, 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
'flow_embedding': embedding} 'flow_embedding': embedding}
return model_input return model_input
def generate_spk_info(self, spk_id: str, prompt_text: str, prompt_speech_16k: torch.Tensor, resample_rate:int=24000, name: str=None):
assert isinstance(spk_id, str)
assert spk_id not in self.spk2info, "spk_id already exists"
prompt_text_token, _ = self._extract_text_token(prompt_text)
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
speech_feat, _ = self._extract_speech_feat(prompt_speech_resample)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
if resample_rate == 24000:
# cosyvoice2, force speech_feat % speech_token = 2
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
speech_feat = speech_feat[:, :2 * token_len]
speech_token = speech_token[:, :token_len]
embedding = self._extract_spk_embedding(prompt_speech_16k)
spk_info = SpeakerInfo(
name=name,
spk_id=spk_id,
prompt_text=prompt_text,
prompt_text_token=prompt_text_token,
speech_feat=speech_feat,
speech_token=speech_token,
embedding=embedding,
)
self.add_spk_info(spk_id, spk_info)
def add_spk_info(self, spk_id: str, spk_info: dict|SpeakerInfo):
if isinstance(spk_info, BaseModel):
spk_info = spk_info.model_dump()
self.spk2info[spk_id] = spk_info
if self.spk2info_path:
torch.save(self.spk2info, self.spk2info_path)
def frontend_instruct2_by_spk_id(self, tts_text, instruct_text, spk_id):
assert spk_id in self.spk2info
tts_text_token, _ = self._extract_text_token(tts_text)
prompt_text_token, _ = self._extract_text_token(instruct_text + '<|endofprompt|>')
model_input = {'text': tts_text_token,
'prompt_text': prompt_text_token,
'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'],
'llm_embedding': self.spk2info[spk_id]['embedding'],
'flow_embedding': self.spk2info[spk_id]['embedding'],
}
return model_input
def frontend_zero_shot_by_spk_id(self, tts_text, spk_id):
assert spk_id in self.spk2info
tts_text_token, _ = self._extract_text_token(tts_text)
model_input = {'text': tts_text_token,
'prompt_text': self.spk2info[spk_id]['prompt_text_token'],
'llm_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
'flow_prompt_speech_token': self.spk2info[spk_id]['speech_token'],
'prompt_speech_feat': self.spk2info[spk_id]['speech_feat'],
'llm_embedding': self.spk2info[spk_id]['embedding'],
'flow_embedding': self.spk2info[spk_id]['embedding']
}
return model_input

View File

@@ -409,3 +409,26 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict.pop(this_uuid) self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid)
torch.cuda.empty_cache() torch.cuda.empty_cache()
class VllmCosyVoice2Model(CosyVoice2Model):
def __init__(self,
model_dir: str,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool):
try:
from cosyvoice.llm.llm_vllm import VllmQwen2LM
except Exception as e:
raise e
llm = VllmQwen2LM(model_dir)
super().__init__(llm,flow,hift,fp16)
def load(self, llm_model, flow_model, hift_model):
self.flow.load_state_dict(torch.load(flow_model, weights_only=True, map_location=self.device), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in
torch.load(hift_model, weights_only=True, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()

248
cosyvoice/llm/llm_vllm.py Normal file
View File

@@ -0,0 +1,248 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import contextlib
import time
from typing import List, Generator, AsyncGenerator
import torch
from cosyvoice.utils.file_utils import logging
from cosyvoice.llm.llm import Qwen2LM
# 启用vllm V1版本
import os
os.environ["VLLM_USE_V1"] = '1'
from vllm import ModelRegistry
from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
from vllm.sampling_params import SamplingParams
from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
# EngineArgs
ENGINE_ARGS = {
"block_size": 16,
"swap_space": 0,
# "enforce_eager": True,
"gpu_memory_utilization": 0.4,
"max_num_batched_tokens": 1024,
"max_model_len": 1024,
"max_num_seqs": 256,
"disable_log_requests": True,
"disable_log_stats": True,
}
from vllm.sampling_params import RequestOutputKind
# SamplingParams
SAMPLING_PARAMS = {
"temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频或者无法正常生成语音Token
"top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频或者无法正常生成语音Token
"top_k": 25,
# "min_tokens": 80, # 不支持设置最小的tokens数量设置开启后vllm直接崩溃无法启动
# "presence_penalty": 1.0, # 不支持设置
# "frequency_penalty": 0.0, # 不支持设置
"max_tokens": 1024,
"detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效待后续版本更新后减少计算
"ignore_eos": False,
"output_kind": RequestOutputKind.DELTA # 设置为DELTA如调整该参数请同时调整llm_inference的处理代码
}
def tensor_to_list(tensor: torch.tensor):
return tensor.view(-1).cpu().numpy().tolist()
class VllmQwen2LM(Qwen2LM):
def __init__(
self,
model_dir,
mix_ratio: List[int] = [5, 15],
):
self.fp16 = False
self.half = lambda: None
self.mix_ratio = mix_ratio
# ---------------------------------------------
# vllm engine 的参数配置
engine_args = AsyncEngineArgs(
model=model_dir,
**ENGINE_ARGS,
)
self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
self.speech_token_size = 6564 # 6561 + 3
self.llm_token_size = 151936 # llm vocab_size
self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1
self.task_token_id = self.sos_eos_token_id + 1
self.zero_token_id = self.task_token_id + 1
async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
-> AsyncGenerator[CompletionOutput, None]:
assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
# logging.debug('prompt_token_ids:', prompt_token_ids)
# TODO: 增加上下文控制,取消请求时
sampling_params = SamplingParams(**SAMPLING_PARAMS)
sampling_params.stop_token_ids = stop_token_ids or [6561]
if max_tokens:
sampling_params.max_tokens = max_tokens
async for output in self.llm_engine.generate(
{
"prompt_token_ids": prompt_token_ids,
},
sampling_params=sampling_params,
request_id=request_id or f"{time.time()}",
):
yield output.outputs[0]
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
-> Generator[CompletionOutput, None, None]:
assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
# logging.debug('prompt_token_ids:', prompt_token_ids)
# TODO: 增加上下文控制,取消请求时
sampling_params = SamplingParams(**SAMPLING_PARAMS)
sampling_params.stop_token_ids = stop_token_ids or [6561]
if max_tokens:
sampling_params.max_tokens = max_tokens
# 创建独立事件循环
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
# 初始化异步生成器
async_gen = self.llm_engine.generate(
{
"prompt_token_ids": prompt_token_ids,
},
sampling_params=sampling_params,
request_id=request_id or f"{time.time()}",
)
while True:
try:
# 同步获取异步结果
output = loop.run_until_complete(async_gen.__anext__())
yield output.outputs[0]
except StopAsyncIteration:
break
except GeneratorExit:
if async_gen is not None:
loop.run_until_complete(async_gen.aclose())
raise
finally:
# 资源清理
print("资源清理...")
if async_gen is not None:
loop.run_until_complete(async_gen.aclose())
loop.close()
print("资源清理成功")
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor|int, None, None]:
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
prompt_speech_token = tensor_to_list(prompt_speech_token)
text = tensor_to_list(text + torch.tensor(6564))
prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \
[self.task_token_id] + prompt_speech_token
max_tokens = len(text) * 20
for output in self.llm_inference(
prompt_token_ids,
stop_token_ids=[6561],
max_tokens=max_tokens,
):
if output.token_ids[-1] == 6561:
need_add_tokens = output.token_ids[:-1]
else:
need_add_tokens = output.token_ids
# 单个token 循环处理比较耗时建议是在model中进行批量extend处理减少循环
# yield need_add_tokens
for token in need_add_tokens:
yield token
def inference_bistream(
self,
text: Generator,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor, None, None]:
last_tokens = []
prompt_token_ids = [self.sos_eos_token_id]
text_tokens_cache = prompt_text
for this_text in text:
this_text = tensor_to_list(this_text + torch.tensor(6564))
# text need tokens
assert isinstance(this_text, list), "text need token ids List[int]."
text_tokens_cache += this_text
while len(llm_prompt_speech_token) != 0:
if len(text_tokens_cache) >= self.mix_ratio[0]:
text_input_token = text_tokens_cache[:self.mix_ratio[0]]
speech_input_token = llm_prompt_speech_token[:self.mix_ratio[1]]
prompt_token_ids += text_input_token + speech_input_token
# reset the last cache
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
llm_prompt_speech_token = llm_prompt_speech_token[self.mix_ratio[1]:]
else:
logging.info('not enough text token to decode, wait for more')
break
if len(llm_prompt_speech_token) == 0:
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
logging.info('get fill token, need to append more text token')
if len(text_tokens_cache) >= self.mix_ratio[0]:
text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
prompt_token_ids += text_tokens_temp
logging.info('append {} text token'.format(len(text_tokens_temp)))
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
else:
logging.info('not enough text token to decode, wait for more')
continue
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
last_tokens = output.token_ids
if last_tokens[-1] == 6563:
need_add_tokens = last_tokens[:-1]
else:
need_add_tokens = last_tokens
# 单个token 循环处理比较耗时建议是在model中进行批量extend处理减少循环
# yield need_add_tokens
for token in need_add_tokens:
yield token
prompt_token_ids.extend(need_add_tokens)
prompt_token_ids += text_tokens_cache + [self.task_token_id]
logging.info('no more text token, decode until met eos')
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
if output.token_ids[-1] == 6561:
need_add_tokens = output.token_ids[:-1]
else:
need_add_tokens = output.token_ids
# 单个token 循环处理比较耗时建议是在model中进行批量extend处理减少循环
# yield need_add_tokens
for token in need_add_tokens:
yield token

View File

@@ -0,0 +1,263 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any
from typing_extensions import TypeVar
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import T
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
logger = init_logger(__name__)
IGNORE_ID = -1
class CosyVoice2Model(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.llm_input_size = 896
self.llm_output_size = 896
self.speech_token_size = 6561+3
self.llm_token_size = config.vocab_size
# 2. build speech token language model related modules
self.sos_eos = 0
self.task_id = 1
self.fill_token = 2
self.allow_patterns_overrides = ["llm.*"]
self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size)
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
# self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size)
self.llm_decoder = ParallelLMHead(self.speech_token_size,
self.llm_output_size,
bias=True,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "llm_decoder"))
self.logits_processor = LogitsProcessor(self.speech_token_size)
# length_normalized_loss: bool = True,
# lsm_weight: float = 0.0,
# self.criterion_ce = LabelSmoothingLoss(
# size=self.speech_token_size,
# padding_idx=IGNORE_ID,
# smoothing=lsm_weight,
# normalize_length=length_normalized_loss,
# )
# 3. [Optional] build speech token related modules
self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size)
# 4. sampling method
## use vllm sampling method
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.mix_ratio: List[int] = [5, 15]
# 定义特殊token常量
self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32)
self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404
self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405
self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32)
self.zero_embed_buffer = torch.zeros(
(vllm_config.scheduler_config.max_num_seqs, self.llm_input_size),
dtype=self.llm_embedding.weight.dtype,
device=self.llm_embedding.weight.device
)
self.inputs_embed_buffer = torch.zeros(
(vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size),
dtype=self.llm_embedding.weight.dtype,
device=self.llm_embedding.weight.device,
)
def get_sos_eos_emb(self):
return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
def get_task_id_emb(self):
return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[T] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
) -> torch.Tensor:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
kwargs.
"""
# 创建掩码,标记哪些 token_id 属于音频 Token
mask = input_ids < self.speech_token_size
# 获取 input_ids 的原始形状
input_shape = input_ids.shape
# 展平 input_ids 和掩码以便统一处理
flat_input_ids = input_ids.view(-1)
flat_mask = mask.view(-1)
inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]]
inputs_embeds.zero_()
# Process speech tokens
if flat_mask.any():
speech_token_ids = flat_input_ids[flat_mask]
inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids)
# 处理大于 delta 的 token_id
if (~flat_mask).any():
llm_token_ids = flat_input_ids[~flat_mask]
llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask])
sos_eos_mask = llm_token_ids == self.sos_eos_token_id
task_mask = llm_token_ids == self.task_token_id
zero_mask = llm_token_ids == self.zero_token_id
normal_mask = ~(sos_eos_mask | task_mask | zero_mask)
# 分层处理逻辑
# 第一优先级SOS/EOS标记
if sos_eos_mask.any():
llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0)
# 第二优先级:任务标记
if task_mask.any():
llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0)
# 第二优先级:空音频标记
if zero_mask.any():
llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])]
# 常规LLM token
if normal_mask.any():
original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta
# print('original_ids: ',original_ids)
llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids)
inputs_embeds[~flat_mask] = llm_embeds
inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size)
# 合并多模态嵌入(如果有)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(
input_ids,
attn_metadata=attn_metadata,
)
return self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.llm_decoder, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
@staticmethod
def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]:
for name, param in weights:
# 处理Qwen2Model核心参数
if name.startswith("llm."):
if name.startswith("llm.model.model."):
name = name.replace("llm.model.model.", "model.")
else:
continue
# print('weights name: ', name)
yield name, param
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = self.convert_weights(weights)
loader = AutoWeightsLoader(self)
loader.load_weights(weights)