add vllm inference

This commit is contained in:
lyuxiang.lx
2025-05-30 07:22:35 +00:00
parent 9f55c5af8f
commit 6dd68b9d5e
6 changed files with 105 additions and 64 deletions

View File

@@ -140,7 +140,7 @@ class CosyVoice:
class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
self.fp16 = fp16

View File

@@ -59,9 +59,6 @@ class CosyVoiceModel:
self.stream_scale_factor = 1
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
for _ in range(trt_concurrent):
self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
@@ -69,7 +66,6 @@ class CosyVoiceModel:
self.mel_overlap_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {}
self.trt_context_dict = {}
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
@@ -98,7 +94,7 @@ class CosyVoiceModel:
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent, device=self.device)
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
@@ -125,7 +121,8 @@ class CosyVoiceModel:
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device)):
embedding=llm_embedding.to(self.device),
uuid=uuid):
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
@@ -180,13 +177,11 @@ class CosyVoiceModel:
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
this_trt_context = self.trt_context_pool.get()
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
self.trt_context_dict[this_uuid] = this_trt_context
if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
else:
@@ -240,8 +235,6 @@ class CosyVoiceModel:
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
self.trt_context_dict.pop(this_uuid)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
@@ -273,22 +266,28 @@ class CosyVoice2Model(CosyVoiceModel):
self.speech_window = np.hamming(2 * self.source_cache_len)
# rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
for _ in range(trt_concurrent):
self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
self.trt_context_dict = {}
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load_vllm(self, model_dir):
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
from vllm import EngineArgs, LLMEngine
engine_args = EngineArgs(model=model_dir,
skip_tokenizer_init=True,
enable_prompt_embeds=True,
gpu_memory_utilization=0.2)
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
del self.llm.llm.model.model.layers
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
@@ -330,11 +329,9 @@ class CosyVoice2Model(CosyVoiceModel):
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
this_trt_context = self.trt_context_pool.get()
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
self.trt_context_dict[this_uuid] = this_trt_context
if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
else:
@@ -388,8 +385,6 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
self.trt_context_dict.pop(this_uuid)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()

View File

@@ -16,6 +16,7 @@ import threading
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
from cosyvoice.utils.common import set_all_random_seed
class ConditionalCFM(BASECFM):
@@ -32,7 +33,6 @@ class ConditionalCFM(BASECFM):
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
self.estimator = estimator
self.lock = threading.Lock()
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
@@ -127,26 +127,27 @@ class ConditionalCFM(BASECFM):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
else:
estimator, trt_engine = self.estimator.acquire_estimator()
estimator.set_input_shape('x', (2, 80, x.size(2)))
estimator.set_input_shape('mask', (2, 1, x.size(2)))
estimator.set_input_shape('mu', (2, 80, x.size(2)))
estimator.set_input_shape('t', (2,))
estimator.set_input_shape('spks', (2, 80))
estimator.set_input_shape('cond', (2, 80, x.size(2)))
data_ptrs = [x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()]
for i, j in enumerate(data_ptrs):
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator)
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
with stream:
estimator.set_input_shape('x', (2, 80, x.size(2)))
estimator.set_input_shape('mask', (2, 1, x.size(2)))
estimator.set_input_shape('mu', (2, 80, x.size(2)))
estimator.set_input_shape('t', (2,))
estimator.set_input_shape('spks', (2, 80))
estimator.set_input_shape('cond', (2, 80, x.size(2)))
data_ptrs = [x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()]
for i, j in enumerate(data_ptrs):
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator, stream)
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
@@ -194,6 +195,7 @@ class ConditionalCFM(BASECFM):
class CausalConditionalCFM(ConditionalCFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
set_all_random_seed(0)
self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode()

View File

@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
import random
import time
import threading
from typing import Dict, Optional, Callable, List, Generator
import torch
from torch import nn
@@ -170,6 +173,7 @@ class TransformerLM(torch.nn.Module):
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
@@ -270,7 +274,6 @@ class Qwen2LM(TransformerLM):
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
# 2. build speech token language model related modules
self.sos_eos = 0
self.task_id = 1
@@ -292,6 +295,11 @@ class Qwen2LM(TransformerLM):
# 4. sampling method
self.sampling = sampling
self.mix_ratio = mix_ratio
# 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(3)]
self.vllm_output_queue = {}
self.lock = threading.Lock()
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
lm_target, lm_input = [], []
@@ -382,6 +390,7 @@ class Qwen2LM(TransformerLM):
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
@@ -402,22 +411,55 @@ class Qwen2LM(TransformerLM):
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 5. step by step decode
out_tokens = []
cache = None
for i in range(max_len):
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
if top_ids > self.speech_token_size:
continue
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
yield token
@torch.inference_mode()
def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
if hasattr(self, 'vllm'):
from vllm import SamplingParams, RequestOutput
sampling_params = SamplingParams(top_k=sampling,
stop_token_ids=self.stop_token_ids,
min_tokens=min_len,
max_tokens=max_len)
with self.lock:
self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
self.vllm_output_queue[uuid] = queue.Queue()
out_tokens = []
while True:
with self.lock:
if self.vllm_output_queue[uuid].empty() is True:
request_outputs: List[RequestOutput] = self.vllm.step()
for request_output in request_outputs:
top_ids = list(request_output.outputs[0].token_ids)[-1]
self.vllm_output_queue[request_output.request_id].put(top_ids)
if self.vllm_output_queue[uuid].empty() is False:
top_ids = self.vllm_output_queue[uuid].get()
if top_ids in self.stop_token_ids:
break
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
time.sleep(0.001)
with self.lock:
self.vllm_output_queue.pop(uuid)
else:
out_tokens = []
cache = None
for i in range(max_len):
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
if top_ids > self.speech_token_size:
continue
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
@torch.inference_mode()
def inference_bistream(

View File

@@ -169,17 +169,18 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1):
self.trt_context_pool = queue.Queue()
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
self.trt_engine = trt_engine
for _ in range(trt_concurrent):
trt_context = trt_engine.create_execution_context()
trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
self.trt_context_pool.put(trt_context)
self.trt_context_pool.put([trt_context, trt_stream])
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
def acquire_estimator(self):
return self.trt_context_pool.get(), self.trt_engine
def release_estimator(self, context):
self.trt_context_pool.put(context)
def release_estimator(self, context, stream):
self.trt_context_pool.put([context, stream])

View File

@@ -58,7 +58,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) # 1GB
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
@@ -122,6 +122,7 @@ def export_cosyvoice2_vllm(model, model_path, device):
model.llm.model.config.tie_word_embeddings = False
model.llm.model.config.use_bias = True
model.llm.model.save_pretrained(model_path)
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
model.llm.model.config.vocab_size = tmp_vocab_size
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
model.llm.model.set_input_embeddings(embed_tokens)