mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -52,5 +52,5 @@ jobs:
|
|||||||
set -eux
|
set -eux
|
||||||
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
|
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
|
||||||
flake8 --version
|
flake8 --version
|
||||||
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
||||||
if [ $? != 0 ]; then exit 1; fi
|
if [ $? != 0 ]; then exit 1; fi
|
||||||
@@ -48,7 +48,7 @@ class CosyVoice:
|
|||||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||||
load_jit, load_trt, fp16 = False, False, False
|
load_jit, load_trt, fp16 = False, False, False
|
||||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||||
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
|
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||||
self.model.load('{}/llm.pt'.format(model_dir),
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
'{}/flow.pt'.format(model_dir),
|
'{}/flow.pt'.format(model_dir),
|
||||||
'{}/hift.pt'.format(model_dir))
|
'{}/hift.pt'.format(model_dir))
|
||||||
@@ -59,6 +59,7 @@ class CosyVoice:
|
|||||||
if load_trt:
|
if load_trt:
|
||||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||||
|
trt_concurrent,
|
||||||
self.fp16)
|
self.fp16)
|
||||||
del configs
|
del configs
|
||||||
|
|
||||||
@@ -162,7 +163,7 @@ class CosyVoice2(CosyVoice):
|
|||||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||||
load_jit, load_trt, fp16 = False, False, False
|
load_jit, load_trt, fp16 = False, False, False
|
||||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
|
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||||
self.model.load('{}/llm.pt'.format(model_dir),
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
'{}/flow.pt'.format(model_dir),
|
'{}/flow.pt'.format(model_dir),
|
||||||
'{}/hift.pt'.format(model_dir))
|
'{}/hift.pt'.format(model_dir))
|
||||||
@@ -173,6 +174,7 @@ class CosyVoice2(CosyVoice):
|
|||||||
if load_trt:
|
if load_trt:
|
||||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||||
|
trt_concurrent,
|
||||||
self.fp16)
|
self.fp16)
|
||||||
del configs
|
del configs
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
import queue
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import threading
|
import threading
|
||||||
@@ -33,14 +32,12 @@ class CosyVoiceModel:
|
|||||||
llm: torch.nn.Module,
|
llm: torch.nn.Module,
|
||||||
flow: torch.nn.Module,
|
flow: torch.nn.Module,
|
||||||
hift: torch.nn.Module,
|
hift: torch.nn.Module,
|
||||||
fp16: bool = False,
|
fp16: bool = False):
|
||||||
trt_concurrent: int = 1):
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.hift = hift
|
self.hift = hift
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
self.trt_concurrent = trt_concurrent
|
|
||||||
if self.fp16 is True:
|
if self.fp16 is True:
|
||||||
self.llm.half()
|
self.llm.half()
|
||||||
self.flow.half()
|
self.flow.half()
|
||||||
@@ -85,7 +82,7 @@ class CosyVoiceModel:
|
|||||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
self.flow.encoder = flow_encoder
|
self.flow.encoder = flow_encoder
|
||||||
|
|
||||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
|
||||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||||
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
||||||
@@ -94,7 +91,7 @@ class CosyVoiceModel:
|
|||||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent, device=self.device)
|
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||||
|
|
||||||
def get_trt_kwargs(self):
|
def get_trt_kwargs(self):
|
||||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
||||||
@@ -104,7 +101,7 @@ class CosyVoiceModel:
|
|||||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
|
|
||||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||||
with self.llm_context, torch.cuda.amp.autocast(self.fp16):
|
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
||||||
if isinstance(text, Generator):
|
if isinstance(text, Generator):
|
||||||
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
||||||
for i in self.llm.inference_bistream(text=text,
|
for i in self.llm.inference_bistream(text=text,
|
||||||
@@ -246,14 +243,12 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
llm: torch.nn.Module,
|
llm: torch.nn.Module,
|
||||||
flow: torch.nn.Module,
|
flow: torch.nn.Module,
|
||||||
hift: torch.nn.Module,
|
hift: torch.nn.Module,
|
||||||
fp16: bool = False,
|
fp16: bool = False):
|
||||||
trt_concurrent: int = 1):
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.hift = hift
|
self.hift = hift
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
self.trt_concurrent = trt_concurrent
|
|
||||||
if self.fp16 is True:
|
if self.fp16 is True:
|
||||||
self.llm.half()
|
self.llm.half()
|
||||||
self.flow.half()
|
self.flow.half()
|
||||||
|
|||||||
@@ -12,7 +12,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import threading
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from matcha.models.components.flow_matching import BASECFM
|
from matcha.models.components.flow_matching import BASECFM
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||||
|
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua)
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@@ -1,212 +0,0 @@
|
|||||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import time
|
|
||||||
import queue
|
|
||||||
import asyncio
|
|
||||||
import threading
|
|
||||||
from typing import List, Generator, AsyncGenerator
|
|
||||||
import torch
|
|
||||||
from cosyvoice.utils.file_utils import logging
|
|
||||||
from cosyvoice.llm.llm import Qwen2LM
|
|
||||||
|
|
||||||
# 启用vllm V1版本
|
|
||||||
import os
|
|
||||||
os.environ["VLLM_USE_V1"] = '1'
|
|
||||||
from vllm import ModelRegistry
|
|
||||||
from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput
|
|
||||||
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
|
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
|
|
||||||
from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
|
|
||||||
ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
|
|
||||||
|
|
||||||
# EngineArgs
|
|
||||||
ENGINE_ARGS = {
|
|
||||||
"block_size": 16,
|
|
||||||
"swap_space": 0,
|
|
||||||
# "enforce_eager": True,
|
|
||||||
"gpu_memory_utilization": 0.4,
|
|
||||||
"max_num_batched_tokens": 1024,
|
|
||||||
"max_model_len": 1024,
|
|
||||||
"max_num_seqs": 256,
|
|
||||||
"disable_log_requests": True,
|
|
||||||
"disable_log_stats": True,
|
|
||||||
"dtype": "float16"
|
|
||||||
}
|
|
||||||
|
|
||||||
from vllm.sampling_params import RequestOutputKind
|
|
||||||
# SamplingParams
|
|
||||||
SAMPLING_PARAMS = {
|
|
||||||
"temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
|
|
||||||
"top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
|
|
||||||
"top_k": 25,
|
|
||||||
# "min_tokens": 80, # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动
|
|
||||||
# "presence_penalty": 1.0, # 不支持设置
|
|
||||||
# "frequency_penalty": 0.0, # 不支持设置
|
|
||||||
"max_tokens": 1024,
|
|
||||||
"detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算
|
|
||||||
"ignore_eos": False,
|
|
||||||
"output_kind": RequestOutputKind.DELTA # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码
|
|
||||||
}
|
|
||||||
|
|
||||||
def tensor_to_list(tensor: torch.tensor):
|
|
||||||
return tensor.view(-1).cpu().numpy().tolist()
|
|
||||||
|
|
||||||
class VllmQwen2LM(Qwen2LM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_dir,
|
|
||||||
mix_ratio: List[int] = [5, 15],
|
|
||||||
):
|
|
||||||
self.fp16 = False
|
|
||||||
self.half = lambda: None
|
|
||||||
self.mix_ratio = mix_ratio
|
|
||||||
# ---------------------------------------------
|
|
||||||
# vllm engine 的参数配置
|
|
||||||
engine_args = AsyncEngineArgs(
|
|
||||||
model=model_dir,
|
|
||||||
**ENGINE_ARGS,
|
|
||||||
)
|
|
||||||
self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
||||||
|
|
||||||
self.speech_token_size = 6564 # 6561 + 3
|
|
||||||
self.llm_token_size = 151936 # llm vocab_size
|
|
||||||
self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1
|
|
||||||
self.task_token_id = self.sos_eos_token_id + 1
|
|
||||||
self.zero_token_id = self.task_token_id + 1
|
|
||||||
|
|
||||||
# vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务
|
|
||||||
self.loop = asyncio.new_event_loop()
|
|
||||||
self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
|
|
||||||
self.loop_thread.start()
|
|
||||||
|
|
||||||
def _run_event_loop(self):
|
|
||||||
asyncio.set_event_loop(self.loop)
|
|
||||||
self.loop.run_forever()
|
|
||||||
|
|
||||||
async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens):
|
|
||||||
sampling_params = SamplingParams(**SAMPLING_PARAMS)
|
|
||||||
sampling_params.stop_token_ids = stop_token_ids or [6561]
|
|
||||||
if max_tokens:
|
|
||||||
sampling_params.max_tokens = max_tokens
|
|
||||||
async for output in self.llm_engine.generate(
|
|
||||||
{
|
|
||||||
"prompt_token_ids": prompt_token_ids,
|
|
||||||
},
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
request_id=request_id or f"{time.time()}",
|
|
||||||
):
|
|
||||||
out_queue.put((output.outputs[0], output.finished))
|
|
||||||
|
|
||||||
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
|
|
||||||
out_queue = queue.Queue()
|
|
||||||
asyncio.run_coroutine_threadsafe(
|
|
||||||
self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop
|
|
||||||
)
|
|
||||||
# 接收 out_queue 返回的结果
|
|
||||||
finished = False
|
|
||||||
while not finished:
|
|
||||||
(output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
|
|
||||||
yield output
|
|
||||||
|
|
||||||
def inference(
|
|
||||||
self,
|
|
||||||
text: torch.Tensor,
|
|
||||||
text_len: torch.Tensor,
|
|
||||||
prompt_text: torch.Tensor,
|
|
||||||
prompt_text_len: torch.Tensor,
|
|
||||||
prompt_speech_token: torch.Tensor,
|
|
||||||
prompt_speech_token_len: torch.Tensor,
|
|
||||||
embedding: torch.Tensor,
|
|
||||||
sampling: int = 25,
|
|
||||||
max_token_text_ratio: float = 20,
|
|
||||||
min_token_text_ratio: float = 2,
|
|
||||||
) -> Generator[torch.Tensor|int, None, None]:
|
|
||||||
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
|
|
||||||
prompt_speech_token = tensor_to_list(prompt_speech_token)
|
|
||||||
|
|
||||||
text = tensor_to_list(text + torch.tensor(6564))
|
|
||||||
prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \
|
|
||||||
[self.task_token_id] + prompt_speech_token
|
|
||||||
max_tokens = len(text) * 20
|
|
||||||
for output in self.llm_inference(
|
|
||||||
prompt_token_ids,
|
|
||||||
stop_token_ids=[6561],
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
):
|
|
||||||
if output.token_ids[-1] == 6561:
|
|
||||||
need_add_tokens = output.token_ids[:-1]
|
|
||||||
else:
|
|
||||||
need_add_tokens = output.token_ids
|
|
||||||
for token in need_add_tokens:
|
|
||||||
yield token
|
|
||||||
|
|
||||||
def inference_bistream(
|
|
||||||
self,
|
|
||||||
text: Generator,
|
|
||||||
prompt_text: torch.Tensor,
|
|
||||||
prompt_text_len: torch.Tensor,
|
|
||||||
prompt_speech_token: torch.Tensor,
|
|
||||||
prompt_speech_token_len: torch.Tensor,
|
|
||||||
embedding: torch.Tensor,
|
|
||||||
sampling: int = 25,
|
|
||||||
max_token_text_ratio: float = 20,
|
|
||||||
min_token_text_ratio: float = 2,
|
|
||||||
) -> Generator[torch.Tensor, None, None]:
|
|
||||||
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
|
|
||||||
prompt_speech_token = tensor_to_list(prompt_speech_token)
|
|
||||||
|
|
||||||
last_tokens = []
|
|
||||||
prompt_token_ids = [self.sos_eos_token_id]
|
|
||||||
text_tokens_cache = prompt_text
|
|
||||||
for this_text in text:
|
|
||||||
this_text = tensor_to_list(this_text + torch.tensor(6564))
|
|
||||||
# text need tokens
|
|
||||||
assert isinstance(this_text, list), "text need token ids List[int]."
|
|
||||||
text_tokens_cache += this_text
|
|
||||||
while len(prompt_speech_token) != 0:
|
|
||||||
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
|
||||||
text_input_token = text_tokens_cache[:self.mix_ratio[0]]
|
|
||||||
speech_input_token = prompt_speech_token[:self.mix_ratio[1]]
|
|
||||||
prompt_token_ids += text_input_token + speech_input_token
|
|
||||||
# reset the last cache
|
|
||||||
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
|
||||||
prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
if len(prompt_speech_token) == 0:
|
|
||||||
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
|
|
||||||
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
|
||||||
text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
|
|
||||||
prompt_token_ids += text_tokens_temp
|
|
||||||
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
|
|
||||||
last_tokens = output.token_ids
|
|
||||||
if last_tokens[-1] == 6563:
|
|
||||||
need_add_tokens = last_tokens[:-1]
|
|
||||||
else:
|
|
||||||
need_add_tokens = last_tokens
|
|
||||||
for token in need_add_tokens:
|
|
||||||
yield token
|
|
||||||
prompt_token_ids.extend(need_add_tokens)
|
|
||||||
prompt_token_ids += text_tokens_cache + [self.task_token_id]
|
|
||||||
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
|
|
||||||
if output.token_ids[-1] == 6561:
|
|
||||||
need_add_tokens = output.token_ids[:-1]
|
|
||||||
else:
|
|
||||||
need_add_tokens = output.token_ids
|
|
||||||
for token in need_add_tokens:
|
|
||||||
yield token
|
|
||||||
@@ -16,7 +16,8 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import torch, torchaudio
|
import torch
|
||||||
|
import torchaudio
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||||
logging.basicConfig(level=logging.DEBUG,
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
|
|||||||
103
cosyvoice/vllm/cosyvoice2.py
Normal file
103
cosyvoice/vllm/cosyvoice2.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
||||||
|
# Copyright 2024 The Qwen team.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||||
|
from vllm.model_executor.models.qwen2 import *
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "lm_head"))
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata, self.lm_head.bias)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."]
|
||||||
|
if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
return loader.load_weights(weights)
|
||||||
Reference in New Issue
Block a user