mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
Merge pull request #1337 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
This commit is contained in:
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -52,5 +52,5 @@ jobs:
|
||||
set -eux
|
||||
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
|
||||
flake8 --version
|
||||
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
||||
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
||||
if [ $? != 0 ]; then exit 1; fi
|
||||
29
README.md
29
README.md
@@ -26,6 +26,10 @@
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [x] 2025/05
|
||||
|
||||
- [x] add cosyvoice 2.0 vllm support
|
||||
|
||||
- [x] 2024/12
|
||||
|
||||
- [x] 25hz cosyvoice 2.0 released
|
||||
@@ -126,7 +130,7 @@ import torchaudio
|
||||
|
||||
#### CosyVoice2 Usage
|
||||
```python
|
||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
|
||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, load_vllm=False, fp16=False)
|
||||
|
||||
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
||||
# zero_shot usage
|
||||
@@ -159,7 +163,28 @@ for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
```
|
||||
|
||||
#### CosyVoice Usage
|
||||
If you want to use vllm for inference, please install `vllm==v0.9.0`. Older vllm version do not support CosyVoice2 inference.
|
||||
|
||||
Notice that `vllm==v0.9.0` has a lot of specific requirements, for example `torch==2.7.0`. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
|
||||
|
||||
``` sh
|
||||
conda create -n cosyvoice_vllm --clone cosyvoice
|
||||
pip install vllm==v0.9.0 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
conda activate cosyvoice_vllm
|
||||
```
|
||||
|
||||
```python
|
||||
import sys
|
||||
sys.path.append('third_party/Matcha-TTS')
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from vllm import ModelRegistry
|
||||
from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
|
||||
ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
|
||||
|
||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, load_vllm=True, fp16=False)
|
||||
```
|
||||
|
||||
**CosyVoice Usage**
|
||||
```python
|
||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
|
||||
# sft usage
|
||||
|
||||
@@ -48,7 +48,7 @@ class CosyVoice:
|
||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||
load_jit, load_trt, fp16 = False, False, False
|
||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
|
||||
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
@@ -59,6 +59,7 @@ class CosyVoice:
|
||||
if load_trt:
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
del configs
|
||||
|
||||
@@ -140,7 +141,7 @@ class CosyVoice:
|
||||
|
||||
class CosyVoice2(CosyVoice):
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
||||
self.instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
@@ -162,15 +163,18 @@ class CosyVoice2(CosyVoice):
|
||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||
load_jit, load_trt, fp16 = False, False, False
|
||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if load_vllm:
|
||||
self.model.load_vllm('{}/vllm'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
del configs
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Generator
|
||||
import queue
|
||||
import torch
|
||||
import numpy as np
|
||||
import threading
|
||||
@@ -23,7 +22,7 @@ from torch.nn import functional as F
|
||||
from contextlib import nullcontext
|
||||
import uuid
|
||||
from cosyvoice.utils.common import fade_in_out
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
||||
from cosyvoice.utils.common import TrtContextWrapper
|
||||
|
||||
|
||||
@@ -33,14 +32,12 @@ class CosyVoiceModel:
|
||||
llm: torch.nn.Module,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False,
|
||||
trt_concurrent: int = 1):
|
||||
fp16: bool = False):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
self.trt_concurrent = trt_concurrent
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
@@ -59,9 +56,6 @@ class CosyVoiceModel:
|
||||
self.stream_scale_factor = 1
|
||||
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
for _ in range(trt_concurrent):
|
||||
self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
||||
self.lock = threading.Lock()
|
||||
# dict used to store session related variable
|
||||
self.tts_speech_token_dict = {}
|
||||
@@ -69,7 +63,6 @@ class CosyVoiceModel:
|
||||
self.mel_overlap_dict = {}
|
||||
self.flow_cache_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
self.trt_context_dict = {}
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||
@@ -89,7 +82,7 @@ class CosyVoiceModel:
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
||||
@@ -98,7 +91,7 @@ class CosyVoiceModel:
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_trt_kwargs(self):
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
||||
@@ -108,7 +101,7 @@ class CosyVoiceModel:
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
with self.llm_context, torch.cuda.amp.autocast(self.fp16):
|
||||
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
||||
if isinstance(text, Generator):
|
||||
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
||||
for i in self.llm.inference_bistream(text=text,
|
||||
@@ -125,7 +118,8 @@ class CosyVoiceModel:
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device)):
|
||||
embedding=llm_embedding.to(self.device),
|
||||
uuid=uuid):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
self.llm_end_dict[uuid] = True
|
||||
|
||||
@@ -180,13 +174,11 @@ class CosyVoiceModel:
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
this_trt_context = self.trt_context_pool.get()
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
||||
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
||||
self.trt_context_dict[this_uuid] = this_trt_context
|
||||
if source_speech_token.shape[1] == 0:
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
else:
|
||||
@@ -240,8 +232,6 @@ class CosyVoiceModel:
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.flow_cache_dict.pop(this_uuid)
|
||||
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
|
||||
self.trt_context_dict.pop(this_uuid)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
@@ -253,14 +243,12 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
llm: torch.nn.Module,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False,
|
||||
trt_concurrent: int = 1):
|
||||
fp16: bool = False):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
self.trt_concurrent = trt_concurrent
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
@@ -273,22 +261,28 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||
# rtf and decoding related
|
||||
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
for _ in range(trt_concurrent):
|
||||
self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
||||
self.lock = threading.Lock()
|
||||
# dict used to store session related variable
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
self.trt_context_dict = {}
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load_vllm(self, model_dir):
|
||||
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
|
||||
from vllm import EngineArgs, LLMEngine
|
||||
engine_args = EngineArgs(model=model_dir,
|
||||
skip_tokenizer_init=True,
|
||||
enable_prompt_embeds=True,
|
||||
gpu_memory_utilization=0.2)
|
||||
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
|
||||
del self.llm.llm.model.model.layers
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
@@ -330,11 +324,9 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
this_trt_context = self.trt_context_pool.get()
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.trt_context_dict[this_uuid] = this_trt_context
|
||||
if source_speech_token.shape[1] == 0:
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
else:
|
||||
@@ -388,8 +380,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
|
||||
self.trt_context_dict.pop(this_uuid)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
@@ -12,10 +12,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import threading
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.flow_matching import BASECFM
|
||||
from cosyvoice.utils.common import set_all_random_seed
|
||||
|
||||
|
||||
class ConditionalCFM(BASECFM):
|
||||
@@ -32,7 +32,6 @@ class ConditionalCFM(BASECFM):
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
||||
@@ -127,26 +126,27 @@ class ConditionalCFM(BASECFM):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
|
||||
else:
|
||||
estimator, trt_engine = self.estimator.acquire_estimator()
|
||||
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('t', (2,))
|
||||
estimator.set_input_shape('spks', (2, 80))
|
||||
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
data_ptrs = [x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
# run trt engine
|
||||
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.estimator.release_estimator(estimator)
|
||||
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
|
||||
with stream:
|
||||
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('t', (2,))
|
||||
estimator.set_input_shape('spks', (2, 80))
|
||||
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
data_ptrs = [x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
# run trt engine
|
||||
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.estimator.release_estimator(estimator, stream)
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
||||
@@ -194,6 +194,7 @@ class ConditionalCFM(BASECFM):
|
||||
class CausalConditionalCFM(ConditionalCFM):
|
||||
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
||||
set_all_random_seed(0)
|
||||
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -11,7 +12,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import queue
|
||||
import random
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, Optional, Callable, List, Generator
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -170,6 +174,7 @@ class TransformerLM(torch.nn.Module):
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
uuid: str = '',
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
@@ -270,7 +275,6 @@ class Qwen2LM(TransformerLM):
|
||||
self.llm_input_size = llm_input_size
|
||||
self.llm_output_size = llm_output_size
|
||||
self.speech_token_size = speech_token_size
|
||||
|
||||
# 2. build speech token language model related modules
|
||||
self.sos_eos = 0
|
||||
self.task_id = 1
|
||||
@@ -293,6 +297,11 @@ class Qwen2LM(TransformerLM):
|
||||
self.sampling = sampling
|
||||
self.mix_ratio = mix_ratio
|
||||
|
||||
# 5. vllm related
|
||||
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
||||
self.vllm_output_queue = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
|
||||
lm_target, lm_input = [], []
|
||||
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
||||
@@ -382,6 +391,7 @@ class Qwen2LM(TransformerLM):
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
uuid: str = '',
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
@@ -402,22 +412,55 @@ class Qwen2LM(TransformerLM):
|
||||
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||
|
||||
# 5. step by step decode
|
||||
out_tokens = []
|
||||
cache = None
|
||||
for i in range(max_len):
|
||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
break
|
||||
if top_ids > self.speech_token_size:
|
||||
continue
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
|
||||
yield token
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
|
||||
if hasattr(self, 'vllm'):
|
||||
from vllm import SamplingParams, RequestOutput
|
||||
sampling_params = SamplingParams(top_k=sampling,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
min_tokens=min_len,
|
||||
max_tokens=max_len)
|
||||
with self.lock:
|
||||
self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
|
||||
self.vllm_output_queue[uuid] = queue.Queue()
|
||||
out_tokens = []
|
||||
while True:
|
||||
with self.lock:
|
||||
if self.vllm_output_queue[uuid].empty() is True:
|
||||
request_outputs: List[RequestOutput] = self.vllm.step()
|
||||
for request_output in request_outputs:
|
||||
top_ids = list(request_output.outputs[0].token_ids)[-1]
|
||||
self.vllm_output_queue[request_output.request_id].put(top_ids)
|
||||
if self.vllm_output_queue[uuid].empty() is False:
|
||||
top_ids = self.vllm_output_queue[uuid].get()
|
||||
if top_ids in self.stop_token_ids:
|
||||
break
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
time.sleep(0.001)
|
||||
with self.lock:
|
||||
self.vllm_output_queue.pop(uuid)
|
||||
else:
|
||||
out_tokens = []
|
||||
cache = None
|
||||
for i in range(max_len):
|
||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
break
|
||||
if top_ids > self.speech_token_size:
|
||||
continue
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_bistream(
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
import queue
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import List, Generator, AsyncGenerator
|
||||
import torch
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.llm.llm import Qwen2LM
|
||||
|
||||
# 启用vllm V1版本
|
||||
import os
|
||||
os.environ["VLLM_USE_V1"] = '1'
|
||||
from vllm import ModelRegistry
|
||||
from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput
|
||||
from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
|
||||
ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
|
||||
|
||||
# EngineArgs
|
||||
ENGINE_ARGS = {
|
||||
"block_size": 16,
|
||||
"swap_space": 0,
|
||||
# "enforce_eager": True,
|
||||
"gpu_memory_utilization": 0.4,
|
||||
"max_num_batched_tokens": 1024,
|
||||
"max_model_len": 1024,
|
||||
"max_num_seqs": 256,
|
||||
"disable_log_requests": True,
|
||||
"disable_log_stats": True,
|
||||
"dtype": "float16"
|
||||
}
|
||||
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
# SamplingParams
|
||||
SAMPLING_PARAMS = {
|
||||
"temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
|
||||
"top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
|
||||
"top_k": 25,
|
||||
# "min_tokens": 80, # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动
|
||||
# "presence_penalty": 1.0, # 不支持设置
|
||||
# "frequency_penalty": 0.0, # 不支持设置
|
||||
"max_tokens": 1024,
|
||||
"detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算
|
||||
"ignore_eos": False,
|
||||
"output_kind": RequestOutputKind.DELTA # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码
|
||||
}
|
||||
|
||||
def tensor_to_list(tensor: torch.tensor):
|
||||
return tensor.view(-1).cpu().numpy().tolist()
|
||||
|
||||
class VllmQwen2LM(Qwen2LM):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir,
|
||||
mix_ratio: List[int] = [5, 15],
|
||||
):
|
||||
self.fp16 = False
|
||||
self.half = lambda: None
|
||||
self.mix_ratio = mix_ratio
|
||||
# ---------------------------------------------
|
||||
# vllm engine 的参数配置
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_dir,
|
||||
**ENGINE_ARGS,
|
||||
)
|
||||
self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
self.speech_token_size = 6564 # 6561 + 3
|
||||
self.llm_token_size = 151936 # llm vocab_size
|
||||
self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1
|
||||
self.task_token_id = self.sos_eos_token_id + 1
|
||||
self.zero_token_id = self.task_token_id + 1
|
||||
|
||||
# vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
|
||||
self.loop_thread.start()
|
||||
|
||||
def _run_event_loop(self):
|
||||
asyncio.set_event_loop(self.loop)
|
||||
self.loop.run_forever()
|
||||
|
||||
async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens):
|
||||
sampling_params = SamplingParams(**SAMPLING_PARAMS)
|
||||
sampling_params.stop_token_ids = stop_token_ids or [6561]
|
||||
if max_tokens:
|
||||
sampling_params.max_tokens = max_tokens
|
||||
async for output in self.llm_engine.generate(
|
||||
{
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id or f"{time.time()}",
|
||||
):
|
||||
out_queue.put((output.outputs[0], output.finished))
|
||||
|
||||
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
|
||||
out_queue = queue.Queue()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop
|
||||
)
|
||||
# 接收 out_queue 返回的结果
|
||||
finished = False
|
||||
while not finished:
|
||||
(output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
|
||||
yield output
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_len: torch.Tensor,
|
||||
prompt_text: torch.Tensor,
|
||||
prompt_text_len: torch.Tensor,
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
) -> Generator[torch.Tensor|int, None, None]:
|
||||
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
|
||||
prompt_speech_token = tensor_to_list(prompt_speech_token)
|
||||
|
||||
text = tensor_to_list(text + torch.tensor(6564))
|
||||
prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \
|
||||
[self.task_token_id] + prompt_speech_token
|
||||
max_tokens = len(text) * 20
|
||||
for output in self.llm_inference(
|
||||
prompt_token_ids,
|
||||
stop_token_ids=[6561],
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
if output.token_ids[-1] == 6561:
|
||||
need_add_tokens = output.token_ids[:-1]
|
||||
else:
|
||||
need_add_tokens = output.token_ids
|
||||
for token in need_add_tokens:
|
||||
yield token
|
||||
|
||||
def inference_bistream(
|
||||
self,
|
||||
text: Generator,
|
||||
prompt_text: torch.Tensor,
|
||||
prompt_text_len: torch.Tensor,
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
|
||||
prompt_speech_token = tensor_to_list(prompt_speech_token)
|
||||
|
||||
last_tokens = []
|
||||
prompt_token_ids = [self.sos_eos_token_id]
|
||||
text_tokens_cache = prompt_text
|
||||
for this_text in text:
|
||||
this_text = tensor_to_list(this_text + torch.tensor(6564))
|
||||
# text need tokens
|
||||
assert isinstance(this_text, list), "text need token ids List[int]."
|
||||
text_tokens_cache += this_text
|
||||
while len(prompt_speech_token) != 0:
|
||||
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
||||
text_input_token = text_tokens_cache[:self.mix_ratio[0]]
|
||||
speech_input_token = prompt_speech_token[:self.mix_ratio[1]]
|
||||
prompt_token_ids += text_input_token + speech_input_token
|
||||
# reset the last cache
|
||||
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
||||
prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
|
||||
else:
|
||||
break
|
||||
if len(prompt_speech_token) == 0:
|
||||
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
|
||||
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
||||
text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
|
||||
prompt_token_ids += text_tokens_temp
|
||||
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
||||
else:
|
||||
continue
|
||||
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
|
||||
last_tokens = output.token_ids
|
||||
if last_tokens[-1] == 6563:
|
||||
need_add_tokens = last_tokens[:-1]
|
||||
else:
|
||||
need_add_tokens = last_tokens
|
||||
for token in need_add_tokens:
|
||||
yield token
|
||||
prompt_token_ids.extend(need_add_tokens)
|
||||
prompt_token_ids += text_tokens_cache + [self.task_token_id]
|
||||
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
|
||||
if output.token_ids[-1] == 6561:
|
||||
need_add_tokens = output.token_ids[:-1]
|
||||
else:
|
||||
need_add_tokens = output.token_ids
|
||||
for token in need_add_tokens:
|
||||
yield token
|
||||
@@ -1,263 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm.model_executor.models.interfaces import T
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IGNORE_ID = -1
|
||||
|
||||
|
||||
class CosyVoice2Model(nn.Module):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.llm_input_size = 896
|
||||
self.llm_output_size = 896
|
||||
|
||||
self.speech_token_size = 6561+3
|
||||
self.llm_token_size = config.vocab_size
|
||||
|
||||
# 2. build speech token language model related modules
|
||||
self.sos_eos = 0
|
||||
self.task_id = 1
|
||||
self.fill_token = 2
|
||||
|
||||
|
||||
self.allow_patterns_overrides = ["llm.*"]
|
||||
self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size)
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
# self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size)
|
||||
self.llm_decoder = ParallelLMHead(self.speech_token_size,
|
||||
self.llm_output_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "llm_decoder"))
|
||||
self.logits_processor = LogitsProcessor(self.speech_token_size)
|
||||
|
||||
# length_normalized_loss: bool = True,
|
||||
# lsm_weight: float = 0.0,
|
||||
# self.criterion_ce = LabelSmoothingLoss(
|
||||
# size=self.speech_token_size,
|
||||
# padding_idx=IGNORE_ID,
|
||||
# smoothing=lsm_weight,
|
||||
# normalize_length=length_normalized_loss,
|
||||
# )
|
||||
|
||||
# 3. [Optional] build speech token related modules
|
||||
self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size)
|
||||
|
||||
# 4. sampling method
|
||||
## use vllm sampling method
|
||||
self.sampler = get_sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
self.mix_ratio: List[int] = [5, 15]
|
||||
|
||||
# 定义特殊token常量
|
||||
self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32)
|
||||
self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404
|
||||
self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405
|
||||
self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32)
|
||||
|
||||
self.zero_embed_buffer = torch.zeros(
|
||||
(vllm_config.scheduler_config.max_num_seqs, self.llm_input_size),
|
||||
dtype=self.llm_embedding.weight.dtype,
|
||||
device=self.llm_embedding.weight.device
|
||||
)
|
||||
self.inputs_embed_buffer = torch.zeros(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size),
|
||||
dtype=self.llm_embedding.weight.dtype,
|
||||
device=self.llm_embedding.weight.device,
|
||||
)
|
||||
|
||||
def get_sos_eos_emb(self):
|
||||
return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
|
||||
def get_task_id_emb(self):
|
||||
return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[T] = None,
|
||||
attn_metadata: Optional["AttentionMetadata"] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns the input embeddings merged from the text embeddings from
|
||||
input_ids and the multimodal embeddings generated from multimodal
|
||||
kwargs.
|
||||
"""
|
||||
# 创建掩码,标记哪些 token_id 属于音频 Token
|
||||
mask = input_ids < self.speech_token_size
|
||||
|
||||
# 获取 input_ids 的原始形状
|
||||
input_shape = input_ids.shape
|
||||
# 展平 input_ids 和掩码以便统一处理
|
||||
flat_input_ids = input_ids.view(-1)
|
||||
flat_mask = mask.view(-1)
|
||||
|
||||
inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]]
|
||||
inputs_embeds.zero_()
|
||||
|
||||
# Process speech tokens
|
||||
if flat_mask.any():
|
||||
speech_token_ids = flat_input_ids[flat_mask]
|
||||
inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids)
|
||||
|
||||
# 处理大于 delta 的 token_id
|
||||
if (~flat_mask).any():
|
||||
llm_token_ids = flat_input_ids[~flat_mask]
|
||||
llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask])
|
||||
|
||||
sos_eos_mask = llm_token_ids == self.sos_eos_token_id
|
||||
task_mask = llm_token_ids == self.task_token_id
|
||||
zero_mask = llm_token_ids == self.zero_token_id
|
||||
normal_mask = ~(sos_eos_mask | task_mask | zero_mask)
|
||||
|
||||
# 分层处理逻辑
|
||||
# 第一优先级:SOS/EOS标记
|
||||
if sos_eos_mask.any():
|
||||
llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0)
|
||||
|
||||
# 第二优先级:任务标记
|
||||
if task_mask.any():
|
||||
llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0)
|
||||
|
||||
# 第二优先级:空音频标记
|
||||
if zero_mask.any():
|
||||
llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])]
|
||||
|
||||
# 常规LLM token
|
||||
if normal_mask.any():
|
||||
original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta
|
||||
# print('original_ids: ',original_ids)
|
||||
llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids)
|
||||
|
||||
inputs_embeds[~flat_mask] = llm_embeds
|
||||
|
||||
inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size)
|
||||
|
||||
# 合并多模态嵌入(如果有)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.audio_token_index
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.llm_decoder, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
@staticmethod
|
||||
def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
for name, param in weights:
|
||||
# 处理Qwen2Model核心参数
|
||||
if name.startswith("llm."):
|
||||
if name.startswith("llm.model.model."):
|
||||
name = name.replace("llm.model.model.", "model.")
|
||||
else:
|
||||
continue
|
||||
# print('weights name: ', name)
|
||||
yield name, param
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
weights = self.convert_weights(weights)
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
@@ -169,17 +169,18 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1):
|
||||
self.trt_context_pool = queue.Queue()
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
self.trt_engine = trt_engine
|
||||
for _ in range(trt_concurrent):
|
||||
trt_context = trt_engine.create_execution_context()
|
||||
trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
|
||||
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||
self.trt_context_pool.put(trt_context)
|
||||
self.trt_context_pool.put([trt_context, trt_stream])
|
||||
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.trt_context_pool.get(), self.trt_engine
|
||||
|
||||
def release_estimator(self, context):
|
||||
self.trt_context_pool.put(context)
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,7 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import torchaudio
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
@@ -56,7 +59,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) # 1GB
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
profile = builder.create_optimization_profile()
|
||||
@@ -83,3 +86,44 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
with open(trt_model, "wb") as f:
|
||||
f.write(engine_bytes)
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
|
||||
def export_cosyvoice2_vllm(model, model_path, device):
|
||||
if os.path.exists(model_path):
|
||||
return
|
||||
pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
vocab_size = model.speech_embedding.num_embeddings
|
||||
feature_size = model.speech_embedding.embedding_dim
|
||||
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
dtype = torch.bfloat16
|
||||
# lm_head
|
||||
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
|
||||
with torch.no_grad():
|
||||
new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
|
||||
new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
|
||||
new_lm_head.weight[vocab_size:] = 0
|
||||
new_lm_head.bias[vocab_size:] = 0
|
||||
model.llm.model.lm_head = new_lm_head
|
||||
new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
||||
# embed_tokens
|
||||
embed_tokens = model.llm.model.model.embed_tokens
|
||||
with torch.no_grad():
|
||||
new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight
|
||||
new_codec_embed.weight[vocab_size:] = 0
|
||||
model.llm.model.set_input_embeddings(new_codec_embed)
|
||||
model.llm.model.to(device)
|
||||
model.llm.model.to(dtype)
|
||||
tmp_vocab_size = model.llm.model.config.vocab_size
|
||||
tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
|
||||
del model.llm.model.generation_config.eos_token_id
|
||||
del model.llm.model.config.bos_token_id
|
||||
del model.llm.model.config.eos_token_id
|
||||
model.llm.model.config.vocab_size = pad_vocab_size
|
||||
model.llm.model.config.tie_word_embeddings = False
|
||||
model.llm.model.config.use_bias = True
|
||||
model.llm.model.save_pretrained(model_path)
|
||||
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||
model.llm.model.config.vocab_size = tmp_vocab_size
|
||||
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||
model.llm.model.set_input_embeddings(embed_tokens)
|
||||
|
||||
103
cosyvoice/vllm/cosyvoice2.py
Normal file
103
cosyvoice/vllm/cosyvoice2.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from vllm.model_executor.models.qwen2 import *
|
||||
|
||||
|
||||
class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684
|
||||
conformer==0.3.2
|
||||
deepspeed==0.14.2; sys_platform == 'linux'
|
||||
deepspeed==0.15.1; sys_platform == 'linux'
|
||||
diffusers==0.29.0
|
||||
fastapi==0.115.6
|
||||
fastapi-cli==0.0.4
|
||||
|
||||
Reference in New Issue
Block a user