From b207c608858370175ec60a5e0feac7612c49f338 Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Thu, 18 Sep 2025 19:07:23 +0800 Subject: [PATCH 01/15] init step-audio2 token2wav --- .../model_repo/cosyvoice2_dit/1/model.py | 455 ++++++++++++++++ .../model_repo/cosyvoice2_dit/config.pbtxt | 73 +++ .../model_repo/token2wav_dit/1/model.py | 278 ++++++++++ .../model_repo/token2wav_dit/config.pbtxt | 80 +++ .../run_stepaudio2_dit_token2wav.sh | 142 +++++ runtime/triton_trtllm/token2wav_dit.py | 496 ++++++++++++++++++ 6 files changed, 1524 insertions(+) create mode 100644 runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py create mode 100644 runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt create mode 100644 runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py create mode 100644 runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt create mode 100644 runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh create mode 100644 runtime/triton_trtllm/token2wav_dit.py diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py new file mode 100644 index 0000000..97659ad --- /dev/null +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -0,0 +1,455 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import math +import os +import re +import threading +import time +from typing import Dict, List, Tuple, Optional, Union + +import numpy as np +import torch +from torch.utils.dlpack import from_dlpack, to_dlpack +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer + +import torchaudio + + +from matcha.utils.audio import mel_spectrogram + +ORIGINAL_VOCAB_SIZE = 151663 +torch.set_num_threads(1) + + +class TritonPythonModel: + """Triton Python model for Spark TTS. + + This model orchestrates the end-to-end TTS pipeline by coordinating + between audio tokenizer, LLM, and vocoder components. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + self.logger = pb_utils.Logger + # Parse model parameters + self.model_config = json.loads(args['model_config']) + parameters = self.model_config['parameters'] + model_params = {k: v["string_value"] for k, v in parameters.items()} + self.logger.log_info(f"model_params:{model_params}") + self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" + self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") + + # Initialize tokenizer + llm_tokenizer_dir = model_params["llm_tokenizer_dir"] + self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) + self.prompt_template = "<|sos|>{input_text}<|task_id|>" + self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>") + + self.device = torch.device("cuda") + self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) + + self.token_frame_rate = 25 + self.flow_pre_lookahead_len = 3 + self.token_hop_len = 15 + + spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt") + if not os.path.exists(spk_info_path): + raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") + spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) + self.default_spk_info = spk_info["001"] + + def forward_llm(self, input_ids): + """ + Prepares the response from the language model based on the provided + inputs. Creates a `pb_utils.InferenceRequest` object with passed + `llm_request_inputs` to send to a decoupled TensorRTLLM model. + For each response from the language model: + - Checks for errors and raise an exception if any are found. + - Extracts the "output_ids" tensor from the response. + - Determines the finish reason based on the presence of the + end-of-sequence token or reaching the maximum length. + - Appends the generated token IDs to `output_ids`. + - If the finish reason is determined, decodes the output IDs to text + and prepares the final response. + + The final response includes the generated text, finish reason, + completion tokens, prompt tokens, and total tokens. + + Parameters + ---------- + - llm_request_inputs (dict): A dictionary containing the inputs for the language model. + + Returns + ------- + - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata. + """ + # convert input_ids to numpy, with shape [1, sequence_length] + input_ids = input_ids.cpu().numpy() + max_tokens = 750 + input_dict = { + "request_output_len": np.array([[max_tokens]], dtype=np.int32), + "end_id": np.array([[self.eos_token_id]], dtype=np.int32), + "pad_id": np.array([[self.eos_token_id]], dtype=np.int32), + "streaming": np.array([[self.decoupled]], dtype=np.bool_), + "runtime_top_p": np.array([[0.95]], dtype=np.float32), + "runtime_top_k": np.array([[50]], dtype=np.int32), + "temperature": np.array([[0.8]], dtype=np.float32), + "repetition_penalty": np.array([[1.1]], dtype=np.float32), + "random_seed": np.array([[42]], dtype=np.uint64), + "input_ids": input_ids, + "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), + } + + # Convert inputs to Triton tensors + input_tensor_list = [ + pb_utils.Tensor(k, v) for k, v in input_dict.items() + ] + + # Create and execute inference request + llm_request = pb_utils.InferenceRequest( + model_name="tensorrt_llm", + requested_output_names=["output_ids", "sequence_length"], + inputs=input_tensor_list, + ) + + llm_responses = llm_request.exec(decoupled=self.decoupled) + if self.decoupled: + for llm_response in llm_responses: + if llm_response.has_error(): + raise pb_utils.TritonModelException(llm_response.error().message()) + + # Extract and process output + output_ids = pb_utils.get_output_tensor_by_name( + llm_response, "output_ids").as_numpy() + seq_lens = pb_utils.get_output_tensor_by_name( + llm_response, "sequence_length").as_numpy() + + # Get actual output IDs up to the sequence length + actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + + yield actual_output_ids + else: + llm_response = llm_responses + if llm_response.has_error(): + raise pb_utils.TritonModelException(llm_response.error().message()) + + # Extract and process output + output_ids = pb_utils.get_output_tensor_by_name( + llm_response, "output_ids").as_numpy() + seq_lens = pb_utils.get_output_tensor_by_name( + llm_response, "sequence_length").as_numpy() + + # Get actual output IDs up to the sequence length + actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + + yield actual_output_ids + + def forward_audio_tokenizer(self, wav, wav_len): + """Forward pass through the audio tokenizer component. + + Args: + wav: Input waveform tensor + wav_len: Waveform length tensor + + Returns: + Tuple of global and semantic tokens + """ + inference_request = pb_utils.InferenceRequest( + model_name='audio_tokenizer', + requested_output_names=['prompt_speech_tokens'], + inputs=[wav, wav_len] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') + prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() + + return prompt_speech_tokens + + def forward_speaker_embedding(self, wav): + """Forward pass through the speaker embedding component. + + Args: + wav: Input waveform tensor + + Returns: + Prompt speaker embedding tensor + """ + inference_request = pb_utils.InferenceRequest( + model_name='speaker_embedding', + requested_output_names=['prompt_spk_embedding'], + inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding') + prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack()) + + return prompt_spk_embedding + + def forward_token2wav( + self, + target_speech_tokens: torch.Tensor, + request_id: str, + prompt_speech_tokens: torch.Tensor = None, + prompt_speech_feat: torch.Tensor = None, + prompt_spk_embedding: torch.Tensor = None, + token_offset: int = None, + finalize: bool = None) -> torch.Tensor: + """Forward pass through the vocoder component. + + Args: + prompt_speech_tokens: Prompt speech tokens tensor + prompt_speech_feat: Prompt speech feat tensor + prompt_spk_embedding: Prompt spk embedding tensor + target_speech_tokens: Target speech tokens tensor + + Returns: + Generated waveform tensor + """ + target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) + + inputs_tensor = [target_speech_tokens_tensor] + + if token_offset is not None: + assert finalize is not None + token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32)) + finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) + inputs_tensor.append(token_offset_tensor) + inputs_tensor.append(finalize_tensor) + + if prompt_spk_embedding is not None: + assert prompt_speech_feat is not None + prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) + prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) + prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) + inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor]) + + # Create and execute inference request + inference_request = pb_utils.InferenceRequest( + model_name='token2wav', + requested_output_names=['waveform'], + inputs=inputs_tensor, + request_id=request_id, + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output waveform + waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') + waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() + + return waveform + + def parse_input(self, text, prompt_text, prompt_speech_tokens): + total_text = f"{prompt_text}{text}" + prompt = self.prompt_template.format(input_text=total_text) + input_ids = self.tokenizer.encode(prompt) + input_ids = torch.tensor([input_ids], dtype=torch.int32) + input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1) + return input_ids + + def _extract_speech_feat(self, speech): + speech_feat = mel_spectrogram( + speech, + n_fft=1920, + num_mels=80, + sampling_rate=24000, + hop_size=480, + win_size=1920, + fmin=0, + fmax=8000).squeeze( + dim=0).transpose( + 0, + 1).to( + self.device) + speech_feat = speech_feat.unsqueeze(dim=0) + return speech_feat + + def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag): + for generated_ids in generated_ids_iter: + generated_ids = generated_ids.tolist() + if len(generated_ids) == 0: + break + semantic_token_ids_arr.extend(generated_ids) + llm_is_done_flag[0] = True + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated audio + """ + responses = [] + + for request in requests: + request_id = request.request_id() + # Extract input tensors + wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") + + # Process reference audio through audio tokenizer + if wav is not None: + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) + prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) + + wav_tensor = wav.as_numpy() + wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) + speech_feat = self._extract_speech_feat(prompt_speech_resample) + token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) + prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() + prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() + + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') + prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + else: + # using pre-cached reference text + reference_text = self.default_spk_info["prompt_text"] + prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + prompt_speech_feat = None + prompt_spk_embedding = None + + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = target_text[0][0].decode('utf-8') + + # Prepare prompt for LLM + input_ids = self.parse_input( + text=target_text, + prompt_text=reference_text, + prompt_speech_tokens=prompt_speech_tokens, + ) + + # Generate semantic tokens with LLM + generated_ids_iter = self.forward_llm(input_ids) + + if self.decoupled: + response_sender = request.get_response_sender() + + semantic_token_ids_arr = [] + llm_is_done_flag = [False] + + llm_thread = threading.Thread( + target=self._llm_gen_thread, + args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag) + ) + + llm_thread.start() + + token_offset, chunk_index = 0, 0 + start_time = time.time() + this_token_hop_len = self.token_hop_len + + while True: + pending_num = len(semantic_token_ids_arr) - token_offset + + if llm_is_done_flag[0]: + break + + if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: + this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] + this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) + + sub_tts_speech = self.forward_token2wav( + this_tts_speech_token, request_id, prompt_speech_tokens, + prompt_speech_feat, prompt_spk_embedding, token_offset, False + ) + + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + + token_offset += this_token_hop_len + self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}") + + if self.dynamic_chunk_strategy == "exponential": + this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) + elif self.dynamic_chunk_strategy == "time_based": + # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 + cost_time = time.time() - start_time + duration = token_offset / self.token_frame_rate + if chunk_index > 0 and cost_time > 0: + avg_chunk_processing_time = cost_time / (chunk_index + 1) + if avg_chunk_processing_time > 0: + multiples = (duration - cost_time) / avg_chunk_processing_time + self.logger.log_info(f"multiples: {multiples}") + next_pending_num = len(semantic_token_ids_arr) - token_offset + if multiples > 4: + this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len + elif multiples > 2: + this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len + else: + this_token_hop_len = self.token_hop_len + this_token_hop_len = max(self.token_hop_len, this_token_hop_len) + chunk_index += 1 + else: + time.sleep(0.02) + + this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device) + sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True) + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + + llm_thread.join() + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + self.logger.log_info("send tritonserver_response_complete_final to end") + else: + generated_ids = next(generated_ids_iter) + generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device) + if generated_ids is None or len(generated_ids) == 0: + raise pb_utils.TritonModelException("Generated IDs is None or empty") + + audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding) + + # Prepare response + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + responses.append(inference_response) + + if not self.decoupled: + return responses diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt new file mode 100644 index 0000000..73a9a05 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt @@ -0,0 +1,73 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "cosyvoice2" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +model_transaction_policy { + decoupled: ${decoupled_mode} +} +parameters [ + { + key: "llm_tokenizer_dir", + value: {string_value:"${llm_tokenizer_dir}"} + }, + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "reference_wav" + data_type: TYPE_FP32 + dims: [-1] + optional: true + }, + { + name: "reference_wav_len" + data_type: TYPE_INT32 + dims: [1] + optional: true + }, + { + name: "reference_text" + data_type: TYPE_STRING + dims: [1] + optional: true + }, + { + name: "target_text" + data_type: TYPE_STRING + dims: [1] + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: ${bls_instance_num} + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py new file mode 100644 index 0000000..10bc272 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -0,0 +1,278 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import os + +import logging +from typing import List, Dict + +import torch +from torch.utils.dlpack import to_dlpack +from torch.nn import functional as F + +import triton_python_backend_utils as pb_utils + +from hyperpyyaml import load_hyperpyyaml +from cosyvoice.utils.common import fade_in_out +from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm +from cosyvoice.utils.common import TrtContextWrapper +from collections import defaultdict +import numpy as np + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +ORIGINAL_VOCAB_SIZE = 151663 +torch.set_num_threads(1) + + +class CosyVoice2: + + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'): + + self.model_dir = model_dir + self.fp16 = fp16 + + hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir) + if not os.path.exists(hyper_yaml_path): + raise ValueError('{} not found!'.format(hyper_yaml_path)) + with open(hyper_yaml_path, 'r') as f: + configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) + self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device) + self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) + if load_jit: + self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) + if load_trt: + self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), + '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + trt_concurrent, + self.fp16) + + +class CosyVoice2Model: + + def __init__(self, + flow: torch.nn.Module, + hift: torch.nn.Module, + fp16: bool = False, + device: str = 'cuda'): + self.device = device + self.flow = flow + self.hift = hift + self.fp16 = fp16 + if self.fp16 is True: + self.flow.half() + + # streaming tts config + self.token_hop_len = 25 + self.mel_cache_len = 8 + self.source_cache_len = int(self.mel_cache_len * 480) + self.speech_window = np.hamming(2 * self.source_cache_len) + self.hift_cache_dict = defaultdict(lambda: None) + + def load_jit(self, flow_encoder_model): + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + self.flow.encoder = flow_encoder + + def load(self, flow_model, hift_model): + self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) + self.flow.to(self.device).eval() + # in case hift_model is a hifigan model + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() + + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) + del self.flow.decoder.estimator + import tensorrt as trt + with open(flow_decoder_estimator_model, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_trt_kwargs(self): + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] + opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)] + max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] + input_names = ["x", "mask", "mu", "cond"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): + with torch.cuda.amp.autocast(self.fp16): + tts_mel, _ = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + streaming=stream, + finalize=finalize) + tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] + # append hift cache + if self.hift_cache_dict[uuid] is not None: + hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] + tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = torch.zeros(1, 1, 0) + # keep overlap mel and hift cache + if finalize is False: + tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) + self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:], + 'source': tts_source[:, :, -self.source_cache_len:], + 'speech': tts_speech[:, -self.source_cache_len:]} + tts_speech = tts_speech[:, :-self.source_cache_len] + else: + if speed != 1.0: + assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode' + tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear') + tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) + return tts_speech + + +class TritonPythonModel: + """Triton Python model for vocoder. + + This model takes global and semantic tokens as input and generates audio waveforms + using the BiCodec vocoder. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + # Parse model parameters + parameters = json.loads(args['model_config'])['parameters'] + model_params = {key: value["string_value"] for key, value in parameters.items()} + model_dir = model_params["model_dir"] + + # Initialize device and vocoder + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + logger.info(f"Initializing vocoder from {model_dir} on {self.device}") + + self.token2wav_model = CosyVoice2( + model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device + ) + + spk_info_path = os.path.join(model_dir, "spk2info.pt") + if not os.path.exists(spk_info_path): + raise ValueError(f"spk2info.pt not found in {model_dir}") + spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) + self.default_spk_info = spk_info["001"] + + logger.info("Token2Wav initialized successfully") + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated waveforms + """ + responses = [] + # Process each request in batch + for request in requests: + target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() + target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device) + + prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens") + if prompt_speech_tokens_tensor is not None: + prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy() + prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy() + prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy() + prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device) + prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device) + prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device) + prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE + else: + prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device) + prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device) + prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device) + + # shift the speech tokens according to the original vocab size + target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE + + # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. + token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset") + if token_offset is not None: + token_offset = token_offset.as_numpy().item() + finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() + if not finalize: + stream = True + else: + stream = False + request_id = request.request_id() + audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens, + prompt_token=prompt_speech_tokens, + prompt_feat=prompt_speech_feat, + embedding=prompt_spk_embedding, + token_offset=token_offset, + uuid=request_id, + stream=stream, + finalize=finalize) + if finalize: + self.token2wav_model.model.hift_cache_dict.pop(request_id) + + else: + tts_mel, _ = self.token2wav_model.model.flow.inference( + token=target_speech_tokens, + token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( + self.device + ), + prompt_token=prompt_speech_tokens, + prompt_token_len=torch.tensor( + [prompt_speech_tokens.shape[1]], dtype=torch.int32 + ).to(self.device), + prompt_feat=prompt_speech_feat, + prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=prompt_spk_embedding, + streaming=False, + finalize=True, + ) + + audio_hat, _ = self.token2wav_model.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) + + generated_wave = audio_hat.squeeze(0).cpu().numpy() + + wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) + inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) + responses.append(inference_response) + + return responses diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt new file mode 100644 index 0000000..c33a85f --- /dev/null +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt @@ -0,0 +1,80 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "token2wav" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +parameters [ + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "target_speech_tokens" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "prompt_speech_tokens" + data_type: TYPE_INT32 + dims: [-1] + optional: true + }, + { + name: "prompt_speech_feat" + data_type: TYPE_FP16 + dims: [-1, 80] + optional: true + }, + { + name: "prompt_spk_embedding" + data_type: TYPE_FP16 + dims: [-1] + optional: true + }, + { + name: "token_offset" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "finalize" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh new file mode 100644 index 0000000..7c7f3cd --- /dev/null +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -0,0 +1,142 @@ +#!/bin/bash +# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) +export CUDA_VISIBLE_DEVICES=0 +cosyvoice_path=/workspace/CosyVoice +export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH +export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH +stage=$1 +stop_stage=$2 + +huggingface_model_local_dir=./cosyvoice2_llm +model_scope_model_local_dir=./CosyVoice2-0.5B +trt_dtype=bfloat16 +trt_weights_dir=./trt_weights_${trt_dtype} +trt_engines_dir=./trt_engines_${trt_dtype} + +model_repo=./model_repo_cosyvoice2 + +use_spk2info_cache=False + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + echo "Cloning CosyVoice" + git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path + cd $cosyvoice_path + git submodule update --init --recursive + cd runtime/triton_trtllm +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + echo "Downloading CosyVoice2-0.5B" + # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py + huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm + modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir + # download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings + wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt +fi + + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + echo "Converting checkpoint to TensorRT weights" + python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \ + --output_dir $trt_weights_dir \ + --dtype $trt_dtype || exit 1 + + echo "Building TensorRT engines" + trtllm-build --checkpoint_dir $trt_weights_dir \ + --output_dir $trt_engines_dir \ + --max_batch_size 16 \ + --max_num_tokens 32768 \ + --gemm_plugin $trt_dtype || exit 1 + + echo "Testing TensorRT engines" + python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \ + --tokenizer_dir $huggingface_model_local_dir \ + --top_k 50 --top_p 0.95 --temperature 0.8 \ + --engine_dir=$trt_engines_dir || exit 1 +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + echo "Creating model repository" + rm -rf $model_repo + mkdir -p $model_repo + cosyvoice2_dir="cosyvoice2" + + cp -r ./model_repo/${cosyvoice2_dir} $model_repo + cp -r ./model_repo/tensorrt_llm $model_repo + cp -r ./model_repo/token2wav $model_repo + if [ $use_spk2info_cache == "False" ]; then + cp -r ./model_repo/audio_tokenizer $model_repo + cp -r ./model_repo/speaker_embedding $model_repo + fi + + ENGINE_PATH=$trt_engines_dir + MAX_QUEUE_DELAY_MICROSECONDS=0 + MODEL_DIR=$model_scope_model_local_dir + LLM_TOKENIZER_DIR=$huggingface_model_local_dir + BLS_INSTANCE_NUM=4 + TRITON_MAX_BATCH_SIZE=16 + DECOUPLED_MODE=True # True for streaming, False for offline + + python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 + if [ $use_spk2info_cache == "False" ]; then + python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + echo "Starting Triton server" + tritonserver --model-repository $model_repo +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + echo "Single request test http, only work for offline TTS mode" + python3 client_http.py \ + --reference-audio ./assets/prompt_audio.wav \ + --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ + --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \ + --model-name cosyvoice2 +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + echo "Running benchmark client grpc" + num_task=4 + + mode=streaming + BLS_INSTANCE_NUM=4 + + python3 client_grpc.py \ + --server-addr localhost \ + --model-name cosyvoice2 \ + --num-tasks $num_task \ + --mode $mode \ + --use-spk2info-cache $use_spk2info_cache \ + --huggingface-dataset yuekai/seed_tts_cosy2 \ + --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache} +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + echo "stage 6: Offline inference benchmark" + n_gpus=1 + datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh + backend=trtllm # hf, trtllm, vllm + + batch_sizes=(16 8 4 2 1) + token2wav_batch_size=1 + for batch_size in ${batch_sizes[@]}; do + for dataset in ${datasets[@]}; do + output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size} + CUDA_VISIBLE_DEVICES=0 \ + python3 offline_inference.py \ + --output-dir $output_dir \ + --llm-model-name-or-path $huggingface_model_local_dir \ + --token2wav-path $model_scope_model_local_dir \ + --backend $backend \ + --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \ + --engine-dir $trt_engines_dir \ + --split-name ${dataset} || exit 1 + done + done +fi diff --git a/runtime/triton_trtllm/token2wav_dit.py b/runtime/triton_trtllm/token2wav_dit.py new file mode 100644 index 0000000..69db946 --- /dev/null +++ b/runtime/triton_trtllm/token2wav_dit.py @@ -0,0 +1,496 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage + CUDA_VISIBLE_DEVICES=0 \ + python3 token2wav.py --enable-trt || exit 1 +""" +import torch +# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec +from flashcosyvoice.modules.hifigan import HiFTGenerator +from flashcosyvoice.utils.audio import mel_spectrogram +import torchaudio.compliance.kaldi as kaldi +import onnxruntime +import s3tokenizer +from torch.utils.data import DataLoader +from datasets import load_dataset +import torchaudio +import os +import logging +import argparse +import queue +import time +import numpy as np +from hyperpyyaml import load_hyperpyyaml + + +def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor): + """perform fade_in_out in tensor style + """ + mel_overlap_len = int(window.shape[0] / 2) + fade_in_mel = fade_in_mel.clone() + fade_in_mel[..., :mel_overlap_len] = \ + fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ + fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] + return fade_in_mel + +def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): + import tensorrt as trt + logging.info("Converting onnx to trt...") + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + network = builder.create_network(network_flags) + parser = trt.OnnxParser(network, logger) + config = builder.create_builder_config() + # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB + if dtype == torch.float16: + config.set_flag(trt.BuilderFlag.FP16) + elif dtype == torch.bfloat16: + config.set_flag(trt.BuilderFlag.BF16) + elif dtype == torch.float32: + config.set_flag(trt.BuilderFlag.FP32) + profile = builder.create_optimization_profile() + # load onnx model + with open(onnx_model, "rb") as f: + if not parser.parse(f.read()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + raise ValueError('failed to parse {}'.format(onnx_model)) + # set input shapes + for i in range(len(trt_kwargs['input_names'])): + profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) + if dtype == torch.float16: + tensor_dtype = trt.DataType.HALF + elif dtype == torch.bfloat16: + tensor_dtype = trt.DataType.BF16 + elif dtype == torch.float32: + tensor_dtype = trt.DataType.FLOAT + else: + raise ValueError('invalid dtype {}'.format(dtype)) + # set input and output data type + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + input_tensor.dtype = tensor_dtype + for i in range(network.num_outputs): + output_tensor = network.get_output(i) + output_tensor.dtype = tensor_dtype + config.add_optimization_profile(profile) + engine_bytes = builder.build_serialized_network(network, config) + # save trt engine + with open(trt_model, "wb") as f: + f.write(engine_bytes) + logging.info("Succesfully convert onnx to trt...") + +class TrtContextWrapper: + def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) + self.trt_engine = trt_engine + self.device = device + for _ in range(trt_concurrent): + trt_context = trt_engine.create_execution_context() + trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device))) + assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) + self.trt_context_pool.put([trt_context, trt_stream]) + assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' + + def acquire_estimator(self): + return self.trt_context_pool.get(), self.trt_engine + + def release_estimator(self, context, stream): + self.trt_context_pool.put([context, stream]) + +class CosyVoice2_Token2Wav(torch.nn.Module): + def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16): + super().__init__() + self.device_id = device_id + self.device = f"cuda:{device_id}" + with open(f"{model_dir}/flow.yaml", "r") as f: + configs = load_hyperpyyaml(f) + self.flow = configs['flow'] + + self.dtype = dtype + self.flow.to(self.dtype) + + self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) + self.flow.to(self.device).eval() + + self.hift = HiFTGenerator() + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() + + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, + providers=["CPUExecutionProvider"]) + + self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() + + gpu="l20" + if enable_trt: + if streaming: + self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', + 1, + self.dtype, streaming) + else: + self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + self.dtype) + self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', + f'{model_dir}/campplus.onnx', + 1, + False) + + + self.streaming_flow_cache = {} + self.speaker_cache = {} + + self.mel_cache_len = 8 # hard-coded, 160ms + self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave + self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda() + + # hifigan cache for streaming tts + self.hift_cache_dict = {} + + def forward_spk_embedding(self, spk_feat): + if isinstance(self.spk_model, onnxruntime.InferenceSession): + return self.spk_model.run( + None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} + )[0].flatten().tolist() + else: + [spk_model, stream], trt_engine = self.spk_model.acquire_estimator() + # NOTE need to synchronize when switching stream + with torch.cuda.device(self.device_id): + torch.cuda.current_stream().synchronize() + spk_feat = spk_feat.unsqueeze(dim=0).to(self.device) + batch_size = spk_feat.size(0) + + with stream: + spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80)) + output_tensor = torch.empty((batch_size, 192), device=spk_feat.device) + + data_ptrs = [spk_feat.contiguous().data_ptr(), + output_tensor.contiguous().data_ptr()] + for i, j in enumerate(data_ptrs): + + spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.spk_model.release_estimator(spk_model, stream) + + return output_tensor.cpu().numpy().flatten().tolist() + + def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): + if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: + trt_kwargs = self.get_spk_trt_kwargs() + convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16) + import tensorrt as trt + with open(spk_model, 'rb') as f: + spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert spk_engine is not None, 'failed to load trt {}'.format(spk_model) + self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_spk_trt_kwargs(self): + min_shape = [(1, 4, 80)] + opt_shape = [(1, 500, 80)] + max_shape = [(1, 3000, 80)] + input_names = ["input"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + opt_batch_size = 2 + max_batch_size = 16 + if streaming: + opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts + trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming) + convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype) + del self.flow.decoder.estimator + import tensorrt as trt + with open(flow_decoder_estimator_model, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False): + if streaming: + min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)] + opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)] + max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)] + input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"] + else: + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] + opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)] + max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)] + input_names = ["x", "mask", "mu", "cond", "t", "spks"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]: + prompt_speech_tokens_list, prompt_speech_mels_list = [], [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T] + prompt_speech_mels_list.append(log_mel) + prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list) + prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize( + prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device) + ) + for i in range(len(prompt_speech_tokens)): + speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() + prompt_speech_tokens_list.append(speech_tokens_i) + return prompt_speech_tokens_list + + def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: + spk_emb_for_flow = [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) + spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) + spk_emb = self.forward_spk_embedding(spk_feat) + + spk_emb_for_flow.append(spk_emb) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + if self.dtype != torch.float32: + spk_emb_for_flow = spk_emb_for_flow.to(self.dtype) + return spk_emb_for_flow + + def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): + prompt_mels_for_flow = [] + prompt_mels_lens_for_flow = [] + for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate): + assert len(audio.shape) == 1 + audio = audio.unsqueeze(0) + if sample_rate != 24000: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) + mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] + mel_len = mel.shape[0] + prompt_mels_for_flow.append(mel) + prompt_mels_lens_for_flow.append(mel_len) + prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] + prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) + return prompt_mels_for_flow, prompt_mels_lens_for_flow + + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): + batch_size = prompt_mels_for_flow.shape[0] + flow_inputs = [] + flow_inputs_lens = [] + for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list): + flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens)) + flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens)) + + flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0) + flow_inputs_lens = torch.tensor(flow_inputs_lens) + + with torch.amp.autocast(self.device, dtype=torch.float16): + generated_mels, generated_mels_lens = self.flow.inference( + flow_inputs.to(self.device), flow_inputs_lens.to(self.device), + prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10 + ) + + return generated_mels, generated_mels_lens + + def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor): + batch_size = generated_mels.shape[0] + generated_wavs = [] + for i in range(batch_size): + mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0) + wav, _ = self.hift(speech_feat=mel) + generated_wavs.append(wav) + return generated_wavs + + + @torch.inference_mode() + def forward( + self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] + ): + # assert all item in prompt_audios_sample_rate is 16000 + assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) + + + prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) + + generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + + generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) + + return generated_wavs + + def prepare_prompt_audio( + self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] + ): + # assert all item in prompt_audios_sample_rate is 16000 + assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) + + + prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) + + prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) + + spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) + + return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + + + def get_prompt_audio_cache_for_streaming_tts( + self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + ): + assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts" + for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list): + prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3]) + prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0) + + cache = self.flow.setup_cache( + prompt_speech_tokens_tensor.to(self.device), + prompt_mels_for_flow.to(self.device), + spk_emb_for_flow.to(self.device), + n_timesteps=10 + ) + + # cache dict's tensor batch dim is 1 for now + return cache + + + @torch.inference_mode() + def forward_streaming( + self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 + ): + + if speaker_id not in self.speaker_cache: + assert prompt_audio is not None, "prompt_audio is required for new speaker" + assert prompt_audio_sample_rate == 16000 + + prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate]) + + token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0])) + prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous() + prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len] + + cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} + + self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} + + if request_id not in self.streaming_flow_cache: + self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy() + self.hift_cache_dict[request_id] = dict( + mel = torch.zeros(1, 80, 0, device='cuda'), + source = torch.zeros(1, 1, 0, device='cuda'), + speech = torch.zeros(1, 0, device='cuda'), + ) + + current_request_cache = self.streaming_flow_cache[request_id] + prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] + generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') + + chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( + token=generated_speech_tokens, + spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device), + cache=current_request_cache, + last_chunk=last_chunk, + n_timesteps=10, + ) + + self.streaming_flow_cache[request_id] = new_streaming_flow_cache + + if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): + self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]], + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], + ], dim=4) + + + + hift_cache_mel = self.hift_cache_dict[request_id]['mel'] + hift_cache_source = self.hift_cache_dict[request_id]['source'] + hift_cache_speech = self.hift_cache_dict[request_id]['speech'] + mel = torch.concat([hift_cache_mel, chunk_mel], dim=2) + + speech, source = self.hift(mel, hift_cache_source) + + # overlap speech smooth + if hift_cache_speech.shape[-1] > 0: + speech = fade_in_out(speech, hift_cache_speech, self.speech_window) + + # update vocoder cache + self.hift_cache_dict[request_id] = dict( + mel = mel[..., -self.mel_cache_len:].clone().detach(), + source = source[:, :, -self.source_cache_len:].clone().detach(), + speech = speech[:, -self.source_cache_len:].clone().detach(), + ) + if not last_chunk: + speech = speech[:, :-self.source_cache_len] + + if last_chunk: + assert request_id in self.streaming_flow_cache + self.streaming_flow_cache.pop(request_id) + self.hift_cache_dict.pop(request_id) + + return speech + +def collate_fn(batch): + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] + for i, item in enumerate(batch): + generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) + audio = torch.from_numpy(item['prompt_audio']['array']).float() + prompt_audios_list.append(audio) + prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) + ids.append(item['id']) + + return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--enable-trt", action="store_true") + parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--output-dir", type=str, default="generated_wavs") + parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") + parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") + return parser.parse_args() + +if __name__ == "__main__": + args = get_args() + model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) + # mkdir output_dir if not exists + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + dataset_name = "yuekai/seed_tts_cosy2" + + dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) + + + data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) + + + for epoch in range(args.warmup): + start_time = time.time() + + for batch in data_loader: + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch + + generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) + + + for id, wav in zip(ids, generated_wavs): + torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) + + end_time = time.time() + epoch_time = end_time - start_time + print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") \ No newline at end of file From 444b7ff5dfafb0c98eecab5e7db461f997843a48 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 19 Sep 2025 13:48:32 +0800 Subject: [PATCH 02/15] fix cache shallow copy --- .../run_stepaudio2_dit_token2wav.sh | 11 +++++++++++ runtime/triton_trtllm/token2wav_dit.py | 19 +++++++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index 7c7f3cd..c0034c2 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -2,6 +2,9 @@ # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) export CUDA_VISIBLE_DEVICES=0 cosyvoice_path=/workspace/CosyVoice +cosyvoice_path=/workspace_yuekai/tts/CosyVoice +stepaudio2_path=/workspace_yuekai/tts/Step-Audio2 +export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH stage=$1 @@ -140,3 +143,11 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then done done fi + + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + + python3 benchmark_streaming_token2wav.py --enable-trt + + +fi \ No newline at end of file diff --git a/runtime/triton_trtllm/token2wav_dit.py b/runtime/triton_trtllm/token2wav_dit.py index 69db946..fdc1a12 100644 --- a/runtime/triton_trtllm/token2wav_dit.py +++ b/runtime/triton_trtllm/token2wav_dit.py @@ -362,8 +362,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module): spk_emb_for_flow.to(self.device), n_timesteps=10 ) - - # cache dict's tensor batch dim is 1 for now + # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] + cache['estimator_att_cache'] = cache['estimator_att_cache'].clone() + cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone() return cache @@ -371,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def forward_streaming( self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 ): - if speaker_id not in self.speaker_cache: assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio_sample_rate == 16000 @@ -388,7 +388,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} if request_id not in self.streaming_flow_cache: - self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy() + self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.hift_cache_dict[request_id] = dict( mel = torch.zeros(1, 80, 0, device='cuda'), source = torch.zeros(1, 1, 0, device='cuda'), @@ -396,12 +396,14 @@ class CosyVoice2_Token2Wav(torch.nn.Module): ) current_request_cache = self.streaming_flow_cache[request_id] - prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] + + current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') + chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( token=generated_speech_tokens, - spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device), + spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device), cache=current_request_cache, last_chunk=last_chunk, n_timesteps=10, @@ -409,9 +411,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module): self.streaming_flow_cache[request_id] = new_streaming_flow_cache - if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): + + if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ - self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]], + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]], self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], ], dim=4) From 482464ea275835b7bcd30c311a0f54b535d4c84b Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Wed, 24 Sep 2025 15:18:01 +0800 Subject: [PATCH 03/15] add streaming dit --- runtime/triton_trtllm/client_grpc.py | 14 +- .../model_repo/cosyvoice2_dit/1/model.py | 59 +-- .../model_repo/cosyvoice2_dit/3/model.py | 438 ++++++++++++++++++ .../model_repo/cosyvoice2_dit/config.pbtxt | 2 +- .../model_repo/token2wav_dit/1/model.py | 209 ++------- .../token2wav_dit/1}/token2wav_dit.py | 58 ++- .../model_repo/token2wav_dit/config.pbtxt | 27 +- runtime/triton_trtllm/offline_inference.py | 99 +++- .../run_stepaudio2_dit_token2wav.sh | 98 +++- runtime/triton_trtllm/streaming_inference.py | 115 +++++ 10 files changed, 850 insertions(+), 269 deletions(-) create mode 100644 runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py rename runtime/triton_trtllm/{ => model_repo/token2wav_dit/1}/token2wav_dit.py (91%) create mode 100644 runtime/triton_trtllm/streaming_inference.py diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 994a401..afbab68 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -209,7 +209,8 @@ def get_args(): choices=[ "f5_tts", "spark_tts", - "cosyvoice2"], + "cosyvoice2", + "cosyvoice2_dit"], help="triton model_repo module name to request", ) @@ -260,8 +261,8 @@ def get_args(): parser.add_argument( "--use-spk2info-cache", - type=bool, - default=False, + type=str, + default="False", help="Use spk2info cache for reference audio.", ) @@ -490,6 +491,7 @@ async def send_streaming( padding_duration=padding_duration, use_spk2info_cache=use_spk2info_cache ) + request_id = str(uuid.uuid4()) user_data = UserData() @@ -670,11 +672,15 @@ async def main(): trust_remote_code=True, ) manifest_item_list = [] + tmp_audio_path="./asset_zero_shot_prompt.wav" + tmp_audio_text="希望你以后能够做的比我还好呦。" for i in range(len(dataset)): manifest_item_list.append( { "audio_filepath": dataset[i]["prompt_audio"], "reference_text": dataset[i]["prompt_text"], + # "audio_filepath": tmp_audio_path, + # "reference_text": tmp_audio_text, "target_audio_path": dataset[i]["id"], "target_text": dataset[i]["target_text"], } @@ -686,7 +692,7 @@ async def main(): manifest_item_list = split_data(manifest_item_list, num_tasks) os.makedirs(args.log_dir, exist_ok=True) - + args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true" tasks = [] start_time = time.time() for i in range(num_tasks): diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index 97659ad..d0977c5 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -227,12 +227,11 @@ class TritonPythonModel: def forward_token2wav( self, + index: int, target_speech_tokens: torch.Tensor, request_id: str, - prompt_speech_tokens: torch.Tensor = None, - prompt_speech_feat: torch.Tensor = None, - prompt_spk_embedding: torch.Tensor = None, - token_offset: int = None, + reference_wav: object, + reference_wav_len: object, finalize: bool = None) -> torch.Tensor: """Forward pass through the vocoder component. @@ -246,29 +245,16 @@ class TritonPythonModel: Generated waveform tensor """ target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) - - inputs_tensor = [target_speech_tokens_tensor] - - if token_offset is not None: - assert finalize is not None - token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32)) - finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) - inputs_tensor.append(token_offset_tensor) - inputs_tensor.append(finalize_tensor) - - if prompt_spk_embedding is not None: - assert prompt_speech_feat is not None - prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) - prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) - prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) - inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor]) + finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) + inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] # Create and execute inference request inference_request = pb_utils.InferenceRequest( - model_name='token2wav', + model_name='token2wav_dit', requested_output_names=['waveform'], inputs=inputs_tensor, request_id=request_id, + parameters={"priority": index+1}, ) inference_response = inference_request.exec() @@ -346,8 +332,15 @@ class TritonPythonModel: reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') - prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + + # reference_text = self.default_spk_info["prompt_text"] + # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + # prompt_speech_feat = None + # prompt_spk_embedding = None + else: + assert False, "wav is None" # using pre-cached reference text reference_text = self.default_spk_info["prompt_text"] prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE @@ -391,12 +384,12 @@ class TritonPythonModel: break if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: - this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] + this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) sub_tts_speech = self.forward_token2wav( - this_tts_speech_token, request_id, prompt_speech_tokens, - prompt_speech_feat, prompt_spk_embedding, token_offset, False + chunk_index, + this_tts_speech_token, request_id, wav, wav_len, False ) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) @@ -429,8 +422,8 @@ class TritonPythonModel: else: time.sleep(0.02) - this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device) - sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True) + this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) + sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) @@ -439,17 +432,7 @@ class TritonPythonModel: response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) self.logger.log_info("send tritonserver_response_complete_final to end") else: - generated_ids = next(generated_ids_iter) - generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device) - if generated_ids is None or len(generated_ids) == 0: - raise pb_utils.TritonModelException("Generated IDs is None or empty") - - audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding) - - # Prepare response - audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) - inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) - responses.append(inference_response) + raise NotImplementedError("Decoupled mode is not supported") if not self.decoupled: return responses diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py new file mode 100644 index 0000000..b4a6348 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py @@ -0,0 +1,438 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import math +import os +import re +import time +from typing import Dict, List, Tuple, Optional, Union +import asyncio +import httpx + +import numpy as np +import torch +from torch.utils.dlpack import from_dlpack, to_dlpack +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer + +import torchaudio + + +from matcha.utils.audio import mel_spectrogram + +ORIGINAL_VOCAB_SIZE = 151663 +torch.set_num_threads(1) + + +def parse_speech_token_string(response_text: str) -> List[int]: + """ + Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs. + """ + speech_tokens = response_text.strip().split('><') + if len(speech_tokens) > 1: + # Add back the missing '<' and '>' for proper parsing + speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens] + speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens] + + speech_ids = [] + for token_str in speech_tokens: + match = re.match(r'<\|s_(\d+)\|>', token_str) + if match: + speech_ids.append(int(match.group(1))) + return speech_ids + + +class TritonPythonModel: + """Triton Python model for Spark TTS. + + This model orchestrates the end-to-end TTS pipeline by coordinating + between audio tokenizer, LLM, and vocoder components. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + self.logger = pb_utils.Logger + # Parse model parameters + self.model_config = json.loads(args['model_config']) + parameters = self.model_config['parameters'] + model_params = {k: v["string_value"] for k, v in parameters.items()} + self.logger.log_info(f"model_params:{model_params}") + self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" + self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") + + # Initialize tokenizer + llm_tokenizer_dir = model_params["llm_tokenizer_dir"] + self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) + self.prompt_template = "<|sos|>{input_text}<|task_id|>" + self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>") + + self.device = torch.device("cuda") + self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) + + self.token_frame_rate = 25 + self.flow_pre_lookahead_len = 3 + self.token_hop_len = 15 + + spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt") + if not os.path.exists(spk_info_path): + raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") + spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) + # self.default_spk_info = spk_info["001"] + + def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: + """Converts a tensor or list of speech token IDs to a string representation.""" + if isinstance(speech_tokens, torch.Tensor): + # Ensure tensor is on CPU and flattened + speech_tokens = speech_tokens.cpu().numpy().flatten().tolist() + + speech_id_str = "" + for token_id in speech_tokens: + # Convert token ID back to the speech number N + token_num = token_id - ORIGINAL_VOCAB_SIZE + speech_id_str += f"<|s_{token_num}|>" + return speech_id_str + + async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]): + """ + Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response. + """ + full_text = f"{reference_text}{target_text}" + prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens) + + chat = [ + {"role": "user", "content": full_text}, + {"role": "assistant", "content": prompt_speech_tokens_str} + ] + print(chat) + + payload = { + "model": "trt_engines_bfloat16", + "messages": chat, + "max_tokens": 750, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "repetition_penalty": 1.1, + "stop": ["<|eos1|>", "<|eos|>"], + "stream": True, + } + + api_base = "http://localhost:8000/v1/chat/completions" + + buffer = "" + async with httpx.AsyncClient() as client: + async with client.stream("POST", api_base, json=payload, timeout=None) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if line.startswith("data: "): + line_data = line[len("data: "):].strip() + if line_data == "[DONE]": + break + try: + json_data = json.loads(line_data) + content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") + if content: + buffer += content + while True: + match = re.search(r"<\|s_(\d+)\|>", buffer) + if not match: + break + + token_num = int(match.group(1)) + final_id = token_num + ORIGINAL_VOCAB_SIZE + yield final_id + buffer = buffer[match.end():] + except json.JSONDecodeError: + self.logger.log_info(f"Skipping non-JSON line: {line_data}") + continue + + # Process any remaining complete tokens in the buffer after the stream ends + while True: + match = re.search(r"<\|s_(\d+)\|>", buffer) + if not match: + break + token_num = int(match.group(1)) + final_id = token_num + ORIGINAL_VOCAB_SIZE + yield final_id + buffer = buffer[match.end():] + + + def forward_audio_tokenizer(self, wav, wav_len): + """Forward pass through the audio tokenizer component. + + Args: + wav: Input waveform tensor + wav_len: Waveform length tensor + + Returns: + Tuple of global and semantic tokens + """ + inference_request = pb_utils.InferenceRequest( + model_name='audio_tokenizer', + requested_output_names=['prompt_speech_tokens'], + inputs=[wav, wav_len] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') + prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() + + return prompt_speech_tokens + + def forward_speaker_embedding(self, wav): + """Forward pass through the speaker embedding component. + + Args: + wav: Input waveform tensor + + Returns: + Prompt speaker embedding tensor + """ + inference_request = pb_utils.InferenceRequest( + model_name='speaker_embedding', + requested_output_names=['prompt_spk_embedding'], + inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding') + prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack()) + + return prompt_spk_embedding + + def forward_token2wav( + self, + index: int, + target_speech_tokens: torch.Tensor, + request_id: str, + reference_wav: object, + reference_wav_len: object, + finalize: bool = None) -> torch.Tensor: + """Forward pass through the vocoder component. + + Args: + prompt_speech_tokens: Prompt speech tokens tensor + prompt_speech_feat: Prompt speech feat tensor + prompt_spk_embedding: Prompt spk embedding tensor + target_speech_tokens: Target speech tokens tensor + + Returns: + Generated waveform tensor + """ + target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) + finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) + inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] + + # Create and execute inference request + inference_request = pb_utils.InferenceRequest( + model_name='token2wav_dit', + requested_output_names=['waveform'], + inputs=inputs_tensor, + request_id=request_id, + parameters={"priority": index+1}, + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output waveform + waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') + waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() + + return waveform + + def _extract_speech_feat(self, speech): + speech_feat = mel_spectrogram( + speech, + n_fft=1920, + num_mels=80, + sampling_rate=24000, + hop_size=480, + win_size=1920, + fmin=0, + fmax=8000).squeeze( + dim=0).transpose( + 0, + 1).to( + self.device) + speech_feat = speech_feat.unsqueeze(dim=0) + return speech_feat + + async def _process_request(self, request): + request_id = request.request_id() + # Extract input tensors + wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") + + # Process reference audio through audio tokenizer + if wav is not None: + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) + prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) + + wav_tensor = wav.as_numpy() + wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) + speech_feat = self._extract_speech_feat(prompt_speech_resample) + token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) + prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() + prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() + + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') + # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + + # reference_text = self.default_spk_info["prompt_text"] + # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + # prompt_speech_feat = None + # prompt_spk_embedding = None + + else: + # using pre-cached reference text + assert False, "using pre-cached reference text is not supported" + reference_text = self.default_spk_info["prompt_text"] + prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + prompt_speech_feat = None + prompt_spk_embedding = None + + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = target_text[0][0].decode('utf-8') + + if self.decoupled: + response_sender = request.get_response_sender() + + semantic_token_ids_arr = [] + token_offset, chunk_index = 0, 0 + start_time = time.time() + this_token_hop_len = self.token_hop_len + + async for generated_ids in self.forward_llm_async( + target_text=target_text, + reference_text=reference_text, + prompt_speech_tokens=prompt_speech_tokens, + ): + if not generated_ids: + break + semantic_token_ids_arr.append(generated_ids) + + while True: + pending_num = len(semantic_token_ids_arr) - token_offset + if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: + this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] + this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) + + sub_tts_speech = self.forward_token2wav( + chunk_index, + this_tts_speech_token, request_id, wav, wav_len, False + ) + + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + + token_offset += this_token_hop_len + self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}") + + if self.dynamic_chunk_strategy == "exponential": + this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) + elif self.dynamic_chunk_strategy == "time_based": + # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 + cost_time = time.time() - start_time + duration = token_offset / self.token_frame_rate + if chunk_index > 0 and cost_time > 0: + avg_chunk_processing_time = cost_time / (chunk_index + 1) + if avg_chunk_processing_time > 0: + multiples = (duration - cost_time) / avg_chunk_processing_time + self.logger.log_info(f"multiples: {multiples}") + next_pending_num = len(semantic_token_ids_arr) - token_offset + if multiples > 4: + this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len + elif multiples > 2: + this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len + else: + this_token_hop_len = self.token_hop_len + this_token_hop_len = max(self.token_hop_len, this_token_hop_len) + chunk_index += 1 + else: + break + + this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) + sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + + ## debug + ## save semantic_token_ids_arr and reference_text, target_text to a single json file + # save into a torch .pt + # for i, item in enumerate(semantic_token_ids_arr): + # semantic_token_ids_arr[i] = item - ORIGINAL_VOCAB_SIZE + # import json + # data = { + # "semantic_token_ids_arr": semantic_token_ids_arr, + # "reference_text": reference_text, + # "target_text": target_text + # } + # with open(f"semantic_token_ids_arr_debug_{request_id}.pt", "wb") as f: + # torch.save(data, f) + # with open(f"semantic_token_ids_arr_debug_{request_id}.json", "w") as f: + # json.dump(data, f) + + # ## + + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + self.logger.log_info("send tritonserver_response_complete_final to end") + else: + raise NotImplementedError("Decoupled mode is not supported") + + async def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated audio + """ + tasks = [ + asyncio.create_task(self._process_request(request)) + for request in requests + ] + await asyncio.gather(*tasks) + return None diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt index 73a9a05..e64647e 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: "cosyvoice2" +name: "cosyvoice2_dit" backend: "python" max_batch_size: ${triton_max_batch_size} dynamic_batching { diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index 10bc272..8f9ffba 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -42,6 +42,8 @@ from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vl from cosyvoice.utils.common import TrtContextWrapper from collections import defaultdict import numpy as np +from .token2wav_dit import CosyVoice2_Token2Wav +import hashlib logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) @@ -49,117 +51,19 @@ logger = logging.getLogger(__name__) ORIGINAL_VOCAB_SIZE = 151663 torch.set_num_threads(1) +def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str: + """ + Generates a unique ID for a torch.Tensor. + Tensors with the same elements and properties will have the same ID. + """ + # Convert tensor to a byte string + tensor_bytes = tensor.numpy().tobytes() -class CosyVoice2: - - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'): - - self.model_dir = model_dir - self.fp16 = fp16 - - hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir) - if not os.path.exists(hyper_yaml_path): - raise ValueError('{} not found!'.format(hyper_yaml_path)) - with open(hyper_yaml_path, 'r') as f: - configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) - self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device) - self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) - if load_jit: - self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) - if load_trt: - self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), - '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), - trt_concurrent, - self.fp16) - - -class CosyVoice2Model: - - def __init__(self, - flow: torch.nn.Module, - hift: torch.nn.Module, - fp16: bool = False, - device: str = 'cuda'): - self.device = device - self.flow = flow - self.hift = hift - self.fp16 = fp16 - if self.fp16 is True: - self.flow.half() - - # streaming tts config - self.token_hop_len = 25 - self.mel_cache_len = 8 - self.source_cache_len = int(self.mel_cache_len * 480) - self.speech_window = np.hamming(2 * self.source_cache_len) - self.hift_cache_dict = defaultdict(lambda: None) - - def load_jit(self, flow_encoder_model): - flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) - self.flow.encoder = flow_encoder - - def load(self, flow_model, hift_model): - self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) - self.flow.to(self.device).eval() - # in case hift_model is a hifigan model - hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} - self.hift.load_state_dict(hift_state_dict, strict=True) - self.hift.to(self.device).eval() - - def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16): - assert torch.cuda.is_available(), 'tensorrt only supports gpu!' - if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: - convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) - del self.flow.decoder.estimator - import tensorrt as trt - with open(flow_decoder_estimator_model, 'rb') as f: - estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) - assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) - self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) - - def get_trt_kwargs(self): - min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] - opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)] - max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] - input_names = ["x", "mask", "mu", "cond"] - return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} - - def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): - with torch.cuda.amp.autocast(self.fp16): - tts_mel, _ = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device), - streaming=stream, - finalize=finalize) - tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] - # append hift cache - if self.hift_cache_dict[uuid] is not None: - hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] - tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) - else: - hift_cache_source = torch.zeros(1, 1, 0) - # keep overlap mel and hift cache - if finalize is False: - tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) - if self.hift_cache_dict[uuid] is not None: - tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) - self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:], - 'source': tts_source[:, :, -self.source_cache_len:], - 'speech': tts_speech[:, -self.source_cache_len:]} - tts_speech = tts_speech[:, :-self.source_cache_len] - else: - if speed != 1.0: - assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode' - tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear') - tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) - if self.hift_cache_dict[uuid] is not None: - tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) - return tts_speech - + # Create a SHA-256 hash of the byte string + hasher = hashlib.sha256() + hasher.update(tensor_bytes) + + return hasher.hexdigest() class TritonPythonModel: """Triton Python model for vocoder. @@ -183,16 +87,10 @@ class TritonPythonModel: self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger.info(f"Initializing vocoder from {model_dir} on {self.device}") - self.token2wav_model = CosyVoice2( - model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device + # FIXME: device id settings + self.token2wav_model = CosyVoice2_Token2Wav( + model_dir, enable_trt=True, streaming=True ) - - spk_info_path = os.path.join(model_dir, "spk2info.pt") - if not os.path.exists(spk_info_path): - raise ValueError(f"spk2info.pt not found in {model_dir}") - spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) - self.default_spk_info = spk_info["001"] - logger.info("Token2Wav initialized successfully") def execute(self, requests): @@ -208,66 +106,31 @@ class TritonPythonModel: # Process each request in batch for request in requests: target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() - target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device) - - prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens") - if prompt_speech_tokens_tensor is not None: - prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy() - prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy() - prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy() - prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device) - prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device) - prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device) - prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE - else: - prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device) - prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device) - prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device) - + target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device) # shift the speech tokens according to the original vocab size target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE + target_speech_tokens = target_speech_tokens.squeeze().tolist() # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. - token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset") - if token_offset is not None: - token_offset = token_offset.as_numpy().item() - finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() - if not finalize: - stream = True - else: - stream = False - request_id = request.request_id() - audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens, - prompt_token=prompt_speech_tokens, - prompt_feat=prompt_speech_feat, - embedding=prompt_spk_embedding, - token_offset=token_offset, - uuid=request_id, - stream=stream, - finalize=finalize) - if finalize: - self.token2wav_model.model.hift_cache_dict.pop(request_id) + + finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() + + request_id = request.request_id() + - else: - tts_mel, _ = self.token2wav_model.model.flow.inference( - token=target_speech_tokens, - token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( - self.device - ), - prompt_token=prompt_speech_tokens, - prompt_token_len=torch.tensor( - [prompt_speech_tokens.shape[1]], dtype=torch.int32 - ).to(self.device), - prompt_feat=prompt_speech_feat, - prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=prompt_spk_embedding, - streaming=False, - finalize=True, - ) + wav_array = pb_utils.get_input_tensor_by_name( + request, "reference_wav").as_numpy() + wav_len = pb_utils.get_input_tensor_by_name( + request, "reference_wav_len").as_numpy().item() - audio_hat, _ = self.token2wav_model.model.hift.inference( - speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) - ) + wav_array = torch.from_numpy(wav_array) + # Prepare inputs + wav = wav_array[:, :wav_len].squeeze(0) + + spk_id = get_spk_id_from_prompt_audio(wav) + # wav = wav.to(self.device) + + audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) generated_wave = audio_hat.squeeze(0).cpu().numpy() diff --git a/runtime/triton_trtllm/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py similarity index 91% rename from runtime/triton_trtllm/token2wav_dit.py rename to runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index fdc1a12..3b696e9 100644 --- a/runtime/triton_trtllm/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -362,17 +362,17 @@ class CosyVoice2_Token2Wav(torch.nn.Module): spk_emb_for_flow.to(self.device), n_timesteps=10 ) + new_cache = {k: v.clone() for k, v in cache.items()} # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] - cache['estimator_att_cache'] = cache['estimator_att_cache'].clone() - cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone() - return cache + return new_cache @torch.inference_mode() def forward_streaming( self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 - ): + ): if speaker_id not in self.speaker_cache: + # if 1: assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio_sample_rate == 16000 @@ -382,10 +382,21 @@ class CosyVoice2_Token2Wav(torch.nn.Module): prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous() prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len] - cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} + + # if speaker_id not in self.speaker_cache: + # if 1: + cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} + print(f"speaker_id {speaker_id} added to cache") + + # get a clone of cache dict ['estimator_att_cache'] and later check if it would be change + att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'].clone() + cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'].clone() + conformer_cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'].clone() + conformer_att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'].clone() + if request_id not in self.streaming_flow_cache: self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} @@ -409,6 +420,33 @@ class CosyVoice2_Token2Wav(torch.nn.Module): n_timesteps=10, ) + # get the original att_cache + original_att_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'] + original_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'] + original_conformer_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'] + original_conformer_att_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'] + if not torch.allclose(original_att_cache, att_cache_clone): + print("att_cache changed") + # print the last 10 elements of original_att_cache and att_cache_clone + print(original_att_cache[:, :, :, -10:]) + print(att_cache_clone[:, :, :, -10:]) + breakpoint() + if not torch.allclose(original_cnn_cache, cnn_cache_clone): + print("cnn_cache changed") + print(original_cnn_cache[..., -10:]) + print(cnn_cache_clone[..., -10:]) + breakpoint() + if not torch.allclose(original_conformer_cnn_cache, conformer_cnn_cache_clone): + print("conformer_cnn_cache changed") + print(original_conformer_cnn_cache[..., -10:]) + print(conformer_cnn_cache_clone[..., -10:]) + breakpoint() + if not torch.allclose(original_conformer_att_cache, conformer_att_cache_clone): + print("conformer_att_cache changed") + print(original_conformer_att_cache[..., -10:]) + print(conformer_att_cache_clone[..., -10:]) + breakpoint() + self.streaming_flow_cache[request_id] = new_streaming_flow_cache @@ -420,10 +458,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module): - hift_cache_mel = self.hift_cache_dict[request_id]['mel'] - hift_cache_source = self.hift_cache_dict[request_id]['source'] - hift_cache_speech = self.hift_cache_dict[request_id]['speech'] - mel = torch.concat([hift_cache_mel, chunk_mel], dim=2) + hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone() + hift_cache_source = self.hift_cache_dict[request_id]['source'].clone() + hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone() + mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone() speech, source = self.hift(mel, hift_cache_source) @@ -444,7 +482,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): assert request_id in self.streaming_flow_cache self.streaming_flow_cache.pop(request_id) self.hift_cache_dict.pop(request_id) - + # breakpoint() return speech def collate_fn(batch): diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt index c33a85f..2040cfe 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: "token2wav" +name: "token2wav_dit" backend: "python" max_batch_size: ${triton_max_batch_size} dynamic_batching { max_queue_delay_microseconds: ${max_queue_delay_microseconds} + priority_levels: 10 + default_priority_level: 10 } parameters [ { @@ -32,29 +34,14 @@ input [ dims: [-1] }, { - name: "prompt_speech_tokens" - data_type: TYPE_INT32 + name: "reference_wav" + data_type: TYPE_FP32 dims: [-1] - optional: true }, { - name: "prompt_speech_feat" - data_type: TYPE_FP16 - dims: [-1, 80] - optional: true - }, - { - name: "prompt_spk_embedding" - data_type: TYPE_FP16 - dims: [-1] - optional: true - }, - { - name: "token_offset" + name: "reference_wav_len" data_type: TYPE_INT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true + dims: [1] }, { name: "finalize" diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py index 6f1a836..30c3b3b 100644 --- a/runtime/triton_trtllm/offline_inference.py +++ b/runtime/triton_trtllm/offline_inference.py @@ -43,6 +43,9 @@ import soundfile as sf import s3tokenizer from functools import partial import time +import requests +import asyncio +import httpx from token2wav import CosyVoice2_Token2Wav @@ -53,6 +56,32 @@ except RuntimeError: pass +async def send_request_async(client, url, payload): + response = await client.post(url, json=payload, timeout=None) + response.raise_for_status() + response_json = response.json() + return response_json['choices'][0]['message']['content'] + + +async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k): + async with httpx.AsyncClient() as client: + tasks = [] + for chat in chats: + payload = { + "model": model_name, + "messages": chat, + "max_tokens": 2048, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "repetition_penalty": 1.1, + "stop": ["<|eos1|>", "<|eos|>"], + "stream": False, + } + tasks.append(send_request_async(client, api_base, payload)) + return await asyncio.gather(*tasks) + + def extract_speech_ids(speech_tokens_str): """Extract speech IDs from token strings like <|s_23456|>""" speech_ids = [] @@ -149,7 +178,7 @@ def get_args(): "--backend", type=str, default="hf", - choices=["hf", "trtllm", "vllm"], + choices=["hf", "trtllm", "vllm", "trtllm-serve"], help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM", ) parser.add_argument( @@ -164,6 +193,18 @@ def get_args(): default=0.6, help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)", ) + parser.add_argument( + "--openai-api-base", + type=str, + default="http://localhost:8000/v1/chat/completions", + help="OpenAI API base URL (for trtllm-serve backend)", + ) + parser.add_argument( + "--openai-model-name", + type=str, + default="trt_engines_bfloat16", + help="Model name to use with OpenAI API (for trtllm-serve backend)", + ) args = parser.parse_args() return args @@ -180,6 +221,7 @@ def data_collator(batch, tokenizer, s3_tokenizer): input_ids_list, prompt_audio_list, prompt_text_list = [], [], [] prompt_text_after_apply_template_list = [] mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], [] + chat_list = [] for _, item in enumerate(batch): audio_processing_start_time = time.time() prompt_text, target_text = ( @@ -237,6 +279,7 @@ def data_collator(batch, tokenizer, s3_tokenizer): {"role": "user", "content": full_text_list[i]}, {"role": "assistant", "content": prompt_audio_cosy2_id_str} ] + chat_list.append(chat) assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template" @@ -265,6 +308,7 @@ def data_collator(batch, tokenizer, s3_tokenizer): "audio_processing_time": total_audio_processing_time, "speech_tokenization_time": total_speech_tokenization_time, "text_tokenization_time": total_text_tokenization_time, + "chat_list": chat_list } @@ -318,6 +362,9 @@ def main(args): elif args.backend == "vllm": model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4) runner = None + elif args.backend == "trtllm-serve": + model = None + runner = None else: raise ValueError(f"Unsupported backend: {args.backend}") @@ -452,6 +499,35 @@ def main(args): print(outputs) for j, output in enumerate(outputs): outputs[j] = input_ids_list[j] + output.outputs[0].token_ids + elif args.backend == "trtllm-serve": + if args.batch_size > 1: + outputs = asyncio.run(send_batch_requests_async( + args.openai_api_base, + args.openai_model_name, + batch["chat_list"], + args.temperature, + args.top_p, + args.top_k, + )) + else: + outputs = [] + for i, chat in enumerate(batch["chat_list"]): + payload = { + "model": args.openai_model_name, + "messages": chat, + "max_tokens": 2048, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": args.top_k, + "repetition_penalty": 1.1, + "stop": ["<|eos1|>", "<|eos|>"], + "stream": False, + } + response = requests.post(args.openai_api_base, json=payload) + response.raise_for_status() + response_json = response.json() + generated_content = response_json['choices'][0]['message']['content'] + outputs.append(generated_content) llm_end_time = time.time() total_llm_time += (llm_end_time - llm_start_time) @@ -459,10 +535,21 @@ def main(args): items_for_token_2wav = [] for i in range(len(batch["ids"])): llm_post_processing_start_time = time.time() - input_length = len(batch["input_ids"][i]) - generated_ids = outputs[i][input_length:] - speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - speech_ids = extract_speech_ids(speech_tokens_str) + if args.backend == "trtllm-serve": + speech_tokens_str = outputs[i].strip().split('><') + if len(speech_tokens_str) > 1: + speech_tokens_str = [ + t if t.startswith('<') else '<' + t for t in speech_tokens_str + ] + speech_tokens_str = [ + t if t.endswith('>') else t + '>' for t in speech_tokens_str + ] + speech_ids = extract_speech_ids(speech_tokens_str) + else: + input_length = len(batch["input_ids"][i]) + generated_ids = outputs[i][input_length:] + speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + speech_ids = extract_speech_ids(speech_tokens_str) print(i, speech_ids) if len(speech_ids) == 0: print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping") @@ -558,6 +645,8 @@ if __name__ == "__main__": from tensorrt_llm.runtime import ModelRunnerCpp elif args.backend == "hf": from transformers import AutoModelForCausalLM + elif args.backend == "trtllm-serve": + pass else: raise ValueError(f"Unsupported backend: {args.backend}") main(args) diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index c0034c2..ad3407e 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -1,6 +1,6 @@ #!/bin/bash # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=1 cosyvoice_path=/workspace/CosyVoice cosyvoice_path=/workspace_yuekai/tts/CosyVoice stepaudio2_path=/workspace_yuekai/tts/Step-Audio2 @@ -16,7 +16,7 @@ trt_dtype=bfloat16 trt_weights_dir=./trt_weights_${trt_dtype} trt_engines_dir=./trt_engines_${trt_dtype} -model_repo=./model_repo_cosyvoice2 +model_repo=./model_repo_cosyvoice2_dit use_spk2info_cache=False @@ -58,40 +58,78 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then --engine_dir=$trt_engines_dir || exit 1 fi + +# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then +# echo "Creating model repository" +# rm -rf $model_repo +# mkdir -p $model_repo +# cosyvoice2_dir="cosyvoice2_dit" +# token2wav_dir="token2wav_dit" + +# cp -r ./model_repo/${cosyvoice2_dir} $model_repo +# cp -r ./model_repo/tensorrt_llm $model_repo +# cp -r ./model_repo/${token2wav_dir} $model_repo +# #if [ $use_spk2info_cache == "False" ]; then +# cp -r ./model_repo/audio_tokenizer $model_repo +# cp -r ./model_repo/speaker_embedding $model_repo +# #fi + +# ENGINE_PATH=$trt_engines_dir +# MAX_QUEUE_DELAY_MICROSECONDS=0 +# MODEL_DIR=$model_scope_model_local_dir +# LLM_TOKENIZER_DIR=$huggingface_model_local_dir +# BLS_INSTANCE_NUM=1 +# TRITON_MAX_BATCH_SIZE=16 +# DECOUPLED_MODE=True # True for streaming, False for offline +# STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav + +# python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} +# python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} +# python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 +# #if [ $use_spk2info_cache == "False" ]; then +# python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} +# python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} +# #fi +# fi + if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - echo "Creating model repository" + echo "Creating model repository async mode" rm -rf $model_repo mkdir -p $model_repo - cosyvoice2_dir="cosyvoice2" + cosyvoice2_dir="cosyvoice2_dit" + token2wav_dir="token2wav_dit" cp -r ./model_repo/${cosyvoice2_dir} $model_repo cp -r ./model_repo/tensorrt_llm $model_repo - cp -r ./model_repo/token2wav $model_repo - if [ $use_spk2info_cache == "False" ]; then + cp -r ./model_repo/${token2wav_dir} $model_repo + #if [ $use_spk2info_cache == "False" ]; then cp -r ./model_repo/audio_tokenizer $model_repo cp -r ./model_repo/speaker_embedding $model_repo - fi + #fi ENGINE_PATH=$trt_engines_dir MAX_QUEUE_DELAY_MICROSECONDS=0 MODEL_DIR=$model_scope_model_local_dir LLM_TOKENIZER_DIR=$huggingface_model_local_dir BLS_INSTANCE_NUM=4 - TRITON_MAX_BATCH_SIZE=16 + TRITON_MAX_BATCH_SIZE=32 DECOUPLED_MODE=True # True for streaming, False for offline + STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav - python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 - if [ $use_spk2info_cache == "False" ]; then + #if [ $use_spk2info_cache == "False" ]; then python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} - fi + #fi + rm -rf $model_repo/tensorrt_llm + # mv $model_repo/cosyvoice2_dit/1 $model_repo/cosyvoice2_dit/4 fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then echo "Starting Triton server" - tritonserver --model-repository $model_repo + tritonserver --model-repository $model_repo --http-port 18000 fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then @@ -112,26 +150,26 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then python3 client_grpc.py \ --server-addr localhost \ - --model-name cosyvoice2 \ + --model-name cosyvoice2_dit \ --num-tasks $num_task \ --mode $mode \ - --use-spk2info-cache $use_spk2info_cache \ --huggingface-dataset yuekai/seed_tts_cosy2 \ - --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache} + --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_no_att_cnn_cache_new fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then echo "stage 6: Offline inference benchmark" n_gpus=1 datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh - backend=trtllm # hf, trtllm, vllm + backend=trtllm-serve # hf, trtllm, vllm batch_sizes=(16 8 4 2 1) + batch_sizes=(16 8 4 2) token2wav_batch_size=1 for batch_size in ${batch_sizes[@]}; do for dataset in ${datasets[@]}; do output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size} - CUDA_VISIBLE_DEVICES=0 \ + CUDA_VISIBLE_DEVICES=1 \ python3 offline_inference.py \ --output-dir $output_dir \ --llm-model-name-or-path $huggingface_model_local_dir \ @@ -147,7 +185,31 @@ fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - python3 benchmark_streaming_token2wav.py --enable-trt + python3 streaming_inference.py +fi + + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 + +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + #! /usr/bin/env bash + curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "trt_engines_bfloat16", + "messages":[{"role": "user", "content": "Where is New York?"}, + {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}], + "max_tokens": 512, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "stop": ["<|eos1|>"], + "repetition_penalty": 1.2, + "stream": false + }' fi \ No newline at end of file diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py new file mode 100644 index 0000000..863358c --- /dev/null +++ b/runtime/triton_trtllm/streaming_inference.py @@ -0,0 +1,115 @@ +import torch +import os +import argparse +from datasets import load_dataset +from torch.utils.data import DataLoader +import numpy as np +import torchaudio +import time +from token2wav_dit import CosyVoice2_Token2Wav +import soundfile as sf + +def collate_fn(batch): + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] + prompt_speech_tokens_list, prompt_text_list = [], [] + for i, item in enumerate(batch): + generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) + audio = torch.from_numpy(item['prompt_audio']['array']).float() + prompt_audios_list.append(audio) + prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) + ids.append(item['id']) + prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens']) + prompt_text_list.append(item['prompt_text']) + + return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--enable-trt", action="store_true") + parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--output-dir", type=str, default="generated_wavs") + parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") + parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2") + return parser.parse_args() + + +def fake_generated_id_iter(generated_speech_tokens_list): + for i in range(len(generated_speech_tokens_list)): + yield generated_speech_tokens_list[i] + + + +if __name__ == "__main__": + args = get_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + dataset_name = args.dataset_name + dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) + data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) + + token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) + + flow_pre_lookahead_len = 3 + CHUNK_SIZE = 25 + OVERLAP_SIZE = 0 + + warmup_times = 3 + for _ in range(warmup_times): + start_time = time.time() + for batch in data_loader: + tts_speech_list = [] + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch + + id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0] + # if id != "unseen3_text5": + # continue + # else: + # a = torch.load("semantic_token_ids_arr_debug_871e2b90-42a7-4829-957c-b45e6a96fdb2.pt") + # generated_speech_tokens = a["semantic_token_ids_arr"] + # print(generated_speech_tokens) + assert prompt_audio_sample_rate == 16000 + + prompt_text = prompt_text_list[0] + prompt_speech_tokens = prompt_speech_tokens_list[0] + + + # generated_ids_iter = fake_generated_id_iter(generated_speech_tokens) + + semantic_token_ids_arr, token_offset = [], 0 + flow_prompt_speech_token_len = len(prompt_speech_tokens) + + buffer = generated_speech_tokens + output_wavs = [] + while True: + + if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len: + wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:] + + output_wavs.append(wavs) + + else: + wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + output_wavs.append(wavs) + break + + for i, wav in enumerate(output_wavs): + output_wavs[i] = wav.cpu().numpy().squeeze() + + + audios = output_wavs + reconstructed_audio = np.concatenate(audios) + # Save reconstructed audio + sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") + + + print(f"Saved {id}") + end_time = time.time() + + if _ == 0: + token2wav_model.speaker_cache = {} + print(f"Warmup time: {end_time - start_time} seconds") + From 31a0adc73d23c786745bb3663df829569d7fe98f Mon Sep 17 00:00:00 2001 From: root Date: Fri, 26 Sep 2025 14:51:41 +0800 Subject: [PATCH 04/15] mark stateless token2wav --- .../model_repo/cosyvoice2_dit/3/model.py | 157 +++++++++++------- .../model_repo/cosyvoice2_dit/config.pbtxt | 2 +- .../model_repo/token2wav_dit/1/model.py | 102 +++++++++--- .../token2wav_dit/1/token2wav_dit.py | 46 +---- .../model_repo/token2wav_dit/config.pbtxt | 80 +++++++++ 5 files changed, 266 insertions(+), 121 deletions(-) diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py index b4a6348..c472968 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py @@ -43,6 +43,7 @@ import torchaudio from matcha.utils.audio import mel_spectrogram +from datetime import datetime ORIGINAL_VOCAB_SIZE = 151663 torch.set_num_threads(1) @@ -86,6 +87,7 @@ class TritonPythonModel: model_params = {k: v["string_value"] for k, v in parameters.items()} self.logger.log_info(f"model_params:{model_params}") self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" + # self.dynamic_chunk_strategy = "equal" self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") # Initialize tokenizer @@ -105,7 +107,9 @@ class TritonPythonModel: if not os.path.exists(spk_info_path): raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) - # self.default_spk_info = spk_info["001"] + self.default_spk_info = spk_info["001"] + self.http_client = httpx.AsyncClient() + self.runtime_cache = {} def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: """Converts a tensor or list of speech token IDs to a string representation.""" @@ -131,7 +135,6 @@ class TritonPythonModel: {"role": "user", "content": full_text}, {"role": "assistant", "content": prompt_speech_tokens_str} ] - print(chat) payload = { "model": "trt_engines_bfloat16", @@ -148,31 +151,33 @@ class TritonPythonModel: api_base = "http://localhost:8000/v1/chat/completions" buffer = "" - async with httpx.AsyncClient() as client: - async with client.stream("POST", api_base, json=payload, timeout=None) as response: - response.raise_for_status() - async for line in response.aiter_lines(): - if line.startswith("data: "): - line_data = line[len("data: "):].strip() - if line_data == "[DONE]": - break - try: - json_data = json.loads(line_data) - content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") - if content: - buffer += content - while True: - match = re.search(r"<\|s_(\d+)\|>", buffer) - if not match: - break + async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response: + print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}") + print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}") + response.raise_for_status() + async for line in response.aiter_lines(): + if line.startswith("data: "): + line_data = line[len("data: "):].strip() + if line_data == "[DONE]": + break + try: + json_data = json.loads(line_data) + content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") + if content: + buffer += content + print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}") + while True: + match = re.search(r"<\|s_(\d+)\|>", buffer) + if not match: + break - token_num = int(match.group(1)) - final_id = token_num + ORIGINAL_VOCAB_SIZE - yield final_id - buffer = buffer[match.end():] - except json.JSONDecodeError: - self.logger.log_info(f"Skipping non-JSON line: {line_data}") - continue + token_num = int(match.group(1)) + final_id = token_num + ORIGINAL_VOCAB_SIZE + yield final_id + buffer = buffer[match.end():] + except json.JSONDecodeError: + self.logger.log_info(f"Skipping non-JSON line: {line_data}") + continue # Process any remaining complete tokens in the buffer after the stream ends while True: @@ -236,7 +241,7 @@ class TritonPythonModel: return prompt_spk_embedding - def forward_token2wav( + async def forward_token2wav( self, index: int, target_speech_tokens: torch.Tensor, @@ -258,20 +263,57 @@ class TritonPythonModel: target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] - + + # optional cache inputs + if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None: + # inputs_tensor.extend([ + # pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()), + # pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()), + # pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()), + # pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()), + # pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()), + # pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()), + # pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()), + # ]) + inputs_tensor.extend([ + self.runtime_cache[request_id]["conformer_cnn_cache"], + self.runtime_cache[request_id]["conformer_att_cache"], + self.runtime_cache[request_id]["estimator_cnn_cache"], + self.runtime_cache[request_id]["estimator_att_cache"], + self.runtime_cache[request_id]["mel"], + self.runtime_cache[request_id]["source"], + self.runtime_cache[request_id]["speech"], + ]) # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav_dit', - requested_output_names=['waveform'], + requested_output_names=[ + "waveform", + "conformer_cnn_cache", + "conformer_att_cache", + "estimator_cnn_cache", + "estimator_att_cache", + "mel", + "source", + "speech", + ], inputs=inputs_tensor, request_id=request_id, parameters={"priority": index+1}, ) - inference_response = inference_request.exec() + inference_response = await inference_request.async_exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) + self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache") + self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache") + self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache") + self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache") + self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel") + self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source") + self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech") + # Extract and convert output waveform waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() @@ -297,6 +339,16 @@ class TritonPythonModel: async def _process_request(self, request): request_id = request.request_id() + if request_id not in self.runtime_cache: + self.runtime_cache[request_id] = { + "conformer_cnn_cache": None, + "conformer_att_cache": None, + "estimator_cnn_cache": None, + "estimator_att_cache": None, + "mel": None, + "source": None, + "speech": None, + } # Extract input tensors wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") @@ -308,6 +360,7 @@ class TritonPythonModel: wav_tensor = wav.as_numpy() wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}") prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) speech_feat = self._extract_speech_feat(prompt_speech_resample) token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) @@ -316,7 +369,7 @@ class TritonPythonModel: reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') - # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) # reference_text = self.default_spk_info["prompt_text"] # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE @@ -333,6 +386,7 @@ class TritonPythonModel: target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode('utf-8') + print(f"target_text: {target_text}, time: {datetime.now()}") if self.decoupled: response_sender = request.get_response_sender() @@ -341,7 +395,7 @@ class TritonPythonModel: token_offset, chunk_index = 0, 0 start_time = time.time() this_token_hop_len = self.token_hop_len - + print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}") async for generated_ids in self.forward_llm_async( target_text=target_text, reference_text=reference_text, @@ -350,18 +404,18 @@ class TritonPythonModel: if not generated_ids: break semantic_token_ids_arr.append(generated_ids) - + print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}") while True: pending_num = len(semantic_token_ids_arr) - token_offset if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) - - sub_tts_speech = self.forward_token2wav( + print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}") + sub_tts_speech = await self.forward_token2wav( chunk_index, this_tts_speech_token, request_id, wav, wav_len, False ) - + print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}") audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) @@ -371,6 +425,8 @@ class TritonPythonModel: if self.dynamic_chunk_strategy == "exponential": this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) + elif self.dynamic_chunk_strategy == "equal": + this_token_hop_len = self.token_hop_len elif self.dynamic_chunk_strategy == "time_based": # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 cost_time = time.time() - start_time @@ -393,29 +449,13 @@ class TritonPythonModel: break this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) - sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) + sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) - - ## debug - ## save semantic_token_ids_arr and reference_text, target_text to a single json file - # save into a torch .pt - # for i, item in enumerate(semantic_token_ids_arr): - # semantic_token_ids_arr[i] = item - ORIGINAL_VOCAB_SIZE - # import json - # data = { - # "semantic_token_ids_arr": semantic_token_ids_arr, - # "reference_text": reference_text, - # "target_text": target_text - # } - # with open(f"semantic_token_ids_arr_debug_{request_id}.pt", "wb") as f: - # torch.save(data, f) - # with open(f"semantic_token_ids_arr_debug_{request_id}.json", "w") as f: - # json.dump(data, f) - - # ## - + if request_id in self.runtime_cache: + del self.runtime_cache[request_id] + self.logger.log_info(f"Deleted cache for request_id: {request_id}") response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) self.logger.log_info("send tritonserver_response_complete_final to end") else: @@ -436,3 +476,8 @@ class TritonPythonModel: ] await asyncio.gather(*tasks) return None + + def finalize(self): + self.logger.log_info("Finalizing CosyVoice DIT model") + if hasattr(self, "http_client"): + asyncio.run(self.http_client.aclose()) diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt index e64647e..b119227 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt @@ -31,7 +31,7 @@ parameters [ value: {string_value:"${model_dir}"} } ] - +parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}} input [ { name: "reference_wav" diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index 8f9ffba..e95ce99 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -103,39 +103,91 @@ class TritonPythonModel: List of inference responses containing generated waveforms """ responses = [] - # Process each request in batch for request in requests: - target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() - target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device) - # shift the speech tokens according to the original vocab size - target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE + request_id = request.request_id() + + # Get inputs + target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens") + target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack()) target_speech_tokens = target_speech_tokens.squeeze().tolist() - # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. - finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() - - request_id = request.request_id() - - - wav_array = pb_utils.get_input_tensor_by_name( - request, "reference_wav").as_numpy() - wav_len = pb_utils.get_input_tensor_by_name( - request, "reference_wav_len").as_numpy().item() - - wav_array = torch.from_numpy(wav_array) - # Prepare inputs - wav = wav_array[:, :wav_len].squeeze(0) - + wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy() + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item() + wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0) spk_id = get_spk_id_from_prompt_audio(wav) - # wav = wav.to(self.device) - audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) + # Handle cache + conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache") + if conformer_cnn_cache is not None: + self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack()) + + conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache") + self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1) + + estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache") + self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0) - generated_wave = audio_hat.squeeze(0).cpu().numpy() + estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache") + self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0) + mel_np = pb_utils.get_input_tensor_by_name(request, "mel") + self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack()) + + source_np = pb_utils.get_input_tensor_by_name(request, "source") + self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack()) + + speech_np = pb_utils.get_input_tensor_by_name(request, "speech") + self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack()) + + # Forward pass + audio_hat = self.token2wav_model.forward_streaming( + target_speech_tokens, + finalize, + request_id=request_id, + speaker_id=f"{spk_id}", + prompt_audio=wav, + prompt_audio_sample_rate=16000 + ) + + # Prepare outputs + outputs = [] wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) - inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) - responses.append(inference_response) + outputs.append(wav_tensor) + + if request_id in self.token2wav_model.streaming_flow_cache: + cache = self.token2wav_model.streaming_flow_cache[request_id] + hifigan_cache = self.token2wav_model.hift_cache_dict[request_id] + conformer_cnn_cache = cache['conformer_cnn_cache'] + conformer_att_cache = cache['conformer_att_cache'].transpose(0,1) + estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0) + estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0) + mel = hifigan_cache['mel'] + source = hifigan_cache['source'] + speech = hifigan_cache['speech'] + outputs.extend([ + pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)), + pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)), + pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)), + pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)), + pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)), + pb_utils.Tensor.from_dlpack("source", to_dlpack(source)), + pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)), + ]) + else: + outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)), + pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)), + pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)), + pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)), + pb_utils.Tensor("mel", np.array([], dtype=np.float32)), + pb_utils.Tensor("source", np.array([], dtype=np.float32)), + pb_utils.Tensor("speech", np.array([], dtype=np.float32)), + ]) + + inference_response = pb_utils.InferenceResponse(output_tensors=outputs) + responses.append(inference_response) return responses + + def finalize(self): + self.logger.log_info("Finalizing Token2WavDiT model") diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index 3b696e9..63dce14 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -372,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 ): if speaker_id not in self.speaker_cache: - # if 1: assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio_sample_rate == 16000 @@ -384,20 +383,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module): prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} - # if speaker_id not in self.speaker_cache: - # if 1: - cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} print(f"speaker_id {speaker_id} added to cache") - # get a clone of cache dict ['estimator_att_cache'] and later check if it would be change - att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'].clone() - cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'].clone() - conformer_cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'].clone() - conformer_att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'].clone() - - if request_id not in self.streaming_flow_cache: self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.hift_cache_dict[request_id] = dict( @@ -405,6 +394,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module): source = torch.zeros(1, 1, 0, device='cuda'), speech = torch.zeros(1, 0, device='cuda'), ) + # else: + # for k, v in self.streaming_flow_cache[request_id].items(): + # print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}") + # for k, v in self.hift_cache_dict[request_id].items(): + # print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}") + # breakpoint() current_request_cache = self.streaming_flow_cache[request_id] @@ -420,33 +415,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): n_timesteps=10, ) - # get the original att_cache - original_att_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'] - original_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'] - original_conformer_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'] - original_conformer_att_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'] - if not torch.allclose(original_att_cache, att_cache_clone): - print("att_cache changed") - # print the last 10 elements of original_att_cache and att_cache_clone - print(original_att_cache[:, :, :, -10:]) - print(att_cache_clone[:, :, :, -10:]) - breakpoint() - if not torch.allclose(original_cnn_cache, cnn_cache_clone): - print("cnn_cache changed") - print(original_cnn_cache[..., -10:]) - print(cnn_cache_clone[..., -10:]) - breakpoint() - if not torch.allclose(original_conformer_cnn_cache, conformer_cnn_cache_clone): - print("conformer_cnn_cache changed") - print(original_conformer_cnn_cache[..., -10:]) - print(conformer_cnn_cache_clone[..., -10:]) - breakpoint() - if not torch.allclose(original_conformer_att_cache, conformer_att_cache_clone): - print("conformer_att_cache changed") - print(original_conformer_att_cache[..., -10:]) - print(conformer_att_cache_clone[..., -10:]) - breakpoint() - self.streaming_flow_cache[request_id] = new_streaming_flow_cache @@ -482,7 +450,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): assert request_id in self.streaming_flow_cache self.streaming_flow_cache.pop(request_id) self.hift_cache_dict.pop(request_id) - # breakpoint() + return speech def collate_fn(batch): diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt index 2040cfe..aed7561 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt @@ -15,11 +15,14 @@ name: "token2wav_dit" backend: "python" max_batch_size: ${triton_max_batch_size} + dynamic_batching { max_queue_delay_microseconds: ${max_queue_delay_microseconds} priority_levels: 10 default_priority_level: 10 } + +parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}} parameters [ { key: "model_dir", @@ -49,6 +52,48 @@ input [ dims: [ 1 ] reshape: { shape: [ ] } optional: true + }, + { + name: "conformer_cnn_cache" + data_type: TYPE_FP16 + dims: [ 512, -1 ] + optional: true + }, + { + name: "conformer_att_cache" + data_type: TYPE_FP16 + dims: [ 10, 8, -1, 128 ] + optional: true + }, + { + name: "estimator_cnn_cache" + data_type: TYPE_FP16 + dims: [ 10, 16, -1, 1024, 2 ] + optional: true + }, + { + name: "estimator_att_cache" + data_type: TYPE_FP16 + dims: [ 10, 16, -1, 8, -1, 128 ] + optional: true + }, + { + name: "mel" + data_type: TYPE_FP32 + dims: [ 80, -1 ] + optional: true + }, + { + name: "source" + data_type: TYPE_FP32 + dims: [ 1, -1 ] + optional: true + }, + { + name: "speech" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true } ] output [ @@ -56,6 +101,41 @@ output [ name: "waveform" data_type: TYPE_FP32 dims: [ -1 ] + }, + { + name: "conformer_cnn_cache" + data_type: TYPE_FP16 + dims: [ 512, -1 ] + }, + { + name: "conformer_att_cache" + data_type: TYPE_FP16 + dims: [ 10, 8, -1, 128 ] + }, + { + name: "estimator_cnn_cache" + data_type: TYPE_FP16 + dims: [ 10, 16, -1, 1024, 2 ] + }, + { + name: "estimator_att_cache" + data_type: TYPE_FP16 + dims: [ 10, 16, -1, 8, -1, 128 ] + }, + { + name: "mel" + data_type: TYPE_FP32 + dims: [ 80, -1 ] + }, + { + name: "source" + data_type: TYPE_FP32 + dims: [ 1, -1 ] + }, + { + name: "speech" + data_type: TYPE_FP32 + dims: [ -1 ] } ] From 79116ac32e296f6fa3d684b3115342fec67778af Mon Sep 17 00:00:00 2001 From: root Date: Fri, 26 Sep 2025 15:14:31 +0800 Subject: [PATCH 05/15] remove cache router --- runtime/triton_trtllm/client_grpc.py | 194 +++++++++++++++--- .../model_repo/cosyvoice2_dit/3/model.py | 52 +---- .../model_repo/cosyvoice2_dit/config.pbtxt | 2 +- .../model_repo/token2wav_dit/1/model.py | 104 +++------- .../model_repo/token2wav_dit/config.pbtxt | 78 ------- .../run_stepaudio2_dit_token2wav.sh | 10 +- runtime/triton_trtllm/streaming_inference.py | 22 +- 7 files changed, 219 insertions(+), 243 deletions(-) diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index afbab68..7aa8d7d 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -59,12 +59,14 @@ import tritonclient.grpc.aio as grpcclient_aio # Renamed original import import tritonclient.grpc as grpcclient_sync # Added sync client import from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException +from datetime import datetime # --- Added UserData and callback --- class UserData: def __init__(self): self._completed_requests = queue.Queue() self._first_chunk_time = None + self._second_chunk_time = None self._start_time = None def record_start_time(self): @@ -75,14 +77,44 @@ class UserData: return self._first_chunk_time - self._start_time return None + def get_second_chunk_latency(self): + if self._first_chunk_time and self._second_chunk_time: + return self._second_chunk_time - self._first_chunk_time + return None + def callback(user_data, result, error): - if user_data._first_chunk_time is None and not error: - user_data._first_chunk_time = time.time() # Record time of first successful chunk + if not error: + if user_data._first_chunk_time is None: + user_data._first_chunk_time = time.time() # Record time of first successful chunk + elif user_data._second_chunk_time is None: + user_data._second_chunk_time = time.time() + if error: user_data._completed_requests.put(error) else: user_data._completed_requests.put(result) + + +def stream_callback(user_data_map, result, error): + request_id = None + if error: + # Note: InferenceServerException doesn't have a public request_id() method in all versions. + # This part might need adjustment depending on the tritonclient library version. + # A more robust way would be to wrap the error with the request_id if possible. + # For now, we assume we can't get request_id from error and it will timeout on the client side. + print(f"An error occurred in the stream callback: {error}") + else: + request_id = result.get_response().id + + if request_id: + user_data = user_data_map.get(request_id) + if user_data: + callback(user_data, result, error) + else: + print(f"Warning: Could not find user_data for request_id {request_id}") + + # --- End Added UserData and callback --- @@ -142,6 +174,68 @@ def write_triton_stats(stats, summary_file): ) +def subtract_stats(stats_after, stats_before): + """Subtracts two Triton inference statistics objects.""" + # Deep copy to avoid modifying the original stats_after + stats_diff = json.loads(json.dumps(stats_after)) + + model_stats_before_map = { + s["name"]: { + "version": s["version"], + "last_inference": s.get("last_inference", 0), + "inference_count": s.get("inference_count", 0), + "execution_count": s.get("execution_count", 0), + "inference_stats": s.get("inference_stats", {}), + "batch_stats": s.get("batch_stats", []), + } + for s in stats_before["model_stats"] + } + + for model_stat_after in stats_diff["model_stats"]: + model_name = model_stat_after["name"] + if model_name in model_stats_before_map: + model_stat_before = model_stats_before_map[model_name] + + # Subtract counts + model_stat_after["inference_count"] = str( + int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0)) + ) + model_stat_after["execution_count"] = str( + int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0)) + ) + + # Subtract aggregate stats (like queue, compute times) + if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before: + for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]: + if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]: + if "ns" in model_stat_after["inference_stats"][key]: + ns_after = int(model_stat_after["inference_stats"][key]["ns"]) + ns_before = int(model_stat_before["inference_stats"][key]["ns"]) + model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before) + if "count" in model_stat_after["inference_stats"][key]: + count_after = int(model_stat_after["inference_stats"][key]["count"]) + count_before = int(model_stat_before["inference_stats"][key]["count"]) + model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before) + + # Subtract batch execution stats + if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before: + batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]} + for batch_stat_after in model_stat_after["batch_stats"]: + bs = batch_stat_after["batch_size"] + if bs in batch_stats_before_map: + batch_stat_before = batch_stats_before_map[bs] + for key in ["compute_input", "compute_infer", "compute_output"]: + if key in batch_stat_after and key in batch_stat_before: + count_after = int(batch_stat_after[key]["count"]) + count_before = int(batch_stat_before[key]["count"]) + batch_stat_after[key]["count"] = str(count_after - count_before) + + ns_after = int(batch_stat_after[key]["ns"]) + ns_before = int(batch_stat_before[key]["ns"]) + batch_stat_after[key]["ns"] = str(ns_after - ns_before) + return stats_diff + + def get_args(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -357,10 +451,10 @@ def run_sync_streaming_inference( """Helper function to run the blocking sync streaming call.""" start_time_total = time.time() user_data.record_start_time() # Record start time for first chunk latency calculation + # e.g. 08:47:34.827758 - # Establish stream - sync_triton_client.start_stream(callback=functools.partial(callback, user_data)) - + print(f"Record start time in human readable: {datetime.now()}") + # input() # Send request sync_triton_client.async_stream_infer( model_name, @@ -374,11 +468,11 @@ def run_sync_streaming_inference( audios = [] while True: try: - result = user_data._completed_requests.get() # Add timeout + result = user_data._completed_requests.get(timeout=20) # Add timeout if isinstance(result, InferenceServerException): print(f"Received InferenceServerException: {result}") - sync_triton_client.stop_stream() - return None, None, None # Indicate error + # Don't stop the stream here, just return error + return None, None, None, None # Get response metadata response = result.get_response() final = response.parameters["triton_final_response"].bool_param @@ -393,13 +487,13 @@ def run_sync_streaming_inference( except queue.Empty: print(f"Timeout waiting for response for request id {request_id}") - sync_triton_client.stop_stream() - return None, None, None # Indicate error + # Don't stop stream here, just return error + return None, None, None, None - sync_triton_client.stop_stream() end_time_total = time.time() total_request_latency = end_time_total - start_time_total first_chunk_latency = user_data.get_first_chunk_latency() + second_chunk_latency = user_data.get_second_chunk_latency() # Reconstruct audio using cross-fade (from client_grpc_streaming.py) actual_duration = 0 @@ -448,7 +542,7 @@ def run_sync_streaming_inference( print("Warning: No audio chunks received.") actual_duration = 0 - return total_request_latency, first_chunk_latency, actual_duration + return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration async def send_streaming( @@ -468,10 +562,12 @@ async def send_streaming( latency_data = [] task_id = int(name[5:]) sync_triton_client = None # Initialize client variable + user_data_map = {} try: # Wrap in try...finally to ensure client closing print(f"{name}: Initializing sync client for streaming...") sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here + sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map)) print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.") for i, item in enumerate(manifest_item_list): @@ -494,10 +590,11 @@ async def send_streaming( request_id = str(uuid.uuid4()) user_data = UserData() + user_data_map[request_id] = user_data audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") - - total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread( + print("target_text: ", target_text, "time: ", datetime.now()) + total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread( run_sync_streaming_inference, sync_triton_client, model_name, @@ -511,12 +608,18 @@ async def send_streaming( ) if total_request_latency is not None: - print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s") - latency_data.append((total_request_latency, first_chunk_latency, actual_duration)) + print( + f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, " + f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, " + f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s" + ) + latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration)) total_duration += actual_duration else: print(f"{name}: Item {i} failed.") + del user_data_map[request_id] + except FileNotFoundError: print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}") except Exception as e: @@ -527,7 +630,8 @@ async def send_streaming( finally: # Ensure client is closed if sync_triton_client: try: - print(f"{name}: Closing sync client...") + print(f"{name}: Closing stream and sync client...") + sync_triton_client.stop_stream() sync_triton_client.close() except Exception as e: print(f"{name}: Error closing sync client: {e}") @@ -685,9 +789,22 @@ async def main(): "target_text": dataset[i]["target_text"], } ) + # manifest_item_list = manifest_item_list[:4] else: manifest_item_list = load_manifests(args.manifest_path) + # --- Statistics Fetching (Before) --- + stats_client = None + stats_before = None + try: + print("Initializing temporary async client for fetching stats...") + stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) + print("Fetching inference statistics before running tasks...") + stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True) + except Exception as e: + print(f"Could not retrieve statistics before running tasks: {e}") + # --- End Statistics Fetching (Before) --- + num_tasks = min(args.num_tasks, len(manifest_item_list)) manifest_item_list = split_data(manifest_item_list, num_tasks) @@ -776,8 +893,9 @@ async def main(): elif args.mode == "streaming": # Calculate stats for total request latency and first chunk latency - total_latency_list = [total for (total, first, duration) in latency_data if total is not None] - first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None] + total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None] + first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None] + second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None] s += "\n--- Total Request Latency ---\n" if total_latency_list: @@ -804,6 +922,19 @@ async def main(): s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n" else: s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n" + + s += "\n--- Second Chunk Latency ---\n" + if second_chunk_latency_list: + avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0 + variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0 + s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n" + s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n" + s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n" + s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n" + s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n" + s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n" + else: + s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n" else: s += "No latency data collected.\n" # --- End Statistics Reporting --- @@ -822,20 +953,23 @@ async def main(): # --- Statistics Fetching using temporary Async Client --- # Use a separate async client for fetching stats regardless of mode - stats_client = None try: - print("Initializing temporary async client for fetching stats...") - stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) - print("Fetching inference statistics...") - # Fetching for all models, filtering might be needed depending on server setup - stats = await stats_client.get_inference_statistics(model_name="", as_json=True) - print("Fetching model config...") - metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True) + if stats_client and stats_before: + print("Fetching inference statistics after running tasks...") + stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True) - write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") + print("Calculating statistics difference...") + stats = subtract_stats(stats_after, stats_before) - with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: - json.dump(metadata, f, indent=4) + print("Fetching model config...") + metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True) + + write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") + + with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: + json.dump(metadata, f, indent=4) + else: + print("Stats client not available or initial stats were not fetched. Skipping stats reporting.") except Exception as e: print(f"Could not retrieve statistics or config: {e}") diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py index c472968..2f81786 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py @@ -109,7 +109,6 @@ class TritonPythonModel: spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) self.default_spk_info = spk_info["001"] self.http_client = httpx.AsyncClient() - self.runtime_cache = {} def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: """Converts a tensor or list of speech token IDs to a string representation.""" @@ -264,38 +263,11 @@ class TritonPythonModel: finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] - # optional cache inputs - if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None: - # inputs_tensor.extend([ - # pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()), - # pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()), - # pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()), - # pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()), - # pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()), - # pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()), - # pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()), - # ]) - inputs_tensor.extend([ - self.runtime_cache[request_id]["conformer_cnn_cache"], - self.runtime_cache[request_id]["conformer_att_cache"], - self.runtime_cache[request_id]["estimator_cnn_cache"], - self.runtime_cache[request_id]["estimator_att_cache"], - self.runtime_cache[request_id]["mel"], - self.runtime_cache[request_id]["source"], - self.runtime_cache[request_id]["speech"], - ]) # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav_dit', requested_output_names=[ "waveform", - "conformer_cnn_cache", - "conformer_att_cache", - "estimator_cnn_cache", - "estimator_att_cache", - "mel", - "source", - "speech", ], inputs=inputs_tensor, request_id=request_id, @@ -306,14 +278,6 @@ class TritonPythonModel: if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) - self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache") - self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache") - self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache") - self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache") - self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel") - self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source") - self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech") - # Extract and convert output waveform waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() @@ -339,16 +303,6 @@ class TritonPythonModel: async def _process_request(self, request): request_id = request.request_id() - if request_id not in self.runtime_cache: - self.runtime_cache[request_id] = { - "conformer_cnn_cache": None, - "conformer_att_cache": None, - "estimator_cnn_cache": None, - "estimator_att_cache": None, - "mel": None, - "source": None, - "speech": None, - } # Extract input tensors wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") @@ -369,7 +323,7 @@ class TritonPythonModel: reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') - prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) # reference_text = self.default_spk_info["prompt_text"] # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE @@ -453,9 +407,7 @@ class TritonPythonModel: audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) - if request_id in self.runtime_cache: - del self.runtime_cache[request_id] - self.logger.log_info(f"Deleted cache for request_id: {request_id}") + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) self.logger.log_info("send tritonserver_response_complete_final to end") else: diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt index b119227..e64647e 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt @@ -31,7 +31,7 @@ parameters [ value: {string_value:"${model_dir}"} } ] -parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}} + input [ { name: "reference_wav" diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index e95ce99..230bad0 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -103,91 +103,47 @@ class TritonPythonModel: List of inference responses containing generated waveforms """ responses = [] + # Process each request in batch for request in requests: - request_id = request.request_id() - - # Get inputs - target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens") - target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack()) + target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() + target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device) + # shift the speech tokens according to the original vocab size + target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE target_speech_tokens = target_speech_tokens.squeeze().tolist() + # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. + finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() - wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy() - wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item() - wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0) + + request_id = request.request_id() + + + wav_array = pb_utils.get_input_tensor_by_name( + request, "reference_wav").as_numpy() + wav_len = pb_utils.get_input_tensor_by_name( + request, "reference_wav_len").as_numpy().item() + + wav_array = torch.from_numpy(wav_array) + # Prepare inputs + wav = wav_array[:, :wav_len].squeeze(0) + spk_id = get_spk_id_from_prompt_audio(wav) + # wav = wav.to(self.device) - # Handle cache - conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache") - if conformer_cnn_cache is not None: - self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack()) - - conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache") - self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1) - - estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache") - self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0) + # update cache before forward + # self.token2wav_model.streaming_flow_cache[request_id] + # self.token2wav_model.hift_cache_dict[request_id] - estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache") - self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0) + audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) - mel_np = pb_utils.get_input_tensor_by_name(request, "mel") - self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack()) - - source_np = pb_utils.get_input_tensor_by_name(request, "source") - self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack()) - - speech_np = pb_utils.get_input_tensor_by_name(request, "speech") - self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack()) - - # Forward pass - audio_hat = self.token2wav_model.forward_streaming( - target_speech_tokens, - finalize, - request_id=request_id, - speaker_id=f"{spk_id}", - prompt_audio=wav, - prompt_audio_sample_rate=16000 - ) - - # Prepare outputs + # get the cache after forward outputs = [] + + generated_wave = audio_hat.squeeze(0).cpu().numpy() + wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) outputs.append(wav_tensor) - - if request_id in self.token2wav_model.streaming_flow_cache: - cache = self.token2wav_model.streaming_flow_cache[request_id] - hifigan_cache = self.token2wav_model.hift_cache_dict[request_id] - conformer_cnn_cache = cache['conformer_cnn_cache'] - conformer_att_cache = cache['conformer_att_cache'].transpose(0,1) - estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0) - estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0) - mel = hifigan_cache['mel'] - source = hifigan_cache['source'] - speech = hifigan_cache['speech'] - - outputs.extend([ - pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)), - pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)), - pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)), - pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)), - pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)), - pb_utils.Tensor.from_dlpack("source", to_dlpack(source)), - pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)), - ]) - else: - outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)), - pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)), - pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)), - pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)), - pb_utils.Tensor("mel", np.array([], dtype=np.float32)), - pb_utils.Tensor("source", np.array([], dtype=np.float32)), - pb_utils.Tensor("speech", np.array([], dtype=np.float32)), - ]) - inference_response = pb_utils.InferenceResponse(output_tensors=outputs) responses.append(inference_response) - return responses - def finalize(self): - self.logger.log_info("Finalizing Token2WavDiT model") + return responses diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt index aed7561..3f579aa 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt @@ -22,7 +22,6 @@ dynamic_batching { default_priority_level: 10 } -parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}} parameters [ { key: "model_dir", @@ -52,48 +51,6 @@ input [ dims: [ 1 ] reshape: { shape: [ ] } optional: true - }, - { - name: "conformer_cnn_cache" - data_type: TYPE_FP16 - dims: [ 512, -1 ] - optional: true - }, - { - name: "conformer_att_cache" - data_type: TYPE_FP16 - dims: [ 10, 8, -1, 128 ] - optional: true - }, - { - name: "estimator_cnn_cache" - data_type: TYPE_FP16 - dims: [ 10, 16, -1, 1024, 2 ] - optional: true - }, - { - name: "estimator_att_cache" - data_type: TYPE_FP16 - dims: [ 10, 16, -1, 8, -1, 128 ] - optional: true - }, - { - name: "mel" - data_type: TYPE_FP32 - dims: [ 80, -1 ] - optional: true - }, - { - name: "source" - data_type: TYPE_FP32 - dims: [ 1, -1 ] - optional: true - }, - { - name: "speech" - data_type: TYPE_FP32 - dims: [ -1 ] - optional: true } ] output [ @@ -101,41 +58,6 @@ output [ name: "waveform" data_type: TYPE_FP32 dims: [ -1 ] - }, - { - name: "conformer_cnn_cache" - data_type: TYPE_FP16 - dims: [ 512, -1 ] - }, - { - name: "conformer_att_cache" - data_type: TYPE_FP16 - dims: [ 10, 8, -1, 128 ] - }, - { - name: "estimator_cnn_cache" - data_type: TYPE_FP16 - dims: [ 10, 16, -1, 1024, 2 ] - }, - { - name: "estimator_att_cache" - data_type: TYPE_FP16 - dims: [ 10, 16, -1, 8, -1, 128 ] - }, - { - name: "mel" - data_type: TYPE_FP32 - dims: [ 80, -1 ] - }, - { - name: "source" - data_type: TYPE_FP32 - dims: [ 1, -1 ] - }, - { - name: "speech" - data_type: TYPE_FP32 - dims: [ -1 ] } ] diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index ad3407e..2eabcf4 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -1,6 +1,6 @@ #!/bin/bash # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) -export CUDA_VISIBLE_DEVICES=1 +export CUDA_VISIBLE_DEVICES=0 cosyvoice_path=/workspace/CosyVoice cosyvoice_path=/workspace_yuekai/tts/CosyVoice stepaudio2_path=/workspace_yuekai/tts/Step-Audio2 @@ -112,7 +112,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then MODEL_DIR=$model_scope_model_local_dir LLM_TOKENIZER_DIR=$huggingface_model_local_dir BLS_INSTANCE_NUM=4 - TRITON_MAX_BATCH_SIZE=32 + TRITON_MAX_BATCH_SIZE=1 DECOUPLED_MODE=True # True for streaming, False for offline STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav @@ -154,7 +154,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --num-tasks $num_task \ --mode $mode \ --huggingface-dataset yuekai/seed_tts_cosy2 \ - --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_no_att_cnn_cache_new + --log-dir ./log_debug_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM} fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then @@ -185,14 +185,14 @@ fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - python3 streaming_inference.py + CUDA_VISIBLE_DEVICES=2 python3 streaming_inference.py --enable-trt --strategy exponential fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 + CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 fi diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index 863358c..93c6758 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -31,6 +31,7 @@ def get_args(): parser.add_argument("--output-dir", type=str, default="generated_wavs") parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2") + parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"]) return parser.parse_args() @@ -53,12 +54,14 @@ if __name__ == "__main__": token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) flow_pre_lookahead_len = 3 - CHUNK_SIZE = 25 + CHUNK_SIZE = 15 + token_frame_rate = 25 OVERLAP_SIZE = 0 warmup_times = 3 for _ in range(warmup_times): start_time = time.time() + total_forward_count = 0 for batch in data_loader: tts_speech_list = [] ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch @@ -83,17 +86,26 @@ if __name__ == "__main__": buffer = generated_speech_tokens output_wavs = [] + chunk_index = 0 while True: + if args.strategy == "equal": + this_chunk_size = CHUNK_SIZE + elif args.strategy == "exponential": + this_chunk_size = token_frame_rate * (2 ** chunk_index) - if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len: - wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) - buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:] + if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len: + wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + buffer = buffer[this_chunk_size - OVERLAP_SIZE:] output_wavs.append(wavs) + total_forward_count += 1 + chunk_index += 1 else: wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) output_wavs.append(wavs) + total_forward_count += 1 + # chunk_index += 1 break for i, wav in enumerate(output_wavs): @@ -112,4 +124,4 @@ if __name__ == "__main__": if _ == 0: token2wav_model.speaker_cache = {} print(f"Warmup time: {end_time - start_time} seconds") - + print(f"Total forward count: {total_forward_count}") From 988d395162a7a717598ef25caa1d6733369a0cd3 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Oct 2025 14:06:19 +0800 Subject: [PATCH 06/15] mark multi client --- .../run_stepaudio2_dit_token2wav.sh | 74 ++++++++++++++++--- 1 file changed, 63 insertions(+), 11 deletions(-) diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index 2eabcf4..2c19a1d 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -9,6 +9,8 @@ export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH stage=$1 stop_stage=$2 +N_GPUS=2 # set the number of GPUs to use + huggingface_model_local_dir=./cosyvoice2_llm model_scope_model_local_dir=./CosyVoice2-0.5B @@ -128,8 +130,32 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - echo "Starting Triton server" - tritonserver --model-repository $model_repo --http-port 18000 + echo "Starting Triton server on $N_GPUS GPUs" + for i in $(seq 0 $(($N_GPUS - 1))); do + echo "Starting server on GPU $i" + http_port=$((19000 + $i)) + grpc_port=$((18000 + $i)) + metrics_port=$((17000 + $i)) + CUDA_VISIBLE_DEVICES=$i tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port & + done + + echo "Servers are running in the background. Press Ctrl+C to stop them and the script." + wait +fi + +if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then + echo "Starting Triton server on $N_GPUS GPUs" + N_GPUS=1 + for i in $(seq 0 $(($N_GPUS - 1))); do + echo "Starting server on GPU $i" + http_port=$((19000 + $i)) + grpc_port=$((18000 + $i)) + metrics_port=$((17000 + $i)) + CUDA_VISIBLE_DEVICES=0 tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port & + done + + echo "Servers are running in the background. Press Ctrl+C to stop them and the script." + wait fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then @@ -142,21 +168,47 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - echo "Running benchmark client grpc" - num_task=4 + echo "Running benchmark client grpc on $N_GPUS GPUs" + num_task=1 mode=streaming BLS_INSTANCE_NUM=4 - python3 client_grpc.py \ - --server-addr localhost \ - --model-name cosyvoice2_dit \ - --num-tasks $num_task \ - --mode $mode \ - --huggingface-dataset yuekai/seed_tts_cosy2 \ - --log-dir ./log_debug_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM} + for i in $(seq 0 $(($N_GPUS - 1))); do + grpc_port=$((18000 + $i)) + echo "Running client for server on localhost:$grpc_port" + python3 client_grpc.py \ + --server-addr localhost \ + --server-port $grpc_port \ + --model-name cosyvoice2_dit \ + --num-tasks $num_task \ + --mode $mode \ + --huggingface-dataset yuekai/seed_tts_cosy2 \ + --log-dir ./log_debug_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_gpu${i} & + done + wait fi +if [ $stage -le 50 ] && [ $stop_stage -ge 50 ]; then + echo "Running benchmark client grpc on $N_GPUS GPUs" + num_task=4 + N_GPUS=1 + mode=streaming + BLS_INSTANCE_NUM=4 + for i in $(seq 0 $(($N_GPUS - 1))); do + grpc_port=$((18000 + $i)) + echo "Running client for server on localhost:$grpc_port" + python3 client_grpc.py \ + --server-addr localhost \ + --server-port $grpc_port \ + --model-name cosyvoice2_dit \ + --num-tasks $num_task \ + --mode $mode \ + --huggingface-dataset yuekai/seed_tts_cosy2 \ + --log-dir ./log_single_card_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM} & + done + wait +fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then echo "stage 6: Offline inference benchmark" n_gpus=1 From f186ec33381544e78af6115d636c5590c78e4569 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Oct 2025 15:21:52 +0800 Subject: [PATCH 07/15] clean code --- .../model_repo/cosyvoice2_dit/1/model.py | 365 ++++++++------- .../model_repo/cosyvoice2_dit/3/model.py | 435 ------------------ runtime/triton_trtllm/offline_inference.py | 10 +- .../run_stepaudio2_dit_token2wav.sh | 232 +++------- runtime/triton_trtllm/streaming_inference.py | 24 +- 5 files changed, 266 insertions(+), 800 deletions(-) delete mode 100644 runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index d0977c5..2f81786 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -28,9 +28,10 @@ import json import math import os import re -import threading import time from typing import Dict, List, Tuple, Optional, Union +import asyncio +import httpx import numpy as np import torch @@ -42,11 +43,30 @@ import torchaudio from matcha.utils.audio import mel_spectrogram +from datetime import datetime ORIGINAL_VOCAB_SIZE = 151663 torch.set_num_threads(1) +def parse_speech_token_string(response_text: str) -> List[int]: + """ + Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs. + """ + speech_tokens = response_text.strip().split('><') + if len(speech_tokens) > 1: + # Add back the missing '<' and '>' for proper parsing + speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens] + speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens] + + speech_ids = [] + for token_str in speech_tokens: + match = re.match(r'<\|s_(\d+)\|>', token_str) + if match: + speech_ids.append(int(match.group(1))) + return speech_ids + + class TritonPythonModel: """Triton Python model for Spark TTS. @@ -67,6 +87,7 @@ class TritonPythonModel: model_params = {k: v["string_value"] for k, v in parameters.items()} self.logger.log_info(f"model_params:{model_params}") self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" + # self.dynamic_chunk_strategy = "equal" self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") # Initialize tokenizer @@ -87,92 +108,86 @@ class TritonPythonModel: raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) self.default_spk_info = spk_info["001"] + self.http_client = httpx.AsyncClient() - def forward_llm(self, input_ids): + def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: + """Converts a tensor or list of speech token IDs to a string representation.""" + if isinstance(speech_tokens, torch.Tensor): + # Ensure tensor is on CPU and flattened + speech_tokens = speech_tokens.cpu().numpy().flatten().tolist() + + speech_id_str = "" + for token_id in speech_tokens: + # Convert token ID back to the speech number N + token_num = token_id - ORIGINAL_VOCAB_SIZE + speech_id_str += f"<|s_{token_num}|>" + return speech_id_str + + async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]): """ - Prepares the response from the language model based on the provided - inputs. Creates a `pb_utils.InferenceRequest` object with passed - `llm_request_inputs` to send to a decoupled TensorRTLLM model. - For each response from the language model: - - Checks for errors and raise an exception if any are found. - - Extracts the "output_ids" tensor from the response. - - Determines the finish reason based on the presence of the - end-of-sequence token or reaching the maximum length. - - Appends the generated token IDs to `output_ids`. - - If the finish reason is determined, decodes the output IDs to text - and prepares the final response. - - The final response includes the generated text, finish reason, - completion tokens, prompt tokens, and total tokens. - - Parameters - ---------- - - llm_request_inputs (dict): A dictionary containing the inputs for the language model. - - Returns - ------- - - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata. + Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response. """ - # convert input_ids to numpy, with shape [1, sequence_length] - input_ids = input_ids.cpu().numpy() - max_tokens = 750 - input_dict = { - "request_output_len": np.array([[max_tokens]], dtype=np.int32), - "end_id": np.array([[self.eos_token_id]], dtype=np.int32), - "pad_id": np.array([[self.eos_token_id]], dtype=np.int32), - "streaming": np.array([[self.decoupled]], dtype=np.bool_), - "runtime_top_p": np.array([[0.95]], dtype=np.float32), - "runtime_top_k": np.array([[50]], dtype=np.int32), - "temperature": np.array([[0.8]], dtype=np.float32), - "repetition_penalty": np.array([[1.1]], dtype=np.float32), - "random_seed": np.array([[42]], dtype=np.uint64), - "input_ids": input_ids, - "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), - } + full_text = f"{reference_text}{target_text}" + prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens) - # Convert inputs to Triton tensors - input_tensor_list = [ - pb_utils.Tensor(k, v) for k, v in input_dict.items() + chat = [ + {"role": "user", "content": full_text}, + {"role": "assistant", "content": prompt_speech_tokens_str} ] - # Create and execute inference request - llm_request = pb_utils.InferenceRequest( - model_name="tensorrt_llm", - requested_output_names=["output_ids", "sequence_length"], - inputs=input_tensor_list, - ) + payload = { + "model": "trt_engines_bfloat16", + "messages": chat, + "max_tokens": 750, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "repetition_penalty": 1.1, + "stop": ["<|eos1|>", "<|eos|>"], + "stream": True, + } - llm_responses = llm_request.exec(decoupled=self.decoupled) - if self.decoupled: - for llm_response in llm_responses: - if llm_response.has_error(): - raise pb_utils.TritonModelException(llm_response.error().message()) + api_base = "http://localhost:8000/v1/chat/completions" - # Extract and process output - output_ids = pb_utils.get_output_tensor_by_name( - llm_response, "output_ids").as_numpy() - seq_lens = pb_utils.get_output_tensor_by_name( - llm_response, "sequence_length").as_numpy() + buffer = "" + async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response: + print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}") + print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}") + response.raise_for_status() + async for line in response.aiter_lines(): + if line.startswith("data: "): + line_data = line[len("data: "):].strip() + if line_data == "[DONE]": + break + try: + json_data = json.loads(line_data) + content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") + if content: + buffer += content + print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}") + while True: + match = re.search(r"<\|s_(\d+)\|>", buffer) + if not match: + break - # Get actual output IDs up to the sequence length - actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + token_num = int(match.group(1)) + final_id = token_num + ORIGINAL_VOCAB_SIZE + yield final_id + buffer = buffer[match.end():] + except json.JSONDecodeError: + self.logger.log_info(f"Skipping non-JSON line: {line_data}") + continue - yield actual_output_ids - else: - llm_response = llm_responses - if llm_response.has_error(): - raise pb_utils.TritonModelException(llm_response.error().message()) + # Process any remaining complete tokens in the buffer after the stream ends + while True: + match = re.search(r"<\|s_(\d+)\|>", buffer) + if not match: + break + token_num = int(match.group(1)) + final_id = token_num + ORIGINAL_VOCAB_SIZE + yield final_id + buffer = buffer[match.end():] - # Extract and process output - output_ids = pb_utils.get_output_tensor_by_name( - llm_response, "output_ids").as_numpy() - seq_lens = pb_utils.get_output_tensor_by_name( - llm_response, "sequence_length").as_numpy() - - # Get actual output IDs up to the sequence length - actual_output_ids = output_ids[0][0][:seq_lens[0][0]] - - yield actual_output_ids def forward_audio_tokenizer(self, wav, wav_len): """Forward pass through the audio tokenizer component. @@ -225,7 +240,7 @@ class TritonPythonModel: return prompt_spk_embedding - def forward_token2wav( + async def forward_token2wav( self, index: int, target_speech_tokens: torch.Tensor, @@ -247,17 +262,19 @@ class TritonPythonModel: target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] - + # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav_dit', - requested_output_names=['waveform'], + requested_output_names=[ + "waveform", + ], inputs=inputs_tensor, request_id=request_id, parameters={"priority": index+1}, ) - inference_response = inference_request.exec() + inference_response = await inference_request.async_exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) @@ -267,14 +284,6 @@ class TritonPythonModel: return waveform - def parse_input(self, text, prompt_text, prompt_speech_tokens): - total_text = f"{prompt_text}{text}" - prompt = self.prompt_template.format(input_text=total_text) - input_ids = self.tokenizer.encode(prompt) - input_ids = torch.tensor([input_ids], dtype=torch.int32) - input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1) - return input_ids - def _extract_speech_feat(self, speech): speech_feat = mel_spectrogram( speech, @@ -292,106 +301,75 @@ class TritonPythonModel: speech_feat = speech_feat.unsqueeze(dim=0) return speech_feat - def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag): - for generated_ids in generated_ids_iter: - generated_ids = generated_ids.tolist() - if len(generated_ids) == 0: - break - semantic_token_ids_arr.extend(generated_ids) - llm_is_done_flag[0] = True + async def _process_request(self, request): + request_id = request.request_id() + # Extract input tensors + wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") - def execute(self, requests): - """Execute inference on the batched requests. + # Process reference audio through audio tokenizer + if wav is not None: + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) + prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - Args: - requests: List of inference requests + wav_tensor = wav.as_numpy() + wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}") + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) + speech_feat = self._extract_speech_feat(prompt_speech_resample) + token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) + prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() + prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() - Returns: - List of inference responses containing generated audio - """ - responses = [] + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') + # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) - for request in requests: - request_id = request.request_id() - # Extract input tensors - wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") + # reference_text = self.default_spk_info["prompt_text"] + # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + # prompt_speech_feat = None + # prompt_spk_embedding = None - # Process reference audio through audio tokenizer - if wav is not None: - wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") - prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) - prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) + else: + # using pre-cached reference text + assert False, "using pre-cached reference text is not supported" + reference_text = self.default_spk_info["prompt_text"] + prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + prompt_speech_feat = None + prompt_spk_embedding = None - wav_tensor = wav.as_numpy() - wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) - speech_feat = self._extract_speech_feat(prompt_speech_resample) - token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) - prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() - prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = target_text[0][0].decode('utf-8') + print(f"target_text: {target_text}, time: {datetime.now()}") - reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() - reference_text = reference_text[0][0].decode('utf-8') - # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + if self.decoupled: + response_sender = request.get_response_sender() - # reference_text = self.default_spk_info["prompt_text"] - # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - # prompt_speech_feat = None - # prompt_spk_embedding = None - - else: - assert False, "wav is None" - # using pre-cached reference text - reference_text = self.default_spk_info["prompt_text"] - prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - prompt_speech_feat = None - prompt_spk_embedding = None - - target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() - target_text = target_text[0][0].decode('utf-8') - - # Prepare prompt for LLM - input_ids = self.parse_input( - text=target_text, - prompt_text=reference_text, + semantic_token_ids_arr = [] + token_offset, chunk_index = 0, 0 + start_time = time.time() + this_token_hop_len = self.token_hop_len + print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}") + async for generated_ids in self.forward_llm_async( + target_text=target_text, + reference_text=reference_text, prompt_speech_tokens=prompt_speech_tokens, - ) - - # Generate semantic tokens with LLM - generated_ids_iter = self.forward_llm(input_ids) - - if self.decoupled: - response_sender = request.get_response_sender() - - semantic_token_ids_arr = [] - llm_is_done_flag = [False] - - llm_thread = threading.Thread( - target=self._llm_gen_thread, - args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag) - ) - - llm_thread.start() - - token_offset, chunk_index = 0, 0 - start_time = time.time() - this_token_hop_len = self.token_hop_len - + ): + if not generated_ids: + break + semantic_token_ids_arr.append(generated_ids) + print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}") while True: pending_num = len(semantic_token_ids_arr) - token_offset - - if llm_is_done_flag[0]: - break - if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) - - sub_tts_speech = self.forward_token2wav( + print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}") + sub_tts_speech = await self.forward_token2wav( chunk_index, this_tts_speech_token, request_id, wav, wav_len, False ) - + print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}") audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) @@ -401,6 +379,8 @@ class TritonPythonModel: if self.dynamic_chunk_strategy == "exponential": this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) + elif self.dynamic_chunk_strategy == "equal": + this_token_hop_len = self.token_hop_len elif self.dynamic_chunk_strategy == "time_based": # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 cost_time = time.time() - start_time @@ -420,19 +400,36 @@ class TritonPythonModel: this_token_hop_len = max(self.token_hop_len, this_token_hop_len) chunk_index += 1 else: - time.sleep(0.02) + break + + this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) + sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) - this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) - sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) - audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) - inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) - response_sender.send(inference_response) + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + self.logger.log_info("send tritonserver_response_complete_final to end") + else: + raise NotImplementedError("Decoupled mode is not supported") - llm_thread.join() - response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) - self.logger.log_info("send tritonserver_response_complete_final to end") - else: - raise NotImplementedError("Decoupled mode is not supported") + async def execute(self, requests): + """Execute inference on the batched requests. - if not self.decoupled: - return responses + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated audio + """ + tasks = [ + asyncio.create_task(self._process_request(request)) + for request in requests + ] + await asyncio.gather(*tasks) + return None + + def finalize(self): + self.logger.log_info("Finalizing CosyVoice DIT model") + if hasattr(self, "http_client"): + asyncio.run(self.http_client.aclose()) diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py deleted file mode 100644 index 2f81786..0000000 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py +++ /dev/null @@ -1,435 +0,0 @@ -# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import json -import math -import os -import re -import time -from typing import Dict, List, Tuple, Optional, Union -import asyncio -import httpx - -import numpy as np -import torch -from torch.utils.dlpack import from_dlpack, to_dlpack -import triton_python_backend_utils as pb_utils -from transformers import AutoTokenizer - -import torchaudio - - -from matcha.utils.audio import mel_spectrogram -from datetime import datetime - -ORIGINAL_VOCAB_SIZE = 151663 -torch.set_num_threads(1) - - -def parse_speech_token_string(response_text: str) -> List[int]: - """ - Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs. - """ - speech_tokens = response_text.strip().split('><') - if len(speech_tokens) > 1: - # Add back the missing '<' and '>' for proper parsing - speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens] - speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens] - - speech_ids = [] - for token_str in speech_tokens: - match = re.match(r'<\|s_(\d+)\|>', token_str) - if match: - speech_ids.append(int(match.group(1))) - return speech_ids - - -class TritonPythonModel: - """Triton Python model for Spark TTS. - - This model orchestrates the end-to-end TTS pipeline by coordinating - between audio tokenizer, LLM, and vocoder components. - """ - - def initialize(self, args): - """Initialize the model. - - Args: - args: Dictionary containing model configuration - """ - self.logger = pb_utils.Logger - # Parse model parameters - self.model_config = json.loads(args['model_config']) - parameters = self.model_config['parameters'] - model_params = {k: v["string_value"] for k, v in parameters.items()} - self.logger.log_info(f"model_params:{model_params}") - self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" - # self.dynamic_chunk_strategy = "equal" - self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") - - # Initialize tokenizer - llm_tokenizer_dir = model_params["llm_tokenizer_dir"] - self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) - self.prompt_template = "<|sos|>{input_text}<|task_id|>" - self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>") - - self.device = torch.device("cuda") - self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) - - self.token_frame_rate = 25 - self.flow_pre_lookahead_len = 3 - self.token_hop_len = 15 - - spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt") - if not os.path.exists(spk_info_path): - raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") - spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) - self.default_spk_info = spk_info["001"] - self.http_client = httpx.AsyncClient() - - def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: - """Converts a tensor or list of speech token IDs to a string representation.""" - if isinstance(speech_tokens, torch.Tensor): - # Ensure tensor is on CPU and flattened - speech_tokens = speech_tokens.cpu().numpy().flatten().tolist() - - speech_id_str = "" - for token_id in speech_tokens: - # Convert token ID back to the speech number N - token_num = token_id - ORIGINAL_VOCAB_SIZE - speech_id_str += f"<|s_{token_num}|>" - return speech_id_str - - async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]): - """ - Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response. - """ - full_text = f"{reference_text}{target_text}" - prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens) - - chat = [ - {"role": "user", "content": full_text}, - {"role": "assistant", "content": prompt_speech_tokens_str} - ] - - payload = { - "model": "trt_engines_bfloat16", - "messages": chat, - "max_tokens": 750, - "temperature": 0.8, - "top_p": 0.95, - "top_k": 50, - "repetition_penalty": 1.1, - "stop": ["<|eos1|>", "<|eos|>"], - "stream": True, - } - - api_base = "http://localhost:8000/v1/chat/completions" - - buffer = "" - async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response: - print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}") - print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}") - response.raise_for_status() - async for line in response.aiter_lines(): - if line.startswith("data: "): - line_data = line[len("data: "):].strip() - if line_data == "[DONE]": - break - try: - json_data = json.loads(line_data) - content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") - if content: - buffer += content - print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}") - while True: - match = re.search(r"<\|s_(\d+)\|>", buffer) - if not match: - break - - token_num = int(match.group(1)) - final_id = token_num + ORIGINAL_VOCAB_SIZE - yield final_id - buffer = buffer[match.end():] - except json.JSONDecodeError: - self.logger.log_info(f"Skipping non-JSON line: {line_data}") - continue - - # Process any remaining complete tokens in the buffer after the stream ends - while True: - match = re.search(r"<\|s_(\d+)\|>", buffer) - if not match: - break - token_num = int(match.group(1)) - final_id = token_num + ORIGINAL_VOCAB_SIZE - yield final_id - buffer = buffer[match.end():] - - - def forward_audio_tokenizer(self, wav, wav_len): - """Forward pass through the audio tokenizer component. - - Args: - wav: Input waveform tensor - wav_len: Waveform length tensor - - Returns: - Tuple of global and semantic tokens - """ - inference_request = pb_utils.InferenceRequest( - model_name='audio_tokenizer', - requested_output_names=['prompt_speech_tokens'], - inputs=[wav, wav_len] - ) - - inference_response = inference_request.exec() - if inference_response.has_error(): - raise pb_utils.TritonModelException(inference_response.error().message()) - - # Extract and convert output tensors - prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') - prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() - - return prompt_speech_tokens - - def forward_speaker_embedding(self, wav): - """Forward pass through the speaker embedding component. - - Args: - wav: Input waveform tensor - - Returns: - Prompt speaker embedding tensor - """ - inference_request = pb_utils.InferenceRequest( - model_name='speaker_embedding', - requested_output_names=['prompt_spk_embedding'], - inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))] - ) - - inference_response = inference_request.exec() - if inference_response.has_error(): - raise pb_utils.TritonModelException(inference_response.error().message()) - - # Extract and convert output tensors - prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding') - prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack()) - - return prompt_spk_embedding - - async def forward_token2wav( - self, - index: int, - target_speech_tokens: torch.Tensor, - request_id: str, - reference_wav: object, - reference_wav_len: object, - finalize: bool = None) -> torch.Tensor: - """Forward pass through the vocoder component. - - Args: - prompt_speech_tokens: Prompt speech tokens tensor - prompt_speech_feat: Prompt speech feat tensor - prompt_spk_embedding: Prompt spk embedding tensor - target_speech_tokens: Target speech tokens tensor - - Returns: - Generated waveform tensor - """ - target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) - finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) - inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] - - # Create and execute inference request - inference_request = pb_utils.InferenceRequest( - model_name='token2wav_dit', - requested_output_names=[ - "waveform", - ], - inputs=inputs_tensor, - request_id=request_id, - parameters={"priority": index+1}, - ) - - inference_response = await inference_request.async_exec() - if inference_response.has_error(): - raise pb_utils.TritonModelException(inference_response.error().message()) - - # Extract and convert output waveform - waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') - waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() - - return waveform - - def _extract_speech_feat(self, speech): - speech_feat = mel_spectrogram( - speech, - n_fft=1920, - num_mels=80, - sampling_rate=24000, - hop_size=480, - win_size=1920, - fmin=0, - fmax=8000).squeeze( - dim=0).transpose( - 0, - 1).to( - self.device) - speech_feat = speech_feat.unsqueeze(dim=0) - return speech_feat - - async def _process_request(self, request): - request_id = request.request_id() - # Extract input tensors - wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") - - # Process reference audio through audio tokenizer - if wav is not None: - wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") - prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) - prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - - wav_tensor = wav.as_numpy() - wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] - print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}") - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) - speech_feat = self._extract_speech_feat(prompt_speech_resample) - token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) - prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() - prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() - - reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() - reference_text = reference_text[0][0].decode('utf-8') - # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) - - # reference_text = self.default_spk_info["prompt_text"] - # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - # prompt_speech_feat = None - # prompt_spk_embedding = None - - else: - # using pre-cached reference text - assert False, "using pre-cached reference text is not supported" - reference_text = self.default_spk_info["prompt_text"] - prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - prompt_speech_feat = None - prompt_spk_embedding = None - - target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() - target_text = target_text[0][0].decode('utf-8') - print(f"target_text: {target_text}, time: {datetime.now()}") - - if self.decoupled: - response_sender = request.get_response_sender() - - semantic_token_ids_arr = [] - token_offset, chunk_index = 0, 0 - start_time = time.time() - this_token_hop_len = self.token_hop_len - print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}") - async for generated_ids in self.forward_llm_async( - target_text=target_text, - reference_text=reference_text, - prompt_speech_tokens=prompt_speech_tokens, - ): - if not generated_ids: - break - semantic_token_ids_arr.append(generated_ids) - print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}") - while True: - pending_num = len(semantic_token_ids_arr) - token_offset - if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: - this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] - this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) - print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}") - sub_tts_speech = await self.forward_token2wav( - chunk_index, - this_tts_speech_token, request_id, wav, wav_len, False - ) - print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}") - audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) - inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) - response_sender.send(inference_response) - - token_offset += this_token_hop_len - self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}") - - if self.dynamic_chunk_strategy == "exponential": - this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) - elif self.dynamic_chunk_strategy == "equal": - this_token_hop_len = self.token_hop_len - elif self.dynamic_chunk_strategy == "time_based": - # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 - cost_time = time.time() - start_time - duration = token_offset / self.token_frame_rate - if chunk_index > 0 and cost_time > 0: - avg_chunk_processing_time = cost_time / (chunk_index + 1) - if avg_chunk_processing_time > 0: - multiples = (duration - cost_time) / avg_chunk_processing_time - self.logger.log_info(f"multiples: {multiples}") - next_pending_num = len(semantic_token_ids_arr) - token_offset - if multiples > 4: - this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len - elif multiples > 2: - this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len - else: - this_token_hop_len = self.token_hop_len - this_token_hop_len = max(self.token_hop_len, this_token_hop_len) - chunk_index += 1 - else: - break - - this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) - sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) - audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) - inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) - response_sender.send(inference_response) - - response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) - self.logger.log_info("send tritonserver_response_complete_final to end") - else: - raise NotImplementedError("Decoupled mode is not supported") - - async def execute(self, requests): - """Execute inference on the batched requests. - - Args: - requests: List of inference requests - - Returns: - List of inference responses containing generated audio - """ - tasks = [ - asyncio.create_task(self._process_request(request)) - for request in requests - ] - await asyncio.gather(*tasks) - return None - - def finalize(self): - self.logger.log_info("Finalizing CosyVoice DIT model") - if hasattr(self, "http_client"): - asyncio.run(self.http_client.aclose()) diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py index 30c3b3b..d309d18 100644 --- a/runtime/triton_trtllm/offline_inference.py +++ b/runtime/triton_trtllm/offline_inference.py @@ -47,8 +47,6 @@ import requests import asyncio import httpx -from token2wav import CosyVoice2_Token2Wav - sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") try: torch.multiprocessing.set_start_method("spawn") @@ -367,7 +365,12 @@ def main(args): runner = None else: raise ValueError(f"Unsupported backend: {args.backend}") - + + if 'Step-Audio-2-mini' in args.token2wav_path: + from token2wav_dit import CosyVoice2_Token2Wav + else: + assert 'CosyVoice2-0.5B' in args.token2wav_path + from token2wav import CosyVoice2_Token2Wav token2wav_model = CosyVoice2_Token2Wav( model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank ) @@ -589,7 +592,6 @@ def main(args): t2w_prompt_audios_list, t2w_prompt_audios_sample_rate, ) - torch.cuda.synchronize() token2wav_end_time = time.time() total_token2wav_time += (token2wav_end_time - token2wav_start_time) diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index 2c19a1d..463e490 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -1,28 +1,33 @@ #!/bin/bash # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) export CUDA_VISIBLE_DEVICES=0 -cosyvoice_path=/workspace/CosyVoice +# cosyvoice_path=/workspace/CosyVoice cosyvoice_path=/workspace_yuekai/tts/CosyVoice stepaudio2_path=/workspace_yuekai/tts/Step-Audio2 + export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH + stage=$1 stop_stage=$2 -N_GPUS=2 # set the number of GPUs to use - huggingface_model_local_dir=./cosyvoice2_llm model_scope_model_local_dir=./CosyVoice2-0.5B +step_audio_model_dir=./Step-Audio-2-mini + trt_dtype=bfloat16 trt_weights_dir=./trt_weights_${trt_dtype} trt_engines_dir=./trt_engines_${trt_dtype} model_repo=./model_repo_cosyvoice2_dit - -use_spk2info_cache=False +bls_instance_num=4 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + + echo "Cloning Step-Audio2-mini" + git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt $stepaudio2_path + echo "Cloning CosyVoice" git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path cd $cosyvoice_path @@ -35,8 +40,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir - # download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings - wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt + + echo "Step-Audio2-mini" + huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini + cd $stepaudio2_path/token2wav + wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx + wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx + cd - fi @@ -60,40 +70,6 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then --engine_dir=$trt_engines_dir || exit 1 fi - -# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then -# echo "Creating model repository" -# rm -rf $model_repo -# mkdir -p $model_repo -# cosyvoice2_dir="cosyvoice2_dit" -# token2wav_dir="token2wav_dit" - -# cp -r ./model_repo/${cosyvoice2_dir} $model_repo -# cp -r ./model_repo/tensorrt_llm $model_repo -# cp -r ./model_repo/${token2wav_dir} $model_repo -# #if [ $use_spk2info_cache == "False" ]; then -# cp -r ./model_repo/audio_tokenizer $model_repo -# cp -r ./model_repo/speaker_embedding $model_repo -# #fi - -# ENGINE_PATH=$trt_engines_dir -# MAX_QUEUE_DELAY_MICROSECONDS=0 -# MODEL_DIR=$model_scope_model_local_dir -# LLM_TOKENIZER_DIR=$huggingface_model_local_dir -# BLS_INSTANCE_NUM=1 -# TRITON_MAX_BATCH_SIZE=16 -# DECOUPLED_MODE=True # True for streaming, False for offline -# STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav - -# python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} -# python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} -# python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 -# #if [ $use_spk2info_cache == "False" ]; then -# python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} -# python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} -# #fi -# fi - if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then echo "Creating model repository async mode" rm -rf $model_repo @@ -102,122 +78,75 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then token2wav_dir="token2wav_dit" cp -r ./model_repo/${cosyvoice2_dir} $model_repo - cp -r ./model_repo/tensorrt_llm $model_repo cp -r ./model_repo/${token2wav_dir} $model_repo - #if [ $use_spk2info_cache == "False" ]; then - cp -r ./model_repo/audio_tokenizer $model_repo - cp -r ./model_repo/speaker_embedding $model_repo - #fi + cp -r ./model_repo/audio_tokenizer $model_repo + cp -r ./model_repo/speaker_embedding $model_repo + ENGINE_PATH=$trt_engines_dir MAX_QUEUE_DELAY_MICROSECONDS=0 MODEL_DIR=$model_scope_model_local_dir LLM_TOKENIZER_DIR=$huggingface_model_local_dir - BLS_INSTANCE_NUM=4 + BLS_INSTANCE_NUM=$bls_instance_num TRITON_MAX_BATCH_SIZE=1 - DECOUPLED_MODE=True # True for streaming, False for offline - STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav + DECOUPLED_MODE=True + STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} - python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 - #if [ $use_spk2info_cache == "False" ]; then - python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} - python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} - #fi - rm -rf $model_repo/tensorrt_llm - # mv $model_repo/cosyvoice2_dit/1 $model_repo/cosyvoice2_dit/4 + python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - echo "Starting Triton server on $N_GPUS GPUs" - for i in $(seq 0 $(($N_GPUS - 1))); do - echo "Starting server on GPU $i" - http_port=$((19000 + $i)) - grpc_port=$((18000 + $i)) - metrics_port=$((17000 + $i)) - CUDA_VISIBLE_DEVICES=$i tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port & - done - - echo "Servers are running in the background. Press Ctrl+C to stop them and the script." - wait -fi - -if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then - echo "Starting Triton server on $N_GPUS GPUs" - N_GPUS=1 - for i in $(seq 0 $(($N_GPUS - 1))); do - echo "Starting server on GPU $i" - http_port=$((19000 + $i)) - grpc_port=$((18000 + $i)) - metrics_port=$((17000 + $i)) - CUDA_VISIBLE_DEVICES=0 tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port & - done - - echo "Servers are running in the background. Press Ctrl+C to stop them and the script." + echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve" + tritonserver --model-repository $model_repo --http-port 18000 & + mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 & wait + # Test using curl + # curl http://localhost:8000/v1/chat/completions \ + # -H "Content-Type: application/json" \ + # -d '{ + # "model": "trt_engines_bfloat16", + # "messages":[{"role": "user", "content": "Where is New York?"}, + # {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}], + # "max_tokens": 512, + # "temperature": 0.8, + # "top_p": 0.95, + # "top_k": 50, + # "stop": ["<|eos1|>"], + # "repetition_penalty": 1.2, + # "stream": false + # }' fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - echo "Single request test http, only work for offline TTS mode" - python3 client_http.py \ - --reference-audio ./assets/prompt_audio.wav \ - --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ - --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \ - --model-name cosyvoice2 + echo "Running benchmark client" + num_task=4 + mode=streaming + BLS_INSTANCE_NUM=$bls_instance_num + + python3 client_grpc.py \ + --server-addr localhost \ + --server-port 8001 \ + --model-name cosyvoice2_dit \ + --num-tasks $num_task \ + --mode $mode \ + --huggingface-dataset yuekai/seed_tts_cosy2 \ + --log-dir ./log_single_gpu_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM} + fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - echo "Running benchmark client grpc on $N_GPUS GPUs" - num_task=1 + echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script" - mode=streaming - BLS_INSTANCE_NUM=4 - - for i in $(seq 0 $(($N_GPUS - 1))); do - grpc_port=$((18000 + $i)) - echo "Running client for server on localhost:$grpc_port" - python3 client_grpc.py \ - --server-addr localhost \ - --server-port $grpc_port \ - --model-name cosyvoice2_dit \ - --num-tasks $num_task \ - --mode $mode \ - --huggingface-dataset yuekai/seed_tts_cosy2 \ - --log-dir ./log_debug_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_gpu${i} & - done - wait -fi -if [ $stage -le 50 ] && [ $stop_stage -ge 50 ]; then - echo "Running benchmark client grpc on $N_GPUS GPUs" - num_task=4 - N_GPUS=1 - mode=streaming - BLS_INSTANCE_NUM=4 - - for i in $(seq 0 $(($N_GPUS - 1))); do - grpc_port=$((18000 + $i)) - echo "Running client for server on localhost:$grpc_port" - python3 client_grpc.py \ - --server-addr localhost \ - --server-port $grpc_port \ - --model-name cosyvoice2_dit \ - --num-tasks $num_task \ - --mode $mode \ - --huggingface-dataset yuekai/seed_tts_cosy2 \ - --log-dir ./log_single_card_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM} & - done - wait -fi -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - echo "stage 6: Offline inference benchmark" - n_gpus=1 datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh - backend=trtllm-serve # hf, trtllm, vllm + backend=trtllm # hf, trtllm, vllm, trtllm-serve - batch_sizes=(16 8 4 2 1) - batch_sizes=(16 8 4 2) + batch_sizes=(16) token2wav_batch_size=1 + for batch_size in ${batch_sizes[@]}; do for dataset in ${datasets[@]}; do output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size} @@ -225,7 +154,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then python3 offline_inference.py \ --output-dir $output_dir \ --llm-model-name-or-path $huggingface_model_local_dir \ - --token2wav-path $model_scope_model_local_dir \ + --token2wav-path $step_audio_model_dir/token2wav \ --backend $backend \ --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \ --engine-dir $trt_engines_dir \ @@ -234,34 +163,13 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then done fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - - CUDA_VISIBLE_DEVICES=2 python3 streaming_inference.py --enable-trt --strategy exponential - - +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + echo "Running Step-Audio2-mini DiT Token2Wav inference using a single python script" + export CUDA_VISIBLE_DEVICES=1 + # Note: Using pre-computed cosyvoice2 tokens + python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential + # Offline Token2wav inference + # python3 token2wav_dit.py --enable-trt fi -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 - -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - #! /usr/bin/env bash - curl http://localhost:8000/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "trt_engines_bfloat16", - "messages":[{"role": "user", "content": "Where is New York?"}, - {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}], - "max_tokens": 512, - "temperature": 0.8, - "top_p": 0.95, - "top_k": 50, - "stop": ["<|eos1|>"], - "repetition_penalty": 1.2, - "stream": false - }' -fi \ No newline at end of file diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index 93c6758..026feb5 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -54,7 +54,7 @@ if __name__ == "__main__": token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) flow_pre_lookahead_len = 3 - CHUNK_SIZE = 15 + CHUNK_SIZE = 25 token_frame_rate = 25 OVERLAP_SIZE = 0 @@ -67,20 +67,12 @@ if __name__ == "__main__": ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0] - # if id != "unseen3_text5": - # continue - # else: - # a = torch.load("semantic_token_ids_arr_debug_871e2b90-42a7-4829-957c-b45e6a96fdb2.pt") - # generated_speech_tokens = a["semantic_token_ids_arr"] - # print(generated_speech_tokens) + assert prompt_audio_sample_rate == 16000 prompt_text = prompt_text_list[0] prompt_speech_tokens = prompt_speech_tokens_list[0] - - # generated_ids_iter = fake_generated_id_iter(generated_speech_tokens) - semantic_token_ids_arr, token_offset = [], 0 flow_prompt_speech_token_len = len(prompt_speech_tokens) @@ -114,14 +106,16 @@ if __name__ == "__main__": audios = output_wavs reconstructed_audio = np.concatenate(audios) - # Save reconstructed audio sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") - - print(f"Saved {id}") end_time = time.time() if _ == 0: token2wav_model.speaker_cache = {} - print(f"Warmup time: {end_time - start_time} seconds") - print(f"Total forward count: {total_forward_count}") + print(f"Warmup time: {end_time - start_time} seconds") + print("clear speaker cache") + elif _ == 1: + print(f"Cost time without speaker cache: {end_time - start_time} seconds") + else: + print(f"Cost time with speaker cache: {end_time - start_time} seconds") + print(f"Total flow matching forward calls: {total_forward_count}") \ No newline at end of file From a019a2504ea4dce4b5c27ee13b3368796b0a1eb0 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Oct 2025 16:48:00 +0800 Subject: [PATCH 08/15] clean code --- runtime/triton_trtllm/client_grpc.py | 142 ++++-------------- .../model_repo/cosyvoice2_dit/1/model.py | 66 ++------ .../model_repo/token2wav_dit/1/model.py | 14 +- .../token2wav_dit/1/token2wav_dit.py | 10 -- runtime/triton_trtllm/streaming_inference.py | 7 - 5 files changed, 46 insertions(+), 193 deletions(-) diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 7aa8d7d..718fe86 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -43,9 +43,9 @@ python3 client_grpc.py \ import argparse import asyncio import json -import queue # Added -import uuid # Added -import functools # Added +import queue +import uuid +import functools import os import time @@ -55,13 +55,11 @@ from pathlib import Path import numpy as np import soundfile as sf import tritonclient -import tritonclient.grpc.aio as grpcclient_aio # Renamed original import -import tritonclient.grpc as grpcclient_sync # Added sync client import -from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException +import tritonclient.grpc.aio as grpcclient_aio +import tritonclient.grpc as grpcclient_sync +from tritonclient.utils import np_to_triton_dtype, InferenceServerException -from datetime import datetime -# --- Added UserData and callback --- class UserData: def __init__(self): self._completed_requests = queue.Queue() @@ -86,7 +84,7 @@ class UserData: def callback(user_data, result, error): if not error: if user_data._first_chunk_time is None: - user_data._first_chunk_time = time.time() # Record time of first successful chunk + user_data._first_chunk_time = time.time() elif user_data._second_chunk_time is None: user_data._second_chunk_time = time.time() @@ -99,10 +97,6 @@ def callback(user_data, result, error): def stream_callback(user_data_map, result, error): request_id = None if error: - # Note: InferenceServerException doesn't have a public request_id() method in all versions. - # This part might need adjustment depending on the tritonclient library version. - # A more robust way would be to wrap the error with the request_id if possible. - # For now, we assume we can't get request_id from error and it will timeout on the client side. print(f"An error occurred in the stream callback: {error}") else: request_id = result.get_response().id @@ -115,31 +109,9 @@ def stream_callback(user_data_map, result, error): print(f"Warning: Could not find user_data for request_id {request_id}") -# --- End Added UserData and callback --- - - def write_triton_stats(stats, summary_file): with open(summary_file, "w") as summary_f: model_stats = stats["model_stats"] - # write a note, the log is from triton_client.get_inference_statistics(), to better human readability - summary_f.write( - "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n" - ) - summary_f.write("To learn more about the log, please refer to: \n") - summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n") - summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n") - summary_f.write( - "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n" - ) - summary_f.write( - "However, there is a trade-off between the increased queue time and the increased batch size. \n" - ) - summary_f.write( - "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n" - ) - summary_f.write( - "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n" - ) for model_state in model_stats: if "last_inference" not in model_state: continue @@ -150,7 +122,7 @@ def write_triton_stats(stats, summary_file): total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 summary_f.write( - f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa + f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" ) model_batch_stats = model_state["batch_stats"] for batch in model_batch_stats: @@ -164,19 +136,18 @@ def write_triton_stats(stats, summary_file): compute_input_time_ms = int(compute_input["ns"]) / 1e6 compute_output_time_ms = int(compute_output["ns"]) / 1e6 summary_f.write( - f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa + f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" ) summary_f.write( - f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa + f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " ) summary_f.write( - f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa + f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" ) def subtract_stats(stats_after, stats_before): """Subtracts two Triton inference statistics objects.""" - # Deep copy to avoid modifying the original stats_after stats_diff = json.loads(json.dumps(stats_after)) model_stats_before_map = { @@ -196,7 +167,6 @@ def subtract_stats(stats_after, stats_before): if model_name in model_stats_before_map: model_stat_before = model_stats_before_map[model_name] - # Subtract counts model_stat_after["inference_count"] = str( int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0)) ) @@ -204,7 +174,6 @@ def subtract_stats(stats_after, stats_before): int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0)) ) - # Subtract aggregate stats (like queue, compute times) if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before: for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]: if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]: @@ -217,7 +186,6 @@ def subtract_stats(stats_after, stats_before): count_before = int(model_stat_before["inference_stats"][key]["count"]) model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before) - # Subtract batch execution stats if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before: batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]} for batch_stat_after in model_stat_after["batch_stats"]: @@ -338,7 +306,6 @@ def get_args(): help="log directory", ) - # --- Added arguments --- parser.add_argument( "--mode", type=str, @@ -379,39 +346,33 @@ def load_audio(wav_path, target_sample_rate=16000): def prepare_request_input_output( - protocol_client, # Can be grpcclient_aio or grpcclient_sync + protocol_client, waveform, reference_text, target_text, sample_rate=16000, - padding_duration: int = None, # Optional padding for offline mode + padding_duration: int = None, use_spk2info_cache: bool = False ): """Prepares inputs for Triton inference (offline or streaming).""" assert len(waveform.shape) == 1, "waveform should be 1D" lengths = np.array([[len(waveform)]], dtype=np.int32) - # Apply padding only if padding_duration is provided (for offline) if padding_duration: duration = len(waveform) / sample_rate - # Estimate target duration based on text length ratio (crude estimation) - # Avoid division by zero if reference_text is empty if reference_text: estimated_target_duration = duration / len(reference_text) * len(target_text) else: - estimated_target_duration = duration # Assume target duration similar to reference if no text + estimated_target_duration = duration - # Calculate required samples based on estimated total duration required_total_samples = padding_duration * sample_rate * ( (int(estimated_target_duration + duration) // padding_duration) + 1 ) samples = np.zeros((1, required_total_samples), dtype=np.float32) samples[0, : len(waveform)] = waveform else: - # No padding for streaming or if padding_duration is None samples = waveform.reshape(1, -1).astype(np.float32) - # Common input creation logic inputs = [ protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)), protocol_client.InferInput( @@ -450,12 +411,8 @@ def run_sync_streaming_inference( ): """Helper function to run the blocking sync streaming call.""" start_time_total = time.time() - user_data.record_start_time() # Record start time for first chunk latency calculation - # e.g. 08:47:34.827758 + user_data.record_start_time() - print(f"Record start time in human readable: {datetime.now()}") - # input() - # Send request sync_triton_client.async_stream_infer( model_name, inputs, @@ -464,30 +421,26 @@ def run_sync_streaming_inference( enable_empty_final_response=True, ) - # Process results audios = [] while True: try: - result = user_data._completed_requests.get(timeout=20) # Add timeout + result = user_data._completed_requests.get(timeout=20) if isinstance(result, InferenceServerException): print(f"Received InferenceServerException: {result}") - # Don't stop the stream here, just return error return None, None, None, None - # Get response metadata response = result.get_response() final = response.parameters["triton_final_response"].bool_param if final is True: break audio_chunk = result.as_numpy("waveform").reshape(-1) - if audio_chunk.size > 0: # Only append non-empty chunks + if audio_chunk.size > 0: audios.append(audio_chunk) else: print("Warning: received empty audio chunk.") except queue.Empty: print(f"Timeout waiting for response for request id {request_id}") - # Don't stop stream here, just return error return None, None, None, None end_time_total = time.time() @@ -495,47 +448,36 @@ def run_sync_streaming_inference( first_chunk_latency = user_data.get_first_chunk_latency() second_chunk_latency = user_data.get_second_chunk_latency() - # Reconstruct audio using cross-fade (from client_grpc_streaming.py) - actual_duration = 0 if audios: - # Only spark_tts model uses cross-fade if model_name == "spark_tts": cross_fade_samples = int(chunk_overlap_duration * save_sample_rate) fade_out = np.linspace(1, 0, cross_fade_samples) fade_in = np.linspace(0, 1, cross_fade_samples) reconstructed_audio = None - # Simplified reconstruction based on client_grpc_streaming.py if not audios: print("Warning: No audio chunks received.") - reconstructed_audio = np.array([], dtype=np.float32) # Empty array + reconstructed_audio = np.array([], dtype=np.float32) elif len(audios) == 1: reconstructed_audio = audios[0] else: - reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap + reconstructed_audio = audios[0][:-cross_fade_samples] for i in range(1, len(audios)): - # Cross-fade section cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + audios[i - 1][-cross_fade_samples:] * fade_out) - # Middle section of the current chunk middle_part = audios[i][cross_fade_samples:-cross_fade_samples] - # Concatenate reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) - # Add the last part of the final chunk reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) if reconstructed_audio is not None and reconstructed_audio.size > 0: actual_duration = len(reconstructed_audio) / save_sample_rate - # Save reconstructed audio sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") else: print("Warning: No audio chunks received or reconstructed.") - actual_duration = 0 # Set duration to 0 if no audio + actual_duration = 0 else: reconstructed_audio = np.concatenate(audios) - print(f"reconstructed_audio: {reconstructed_audio.shape}") actual_duration = len(reconstructed_audio) / save_sample_rate - # Save reconstructed audio sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") else: @@ -548,7 +490,7 @@ def run_sync_streaming_inference( async def send_streaming( manifest_item_list: list, name: str, - server_url: str, # Changed from sync_triton_client + server_url: str, protocol_client: types.ModuleType, log_interval: int, model_name: str, @@ -561,12 +503,12 @@ async def send_streaming( total_duration = 0.0 latency_data = [] task_id = int(name[5:]) - sync_triton_client = None # Initialize client variable + sync_triton_client = None user_data_map = {} - try: # Wrap in try...finally to ensure client closing + try: print(f"{name}: Initializing sync client for streaming...") - sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here + sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map)) print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.") @@ -593,7 +535,6 @@ async def send_streaming( user_data_map[request_id] = user_data audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") - print("target_text: ", target_text, "time: ", datetime.now()) total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread( run_sync_streaming_inference, sync_triton_client, @@ -627,7 +568,7 @@ async def send_streaming( import traceback traceback.print_exc() - finally: # Ensure client is closed + finally: if sync_triton_client: try: print(f"{name}: Closing stream and sync client...") @@ -656,7 +597,6 @@ async def send( latency_data = [] task_id = int(name[5:]) - print(f"manifest_item_list: {manifest_item_list}") for i, item in enumerate(manifest_item_list): if i % log_interval == 0: print(f"{name}: {i}/{len(manifest_item_list)}") @@ -697,7 +637,6 @@ def load_manifests(manifest_path): assert len(line.strip().split("|")) == 4 utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") utt = Path(utt).stem - # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav") if not os.path.isabs(prompt_wav): prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav) manifest_list.append( @@ -738,23 +677,17 @@ async def main(): args = get_args() url = f"{args.server_addr}:{args.server_port}" - # --- Client Initialization based on mode --- triton_client = None protocol_client = None if args.mode == "offline": print("Initializing gRPC client for offline mode...") - # Use the async client for offline tasks triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) protocol_client = grpcclient_aio elif args.mode == "streaming": print("Initializing gRPC client for streaming mode...") - # Use the sync client for streaming tasks, handled via asyncio.to_thread - # We will create one sync client instance PER TASK inside send_streaming. - # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now - protocol_client = grpcclient_sync # protocol client for input prep + protocol_client = grpcclient_sync else: raise ValueError(f"Invalid mode: {args.mode}") - # --- End Client Initialization --- if args.reference_audio: args.num_tasks = 1 @@ -776,24 +709,18 @@ async def main(): trust_remote_code=True, ) manifest_item_list = [] - tmp_audio_path="./asset_zero_shot_prompt.wav" - tmp_audio_text="希望你以后能够做的比我还好呦。" for i in range(len(dataset)): manifest_item_list.append( { "audio_filepath": dataset[i]["prompt_audio"], "reference_text": dataset[i]["prompt_text"], - # "audio_filepath": tmp_audio_path, - # "reference_text": tmp_audio_text, "target_audio_path": dataset[i]["id"], "target_text": dataset[i]["target_text"], } ) - # manifest_item_list = manifest_item_list[:4] else: manifest_item_list = load_manifests(args.manifest_path) - # --- Statistics Fetching (Before) --- stats_client = None stats_before = None try: @@ -803,7 +730,6 @@ async def main(): stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True) except Exception as e: print(f"Could not retrieve statistics before running tasks: {e}") - # --- End Statistics Fetching (Before) --- num_tasks = min(args.num_tasks, len(manifest_item_list)) manifest_item_list = split_data(manifest_item_list, num_tasks) @@ -813,7 +739,6 @@ async def main(): tasks = [] start_time = time.time() for i in range(num_tasks): - # --- Task Creation based on mode --- if args.mode == "offline": task = asyncio.create_task( send( @@ -834,7 +759,7 @@ async def main(): send_streaming( manifest_item_list[i], name=f"task-{i}", - server_url=url, # Pass URL instead of client + server_url=url, protocol_client=protocol_client, log_interval=args.log_interval, model_name=args.model_name, @@ -845,7 +770,6 @@ async def main(): use_spk2info_cache=args.use_spk2info_cache, ) ) - # --- End Task Creation --- tasks.append(task) ans_list = await asyncio.gather(*tasks) @@ -858,7 +782,7 @@ async def main(): for ans in ans_list: if ans: total_duration += ans[0] - latency_data.extend(ans[1]) # Use extend for list of lists + latency_data.extend(ans[1]) else: print("Warning: A task returned None, possibly due to an error.") @@ -874,10 +798,8 @@ async def main(): s += f"({total_duration / 3600:.2f} hours)\n" s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n" - # --- Statistics Reporting based on mode --- if latency_data: if args.mode == "offline": - # Original offline latency calculation latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data] if latency_list: latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0 @@ -892,7 +814,6 @@ async def main(): s += "No latency data collected for offline mode.\n" elif args.mode == "streaming": - # Calculate stats for total request latency and first chunk latency total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None] first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None] second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None] @@ -937,7 +858,6 @@ async def main(): s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n" else: s += "No latency data collected.\n" - # --- End Statistics Reporting --- print(s) if args.manifest_path: @@ -947,12 +867,10 @@ async def main(): elif args.reference_audio: name = Path(args.reference_audio).stem else: - name = "results" # Default name if no manifest/split/audio provided + name = "results" with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: f.write(s) - # --- Statistics Fetching using temporary Async Client --- - # Use a separate async client for fetching stats regardless of mode try: if stats_client and stats_before: print("Fetching inference statistics after running tasks...") @@ -980,11 +898,9 @@ async def main(): await stats_client.close() except Exception as e: print(f"Error closing async stats client: {e}") - # --- End Statistics Fetching --- if __name__ == "__main__": - # asyncio.run(main()) # Use TaskGroup for better exception handling if needed async def run_main(): try: await main() diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index 2f81786..827925c 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -43,7 +43,7 @@ import torchaudio from matcha.utils.audio import mel_spectrogram -from datetime import datetime + ORIGINAL_VOCAB_SIZE = 151663 torch.set_num_threads(1) @@ -85,9 +85,7 @@ class TritonPythonModel: self.model_config = json.loads(args['model_config']) parameters = self.model_config['parameters'] model_params = {k: v["string_value"] for k, v in parameters.items()} - self.logger.log_info(f"model_params:{model_params}") self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" - # self.dynamic_chunk_strategy = "equal" self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") # Initialize tokenizer @@ -103,12 +101,8 @@ class TritonPythonModel: self.flow_pre_lookahead_len = 3 self.token_hop_len = 15 - spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt") - if not os.path.exists(spk_info_path): - raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") - spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) - self.default_spk_info = spk_info["001"] self.http_client = httpx.AsyncClient() + self.api_base = "http://localhost:8000/v1/chat/completions" def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: """Converts a tensor or list of speech token IDs to a string representation.""" @@ -147,12 +141,8 @@ class TritonPythonModel: "stream": True, } - api_base = "http://localhost:8000/v1/chat/completions" - buffer = "" - async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response: - print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}") - print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}") + async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): @@ -164,7 +154,6 @@ class TritonPythonModel: content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") if content: buffer += content - print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}") while True: match = re.search(r"<\|s_(\d+)\|>", buffer) if not match: @@ -307,40 +296,24 @@ class TritonPythonModel: wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") # Process reference audio through audio tokenizer - if wav is not None: - wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") - prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) - prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - wav_tensor = wav.as_numpy() - wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] - print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}") - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) - speech_feat = self._extract_speech_feat(prompt_speech_resample) - token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) - prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() - prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) + prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() - reference_text = reference_text[0][0].decode('utf-8') - # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + wav_tensor = wav.as_numpy() + wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) + speech_feat = self._extract_speech_feat(prompt_speech_resample) + token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) + prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() + prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() - # reference_text = self.default_spk_info["prompt_text"] - # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - # prompt_speech_feat = None - # prompt_spk_embedding = None - - else: - # using pre-cached reference text - assert False, "using pre-cached reference text is not supported" - reference_text = self.default_spk_info["prompt_text"] - prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - prompt_speech_feat = None - prompt_spk_embedding = None + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode('utf-8') - print(f"target_text: {target_text}, time: {datetime.now()}") if self.decoupled: response_sender = request.get_response_sender() @@ -349,7 +322,6 @@ class TritonPythonModel: token_offset, chunk_index = 0, 0 start_time = time.time() this_token_hop_len = self.token_hop_len - print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}") async for generated_ids in self.forward_llm_async( target_text=target_text, reference_text=reference_text, @@ -358,24 +330,20 @@ class TritonPythonModel: if not generated_ids: break semantic_token_ids_arr.append(generated_ids) - print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}") while True: pending_num = len(semantic_token_ids_arr) - token_offset if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) - print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}") sub_tts_speech = await self.forward_token2wav( chunk_index, this_tts_speech_token, request_id, wav, wav_len, False ) - print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}") audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) token_offset += this_token_hop_len - self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}") if self.dynamic_chunk_strategy == "exponential": this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) @@ -389,7 +357,6 @@ class TritonPythonModel: avg_chunk_processing_time = cost_time / (chunk_index + 1) if avg_chunk_processing_time > 0: multiples = (duration - cost_time) / avg_chunk_processing_time - self.logger.log_info(f"multiples: {multiples}") next_pending_num = len(semantic_token_ids_arr) - token_offset if multiples > 4: this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len @@ -409,9 +376,8 @@ class TritonPythonModel: response_sender.send(inference_response) response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) - self.logger.log_info("send tritonserver_response_complete_final to end") else: - raise NotImplementedError("Decoupled mode is not supported") + raise NotImplementedError("Offline TTS mode is not supported") async def execute(self, requests): """Execute inference on the batched requests. diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index 230bad0..1f6b591 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -106,13 +106,10 @@ class TritonPythonModel: # Process each request in batch for request in requests: target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() - target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device) - # shift the speech tokens according to the original vocab size + target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor) target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE target_speech_tokens = target_speech_tokens.squeeze().tolist() - # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. - finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() request_id = request.request_id() @@ -124,23 +121,14 @@ class TritonPythonModel: request, "reference_wav_len").as_numpy().item() wav_array = torch.from_numpy(wav_array) - # Prepare inputs wav = wav_array[:, :wav_len].squeeze(0) spk_id = get_spk_id_from_prompt_audio(wav) - # wav = wav.to(self.device) - - # update cache before forward - # self.token2wav_model.streaming_flow_cache[request_id] - # self.token2wav_model.hift_cache_dict[request_id] audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) - # get the cache after forward outputs = [] - generated_wave = audio_hat.squeeze(0).cpu().numpy() - wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) outputs.append(wav_tensor) inference_response = pb_utils.InferenceResponse(output_tensors=outputs) diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index 63dce14..bda4cb1 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -320,7 +320,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def forward( self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] ): - # assert all item in prompt_audios_sample_rate is 16000 assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) @@ -335,7 +334,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def prepare_prompt_audio( self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] ): - # assert all item in prompt_audios_sample_rate is 16000 assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) @@ -385,7 +383,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} - print(f"speaker_id {speaker_id} added to cache") if request_id not in self.streaming_flow_cache: self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} @@ -394,12 +391,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): source = torch.zeros(1, 1, 0, device='cuda'), speech = torch.zeros(1, 0, device='cuda'), ) - # else: - # for k, v in self.streaming_flow_cache[request_id].items(): - # print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}") - # for k, v in self.hift_cache_dict[request_id].items(): - # print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}") - # breakpoint() current_request_cache = self.streaming_flow_cache[request_id] @@ -477,7 +468,6 @@ def get_args(): if __name__ == "__main__": args = get_args() model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) - # mkdir output_dir if not exists if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) dataset_name = "yuekai/seed_tts_cosy2" diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index 026feb5..a5404e2 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -35,12 +35,6 @@ def get_args(): return parser.parse_args() -def fake_generated_id_iter(generated_speech_tokens_list): - for i in range(len(generated_speech_tokens_list)): - yield generated_speech_tokens_list[i] - - - if __name__ == "__main__": args = get_args() @@ -53,7 +47,6 @@ if __name__ == "__main__": token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) - flow_pre_lookahead_len = 3 CHUNK_SIZE = 25 token_frame_rate = 25 OVERLAP_SIZE = 0 From 7cbd4902534177e4eb1369fbaa42d465d6eb9737 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Oct 2025 17:20:04 +0800 Subject: [PATCH 09/15] add docker compose for streaming tts --- runtime/triton_trtllm/docker-compose.dit.yml | 20 +++++++++++++++++++ .../run_stepaudio2_dit_token2wav.sh | 7 +++---- 2 files changed, 23 insertions(+), 4 deletions(-) create mode 100644 runtime/triton_trtllm/docker-compose.dit.yml diff --git a/runtime/triton_trtllm/docker-compose.dit.yml b/runtime/triton_trtllm/docker-compose.dit.yml new file mode 100644 index 0000000..1f97f7c --- /dev/null +++ b/runtime/triton_trtllm/docker-compose.dit.yml @@ -0,0 +1,20 @@ +services: + tts: + image: soar97/triton-cosyvoice:25.06 + shm_size: '1gb' + ports: + - "8000:8000" + - "8001:8001" + - "8002:8002" + environment: + - PYTHONIOENCODING=utf-8 + - MODEL_ID=${MODEL_ID} + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['0'] + capabilities: [gpu] + command: > + /bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3" \ No newline at end of file diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index 463e490..c401793 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -1,9 +1,8 @@ #!/bin/bash # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) export CUDA_VISIBLE_DEVICES=0 -# cosyvoice_path=/workspace/CosyVoice -cosyvoice_path=/workspace_yuekai/tts/CosyVoice -stepaudio2_path=/workspace_yuekai/tts/Step-Audio2 +cosyvoice_path=/workspace/CosyVoice +stepaudio2_path=/workspace/Step-Audio2 export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH @@ -89,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then LLM_TOKENIZER_DIR=$huggingface_model_local_dir BLS_INSTANCE_NUM=$bls_instance_num TRITON_MAX_BATCH_SIZE=1 - DECOUPLED_MODE=True + DECOUPLED_MODE=True # Only streaming TTS mode is supported using Nvidia Triton for now STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} From aceede59ba5fe97b65a4cb7f36b76f19de29b4f9 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Oct 2025 18:13:09 +0800 Subject: [PATCH 10/15] fix bug --- runtime/triton_trtllm/client_grpc.py | 2 +- runtime/triton_trtllm/docker-compose.dit.yml | 2 +- .../model_repo/cosyvoice2_dit/1/model.py | 32 ++++++++----------- .../token2wav_dit/1/token2wav_dit.py | 7 ++-- .../run_stepaudio2_dit_token2wav.sh | 6 ++-- 5 files changed, 20 insertions(+), 29 deletions(-) diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 718fe86..840390d 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -424,7 +424,7 @@ def run_sync_streaming_inference( audios = [] while True: try: - result = user_data._completed_requests.get(timeout=20) + result = user_data._completed_requests.get(timeout=200) if isinstance(result, InferenceServerException): print(f"Received InferenceServerException: {result}") return None, None, None, None diff --git a/runtime/triton_trtllm/docker-compose.dit.yml b/runtime/triton_trtllm/docker-compose.dit.yml index 1f97f7c..35312a1 100644 --- a/runtime/triton_trtllm/docker-compose.dit.yml +++ b/runtime/triton_trtllm/docker-compose.dit.yml @@ -17,4 +17,4 @@ services: device_ids: ['0'] capabilities: [gpu] command: > - /bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3" \ No newline at end of file + /bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run_stepaudio2_dit_token2wav.sh 0 3" \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index 827925c..523a5b8 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -103,6 +103,7 @@ class TritonPythonModel: self.http_client = httpx.AsyncClient() self.api_base = "http://localhost:8000/v1/chat/completions" + self.speaker_cache = {} def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: """Converts a tensor or list of speech token IDs to a string representation.""" @@ -240,10 +241,12 @@ class TritonPythonModel: """Forward pass through the vocoder component. Args: - prompt_speech_tokens: Prompt speech tokens tensor - prompt_speech_feat: Prompt speech feat tensor - prompt_spk_embedding: Prompt spk embedding tensor + index: Index of the request target_speech_tokens: Target speech tokens tensor + request_id: Request ID + reference_wav: Reference waveform tensor + reference_wav_len: Reference waveform length tensor + finalize: Whether to finalize the request Returns: Generated waveform tensor @@ -292,26 +295,17 @@ class TritonPythonModel: async def _process_request(self, request): request_id = request.request_id() - # Extract input tensors - wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") - - # Process reference audio through audio tokenizer - - wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") - prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) - prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - - wav_tensor = wav.as_numpy() - wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) - speech_feat = self._extract_speech_feat(prompt_speech_resample) - token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) - prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() - prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') + wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + + if reference_text not in self.speaker_cache: + self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0) + prompt_speech_tokens = self.speaker_cache[reference_text] + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode('utf-8') diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index bda4cb1..3d50325 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -57,10 +57,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB if dtype == torch.float16: config.set_flag(trt.BuilderFlag.FP16) - elif dtype == torch.bfloat16: - config.set_flag(trt.BuilderFlag.BF16) - elif dtype == torch.float32: - config.set_flag(trt.BuilderFlag.FP32) + profile = builder.create_optimization_profile() # load onnx model with open(onnx_model, "rb") as f: @@ -199,7 +196,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: trt_kwargs = self.get_spk_trt_kwargs() - convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16) + convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32) import tensorrt as trt with open(spk_model, 'rb') as f: spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index c401793..5881b44 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -42,7 +42,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then echo "Step-Audio2-mini" huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini - cd $stepaudio2_path/token2wav + cd $step_audio_model_dir/token2wav wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx cd - @@ -100,8 +100,8 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve" - tritonserver --model-repository $model_repo --http-port 18000 & mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 & + tritonserver --model-repository $model_repo --http-port 18000 & wait # Test using curl # curl http://localhost:8000/v1/chat/completions \ @@ -168,7 +168,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then # Note: Using pre-computed cosyvoice2 tokens python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential # Offline Token2wav inference - # python3 token2wav_dit.py --enable-trt + python3 token2wav_dit.py --enable-trt fi From 807bb6ee0b0fa1270566f1b96ab85208a21dde96 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Oct 2025 22:01:24 +0800 Subject: [PATCH 11/15] add dit results --- runtime/triton_trtllm/README.DIT.md | 106 +++++++++++++++++++++++++ runtime/triton_trtllm/token2wav_dit.py | 1 + 2 files changed, 107 insertions(+) create mode 100644 runtime/triton_trtllm/README.DIT.md create mode 120000 runtime/triton_trtllm/token2wav_dit.py diff --git a/runtime/triton_trtllm/README.DIT.md b/runtime/triton_trtllm/README.DIT.md new file mode 100644 index 0000000..3c130b3 --- /dev/null +++ b/runtime/triton_trtllm/README.DIT.md @@ -0,0 +1,106 @@ +## Accelerating CosyVoice with DiT-based Token2Wav, NVIDIA Triton Inference Server and TensorRT-LLM + +Contributed by Yuekai Zhang (NVIDIA). + +This document describes how to accelerate CosyVoice with a DiT-based Token2Wav module from Step-Audio2, using NVIDIA Triton Inference Server and TensorRT-LLM. + +### Quick Start + +Launch the service directly with Docker Compose: +```sh +docker compose -f docker-compose.dit.yml up +``` + +### Build the Docker Image + +To build the image from scratch: +```sh +docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06 +``` + +### Run a Docker Container +```sh +your_mount_dir=/mnt:/mnt +docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06 +``` + +### Understanding `run_stepaudio2_dit_token2wav.sh` + +The `run_stepaudio2_dit_token2wav.sh` script orchestrates the entire workflow through numbered stages. + +You can run a subset of stages with: +```sh +bash run_stepaudio2_dit_token2wav.sh +``` +- ``: The stage to start from. +- ``: The stage to stop after. + +**Stages:** + +- **Stage -1**: Clones the `Step-Audio2` and `CosyVoice` repositories. +- **Stage 0**: Downloads the `cosyvoice2_llm`, `CosyVoice2-0.5B`, and `Step-Audio-2-mini` models. +- **Stage 1**: Converts the HuggingFace checkpoint for the LLM to the TensorRT-LLM format and builds the TensorRT engines. +- **Stage 2**: Creates the Triton model repository, including configurations for `cosyvoice2_dit` and `token2wav_dit`. +- **Stage 3**: Launches the Triton Inference Server for Token2Wav module and uses `trtllm-serve` to deploy Cosyvoice2 LLM. +- **Stage 4**: Runs the gRPC benchmark client for performance testing. +- **Stage 5**: Runs the offline TTS inference benchmark test. +- **Stage 6**: Runs a standalone inference script for the Step-Audio2-mini DiT Token2Wav model. + +### Export Models and Launch Server + +Inside the Docker container, prepare the models and start the Triton server by running stages 0-3: +```sh +# This command runs stages 0, 1, 2, and 3 +bash run_stepaudio2_dit_token2wav.sh 0 3 +``` + +### Benchmark with client-server mode + +To benchmark the running Triton server, run stage 4: +```sh +bash run_stepaudio2_dit_token2wav.sh 4 4 + +# You can customize parameters such as the number of tasks inside the script. +``` +The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset. + +#### Total Request Latency + +| Concurrent Tasks | RTF | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) | +| ---------------- | ------ | ------------ | -------------------- | -------------------- | -------------------- | -------------------- | +| 1 | 0.1228 | 833.66 | 779.98 | 1297.05 | 1555.97 | 1653.02 | +| 2 | 0.0901 | 1166.23 | 1124.69 | 1762.76 | 1900.64 | 2204.14 | +| 4 | 0.0741 | 1849.30 | 1759.42 | 2624.50 | 2822.20 | 3128.42 | +| 6 | 0.0774 | 2936.13 | 3054.64 | 3849.60 | 3900.49 | 4245.79 | +| 8 | 0.0691 | 3408.56 | 3434.98 | 4547.13 | 5047.76 | 5346.53 | +| 10 | 0.0707 | 4306.56 | 4343.44 | 5769.64 | 5876.09 | 5939.79 | + +#### First Chunk Latency + +| Concurrent Tasks | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) | +| ---------------- | ------------ | -------------------- | -------------------- | -------------------- | -------------------- | +| 1 | 197.50 | 196.13 | 214.65 | 215.96 | 229.21 | +| 2 | 281.15 | 278.20 | 345.18 | 361.79 | 395.97 | +| 4 | 510.65 | 530.50 | 630.13 | 642.44 | 666.65 | +| 6 | 921.54 | 918.86 | 1079.97 | 1265.22 | 1524.41 | +| 8 | 1019.95 | 1085.26 | 1371.05 | 1402.24 | 1410.66 | +| 10 | 1214.98 | 1293.54 | 1575.36 | 1654.51 | 2161.76 | + +### Benchmark with offline inference mode +For offline inference mode benchmark, please run stage 5: +```sh +bash run_stepaudio2_dit_token2wav.sh 5 5 +``` + +The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset. + +#### Offline TTS (Cosyvoice2 0.5B LLM + StepAudio2 DiT Token2Wav) +| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF | +|---------|------------|------------------|-----------------------|--| +| TRTLLM | 16 | 2.01 | 5.03 | 0.0292 | + + + +### Acknowledgements + +This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub). diff --git a/runtime/triton_trtllm/token2wav_dit.py b/runtime/triton_trtllm/token2wav_dit.py new file mode 120000 index 0000000..2bd78a5 --- /dev/null +++ b/runtime/triton_trtllm/token2wav_dit.py @@ -0,0 +1 @@ +model_repo/token2wav_dit/1/token2wav_dit.py \ No newline at end of file From 8811e9f33a5e7a14ad308f821b967f394e72bdcc Mon Sep 17 00:00:00 2001 From: root Date: Thu, 9 Oct 2025 14:49:22 +0800 Subject: [PATCH 12/15] fix white space --- examples/grpo/cosyvoice2/README.md | 4 +-- examples/grpo/cosyvoice2/run.sh | 6 ++-- .../model_repo/cosyvoice2_dit/1/model.py | 4 +-- .../model_repo/token2wav_dit/1/model.py | 6 ++-- .../token2wav_dit/1/token2wav_dit.py | 29 ++++++++----------- runtime/triton_trtllm/offline_inference.py | 1 - runtime/triton_trtllm/streaming_inference.py | 10 +++---- 7 files changed, 27 insertions(+), 33 deletions(-) diff --git a/examples/grpo/cosyvoice2/README.md b/examples/grpo/cosyvoice2/README.md index 8783aa1..1f5c6a0 100644 --- a/examples/grpo/cosyvoice2/README.md +++ b/examples/grpo/cosyvoice2/README.md @@ -36,7 +36,7 @@ Stage `0` converts raw JSONL files into the parquet format expected by veRL: ```bash bash run.sh 0 0 ``` -Create two JSONL files—`train.jsonl` and `test.jsonl`. +Create two JSONL files—`train.jsonl` and `test.jsonl`. The script will then generate two Parquet files: ``` @@ -111,7 +111,7 @@ bash run.sh 5 5 The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository. > [!TIP] -> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format. +> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format. ## Results diff --git a/examples/grpo/cosyvoice2/run.sh b/examples/grpo/cosyvoice2/run.sh index ce97ab3..b1658e2 100644 --- a/examples/grpo/cosyvoice2/run.sh +++ b/examples/grpo/cosyvoice2/run.sh @@ -33,7 +33,7 @@ fi if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint" - modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path + modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path python3 pretrained_to_huggingface.py \ --pretrained-cosyvoice2-path $model_scope_model_path \ --save-path $sft_model_path @@ -61,7 +61,7 @@ fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "stage 1: start token2wav asr server for reward function" python3 token2wav_asr_server.py --number-of-devices 8 -fi +fi exp_name=official_llm_aishell3_grpo if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -125,7 +125,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then --backend fsdp \ --local_dir $llm_path/actor \ --target_dir $llm_path/merged_hf_model || exit 1 -fi +fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "stage 4: Test the model" diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index 523a5b8..8e2b28b 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -254,7 +254,7 @@ class TritonPythonModel: target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] - + # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav_dit', @@ -362,7 +362,7 @@ class TritonPythonModel: chunk_index += 1 else: break - + this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index 1f6b591..1f90644 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -62,7 +62,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str: # Create a SHA-256 hash of the byte string hasher = hashlib.sha256() hasher.update(tensor_bytes) - + return hasher.hexdigest() class TritonPythonModel: @@ -111,9 +111,9 @@ class TritonPythonModel: target_speech_tokens = target_speech_tokens.squeeze().tolist() finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() - + request_id = request.request_id() - + wav_array = pb_utils.get_input_tensor_by_name( request, "reference_wav").as_numpy() diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index 3d50325..d413003 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -133,7 +133,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): option.intra_op_num_threads = 1 self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"]) - self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() gpu="l20" @@ -253,7 +252,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() prompt_speech_tokens_list.append(speech_tokens_i) return prompt_speech_tokens_list - + def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: spk_emb_for_flow = [] for audio in prompt_audios_list: @@ -263,11 +262,11 @@ class CosyVoice2_Token2Wav(torch.nn.Module): spk_emb = self.forward_spk_embedding(spk_feat) spk_emb_for_flow.append(spk_emb) - spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) if self.dtype != torch.float32: spk_emb_for_flow = spk_emb_for_flow.to(self.dtype) return spk_emb_for_flow - + def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): prompt_mels_for_flow = [] prompt_mels_lens_for_flow = [] @@ -283,7 +282,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) return prompt_mels_for_flow, prompt_mels_lens_for_flow - + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): batch_size = prompt_mels_for_flow.shape[0] flow_inputs = [] @@ -318,28 +317,24 @@ class CosyVoice2_Token2Wav(torch.nn.Module): self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] ): assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) - prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) - return generated_wavs def prepare_prompt_audio( self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] ): assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) - prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) - return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow @@ -365,7 +360,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): @torch.inference_mode() def forward_streaming( self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 - ): + ): if speaker_id not in self.speaker_cache: assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio_sample_rate == 16000 @@ -384,7 +379,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): if request_id not in self.streaming_flow_cache: self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.hift_cache_dict[request_id] = dict( - mel = torch.zeros(1, 80, 0, device='cuda'), + mel = torch.zeros(1, 80, 0, device='cuda'), source = torch.zeros(1, 1, 0, device='cuda'), speech = torch.zeros(1, 0, device='cuda'), ) @@ -445,7 +440,7 @@ def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] for i, item in enumerate(batch): generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) - audio = torch.from_numpy(item['prompt_audio']['array']).float() + audio = torch.from_numpy(item['prompt_audio']['array']).float() prompt_audios_list.append(audio) prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) ids.append(item['id']) @@ -473,20 +468,20 @@ if __name__ == "__main__": data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) - - + + for epoch in range(args.warmup): start_time = time.time() - + for batch in data_loader: ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) - + for id, wav in zip(ids, generated_wavs): torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) - + end_time = time.time() epoch_time = end_time - start_time print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") \ No newline at end of file diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py index d309d18..77f2915 100644 --- a/runtime/triton_trtllm/offline_inference.py +++ b/runtime/triton_trtllm/offline_inference.py @@ -365,7 +365,6 @@ def main(args): runner = None else: raise ValueError(f"Unsupported backend: {args.backend}") - if 'Step-Audio-2-mini' in args.token2wav_path: from token2wav_dit import CosyVoice2_Token2Wav else: diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index a5404e2..e9c2ebb 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -14,7 +14,7 @@ def collate_fn(batch): prompt_speech_tokens_list, prompt_text_list = [], [] for i, item in enumerate(batch): generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) - audio = torch.from_numpy(item['prompt_audio']['array']).float() + audio = torch.from_numpy(item['prompt_audio']['array']).float() prompt_audios_list.append(audio) prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) ids.append(item['id']) @@ -37,7 +37,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - + if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) @@ -46,7 +46,7 @@ if __name__ == "__main__": data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) - + CHUNK_SIZE = 25 token_frame_rate = 25 OVERLAP_SIZE = 0 @@ -68,7 +68,7 @@ if __name__ == "__main__": semantic_token_ids_arr, token_offset = [], 0 flow_prompt_speech_token_len = len(prompt_speech_tokens) - + buffer = generated_speech_tokens output_wavs = [] chunk_index = 0 @@ -97,7 +97,7 @@ if __name__ == "__main__": output_wavs[i] = wav.cpu().numpy().squeeze() - audios = output_wavs + audios = output_wavs reconstructed_audio = np.concatenate(audios) sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") From 33aee03ed5b219c1173f2f71250a078739bdddda Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Thu, 9 Oct 2025 15:13:43 +0800 Subject: [PATCH 13/15] fix lint --- examples/grpo/cosyvoice2/infer_dataset.py | 2 +- .../cosyvoice2/pretrained_to_huggingface.py | 2 - .../scripts/offline-decode-files.py | 4 +- .../grpo/cosyvoice2/token2wav_asr_server.py | 2 +- runtime/triton_trtllm/client_grpc.py | 12 +- runtime/triton_trtllm/client_http.py | 1 - .../model_repo/cosyvoice2/1/model.py | 3 - .../model_repo/cosyvoice2_dit/1/model.py | 3 +- .../model_repo/token2wav/1/model.py | 1 - .../model_repo/token2wav_dit/1/model.py | 9 +- .../token2wav_dit/1/token2wav_dit.py | 111 +++++++++++------- runtime/triton_trtllm/offline_inference.py | 1 - runtime/triton_trtllm/scripts/test_llm.py | 5 - runtime/triton_trtllm/streaming_inference.py | 16 ++- 14 files changed, 100 insertions(+), 72 deletions(-) diff --git a/examples/grpo/cosyvoice2/infer_dataset.py b/examples/grpo/cosyvoice2/infer_dataset.py index 4dcbc96..f0d22d7 100644 --- a/examples/grpo/cosyvoice2/infer_dataset.py +++ b/examples/grpo/cosyvoice2/infer_dataset.py @@ -53,7 +53,7 @@ except RuntimeError: pass -TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" +TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501 def audio_decode_cosyvoice2( diff --git a/examples/grpo/cosyvoice2/pretrained_to_huggingface.py b/examples/grpo/cosyvoice2/pretrained_to_huggingface.py index 161a11f..7aaa10b 100644 --- a/examples/grpo/cosyvoice2/pretrained_to_huggingface.py +++ b/examples/grpo/cosyvoice2/pretrained_to_huggingface.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # diff --git a/examples/grpo/cosyvoice2/scripts/offline-decode-files.py b/examples/grpo/cosyvoice2/scripts/offline-decode-files.py index 847d434..90c4665 100644 --- a/examples/grpo/cosyvoice2/scripts/offline-decode-files.py +++ b/examples/grpo/cosyvoice2/scripts/offline-decode-files.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 -# # Copyright (c) 2023 by manyeyes # Copyright (c) 2023 Xiaomi Corporation @@ -195,7 +193,7 @@ def write_error_stats( hyp = list("".join(hyp)) results[i] = (cut_id, ref, hyp) - for cut_id, ref, hyp in results: + for _cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) for ref_word, hyp_word in ali: if ref_word == ERR: diff --git a/examples/grpo/cosyvoice2/token2wav_asr_server.py b/examples/grpo/cosyvoice2/token2wav_asr_server.py index 8a6cb6e..9f9f80b 100644 --- a/examples/grpo/cosyvoice2/token2wav_asr_server.py +++ b/examples/grpo/cosyvoice2/token2wav_asr_server.py @@ -295,7 +295,7 @@ def main(): metrics_port=8002, ) - device_ids = [i for i in range(args.number_of_devices)] + device_ids = list(range(args.number_of_devices)) device_ids = device_ids * args.number_of_instances_per_device with Triton(config=triton_config) as triton: diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 840390d..b344849 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -122,7 +122,10 @@ def write_triton_stats(stats, summary_file): total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 summary_f.write( - f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" + f"queue time {total_queue_time_s:<5.2f} s, " + f"compute infer time {total_infer_time_s:<5.2f} s, " + f"compute input time {total_input_time_s:<5.2f} s, " + f"compute output time {total_output_time_s:<5.2f} s \n" ) model_batch_stats = model_state["batch_stats"] for batch in model_batch_stats: @@ -136,7 +139,12 @@ def write_triton_stats(stats, summary_file): compute_input_time_ms = int(compute_input["ns"]) / 1e6 compute_output_time_ms = int(compute_output["ns"]) / 1e6 summary_f.write( - f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" + f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, " + f"total_infer_time {compute_infer_time_ms:<9.2f} ms, " + f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}=" + f"{compute_infer_time_ms / batch_count:.2f} ms, " + f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}=" + f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" ) summary_f.write( f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " diff --git a/runtime/triton_trtllm/client_http.py b/runtime/triton_trtllm/client_http.py index 4d73e0b..2c30a7a 100644 --- a/runtime/triton_trtllm/client_http.py +++ b/runtime/triton_trtllm/client_http.py @@ -25,7 +25,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import requests import soundfile as sf -import json import numpy as np import argparse diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py index 97659ad..a2bfb30 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py @@ -25,12 +25,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json -import math import os -import re import threading import time -from typing import Dict, List, Tuple, Optional, Union import numpy as np import torch diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index 8e2b28b..2e2533c 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -178,7 +178,6 @@ class TritonPythonModel: yield final_id buffer = buffer[match.end():] - def forward_audio_tokenizer(self, wav, wav_len): """Forward pass through the audio tokenizer component. @@ -263,7 +262,7 @@ class TritonPythonModel: ], inputs=inputs_tensor, request_id=request_id, - parameters={"priority": index+1}, + parameters={"priority": index + 1}, ) inference_response = await inference_request.async_exec() diff --git a/runtime/triton_trtllm/model_repo/token2wav/1/model.py b/runtime/triton_trtllm/model_repo/token2wav/1/model.py index 10bc272..1e38052 100644 --- a/runtime/triton_trtllm/model_repo/token2wav/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav/1/model.py @@ -28,7 +28,6 @@ import json import os import logging -from typing import List, Dict import torch from torch.utils.dlpack import to_dlpack diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index 1f90644..f9e461e 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -48,9 +48,11 @@ import hashlib logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) + ORIGINAL_VOCAB_SIZE = 151663 torch.set_num_threads(1) + def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str: """ Generates a unique ID for a torch.Tensor. @@ -65,6 +67,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str: return hasher.hexdigest() + class TritonPythonModel: """Triton Python model for vocoder. @@ -114,7 +117,6 @@ class TritonPythonModel: request_id = request.request_id() - wav_array = pb_utils.get_input_tensor_by_name( request, "reference_wav").as_numpy() wav_len = pb_utils.get_input_tensor_by_name( @@ -125,7 +127,10 @@ class TritonPythonModel: spk_id = get_spk_id_from_prompt_audio(wav) - audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) + audio_hat = self.token2wav_model.forward_streaming( + target_speech_tokens, finalize, request_id=request_id, + speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000 + ) outputs = [] diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index d413003..6bce5cc 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -35,7 +35,7 @@ import numpy as np from hyperpyyaml import load_hyperpyyaml -def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor): +def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor): """perform fade_in_out in tensor style """ mel_overlap_len = int(window.shape[0] / 2) @@ -45,6 +45,7 @@ def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torc fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] return fade_in_mel + def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): import tensorrt as trt logging.info("Converting onnx to trt...") @@ -90,6 +91,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): f.write(engine_bytes) logging.info("Succesfully convert onnx to trt...") + class TrtContextWrapper: def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) @@ -108,6 +110,7 @@ class TrtContextWrapper: def release_estimator(self, context, stream): self.trt_context_pool.put([context, stream]) + class CosyVoice2_Token2Wav(torch.nn.Module): def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16): super().__init__() @@ -131,27 +134,33 @@ class CosyVoice2_Token2Wav(torch.nn.Module): option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 - self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, - providers=["CPUExecutionProvider"]) + self.spk_model = onnxruntime.InferenceSession( + f"{model_dir}/campplus.onnx", sess_options=option, + providers=["CPUExecutionProvider"]) self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() - gpu="l20" + gpu = "l20" if enable_trt: if streaming: - self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', - f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', - 1, - self.dtype, streaming) + self.load_trt( + f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', + 1, + self.dtype, streaming + ) else: - self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', - f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', - 1, - self.dtype) - self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', - f'{model_dir}/campplus.onnx', - 1, - False) - + self.load_trt( + f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + self.dtype + ) + self.load_spk_trt( + f'{model_dir}/campplus.{gpu}.fp32.trt', + f'{model_dir}/campplus.onnx', + 1, + False + ) self.streaming_flow_cache = {} self.speaker_cache = {} @@ -215,7 +224,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): opt_batch_size = 2 max_batch_size = 16 if streaming: - opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts + opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming) convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype) del self.flow.decoder.estimator @@ -228,13 +237,27 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False): if streaming: min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)] - opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)] - max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)] + opt_shape = [ + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), + (opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2), + (16, opt_batch_size * 2, 8, 100, 128) + ] + max_shape = [ + (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), + (max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2), + (16, max_batch_size * 2, 8, 1000, 128) + ] input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"] else: min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] - opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)] - max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)] + opt_shape = [ + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500), + (opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80) + ] + max_shape = [ + (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), + (max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80) + ] input_names = ["x", "mask", "mu", "cond", "t", "spks"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} @@ -279,11 +302,17 @@ class CosyVoice2_Token2Wav(torch.nn.Module): mel_len = mel.shape[0] prompt_mels_for_flow.append(mel) prompt_mels_lens_for_flow.append(mel_len) - prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] + prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence( + prompt_mels_for_flow, batch_first=True, padding_value=0 + ) # [B, T', num_mels=80] prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) return prompt_mels_for_flow, prompt_mels_lens_for_flow - def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], + generated_speech_tokens_list: list[list[int]], + prompt_mels_for_flow: torch.Tensor, + prompt_mels_lens_for_flow: torch.Tensor, + spk_emb_for_flow: torch.Tensor): batch_size = prompt_mels_for_flow.shape[0] flow_inputs = [] flow_inputs_lens = [] @@ -311,7 +340,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): generated_wavs.append(wav) return generated_wavs - @torch.inference_mode() def forward( self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] @@ -320,7 +348,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module): prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) - generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + generated_mels, generated_mels_lens = self.forward_flow( + prompt_speech_tokens_list, generated_speech_tokens_list, + prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + ) generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) return generated_wavs @@ -337,7 +368,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow - def get_prompt_audio_cache_for_streaming_tts( self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow ): @@ -356,7 +386,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] return new_cache - @torch.inference_mode() def forward_streaming( self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 @@ -379,9 +408,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module): if request_id not in self.streaming_flow_cache: self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.hift_cache_dict[request_id] = dict( - mel = torch.zeros(1, 80, 0, device='cuda'), - source = torch.zeros(1, 1, 0, device='cuda'), - speech = torch.zeros(1, 0, device='cuda'), + mel=torch.zeros(1, 80, 0, device='cuda'), + source=torch.zeros(1, 1, 0, device='cuda'), + speech=torch.zeros(1, 0, device='cuda'), ) current_request_cache = self.streaming_flow_cache[request_id] @@ -389,7 +418,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') - chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( token=generated_speech_tokens, spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device), @@ -400,15 +428,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module): self.streaming_flow_cache[request_id] = new_streaming_flow_cache - if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]], self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], ], dim=4) - - hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone() hift_cache_source = self.hift_cache_dict[request_id]['source'].clone() hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone() @@ -422,9 +447,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module): # update vocoder cache self.hift_cache_dict[request_id] = dict( - mel = mel[..., -self.mel_cache_len:].clone().detach(), - source = source[:, :, -self.source_cache_len:].clone().detach(), - speech = speech[:, -self.source_cache_len:].clone().detach(), + mel=mel[..., -self.mel_cache_len:].clone().detach(), + source=source[:, :, -self.source_cache_len:].clone().detach(), + speech=speech[:, -self.source_cache_len:].clone().detach(), ) if not last_chunk: speech = speech[:, :-self.source_cache_len] @@ -436,6 +461,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): return speech + def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] for i, item in enumerate(batch): @@ -447,6 +473,7 @@ def collate_fn(batch): return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--enable-trt", action="store_true") @@ -457,6 +484,7 @@ def get_args(): parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") return parser.parse_args() + if __name__ == "__main__": args = get_args() model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) @@ -466,22 +494,17 @@ if __name__ == "__main__": dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) - data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) - for epoch in range(args.warmup): start_time = time.time() - for batch in data_loader: ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) - for id, wav in zip(ids, generated_wavs): torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) - end_time = time.time() epoch_time = end_time - start_time - print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") \ No newline at end of file + print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py index 77f2915..e3eac2f 100644 --- a/runtime/triton_trtllm/offline_inference.py +++ b/runtime/triton_trtllm/offline_inference.py @@ -28,7 +28,6 @@ import argparse import json import os import sys -from pathlib import Path import torch import torch.distributed as dist diff --git a/runtime/triton_trtllm/scripts/test_llm.py b/runtime/triton_trtllm/scripts/test_llm.py index d52d724..90e6710 100644 --- a/runtime/triton_trtllm/scripts/test_llm.py +++ b/runtime/triton_trtllm/scripts/test_llm.py @@ -15,11 +15,6 @@ # limitations under the License. import argparse -import ast -import csv -import os -from pathlib import Path -from typing import List, Optional import numpy as np import torch diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index e9c2ebb..9c4a2fb 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -9,6 +9,7 @@ import time from token2wav_dit import CosyVoice2_Token2Wav import soundfile as sf + def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] prompt_speech_tokens_list, prompt_text_list = [], [] @@ -23,6 +24,7 @@ def collate_fn(batch): return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--enable-trt", action="store_true") @@ -79,7 +81,11 @@ if __name__ == "__main__": this_chunk_size = token_frame_rate * (2 ** chunk_index) if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len: - wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + wavs = token2wav_model.forward_streaming( + buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], + False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, + prompt_audio_sample_rate=prompt_audio_sample_rate + ) buffer = buffer[this_chunk_size - OVERLAP_SIZE:] output_wavs.append(wavs) @@ -87,7 +93,10 @@ if __name__ == "__main__": chunk_index += 1 else: - wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + wavs = token2wav_model.forward_streaming( + buffer, True, request_id=id, speaker_id=f"{id}", + prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate + ) output_wavs.append(wavs) total_forward_count += 1 # chunk_index += 1 @@ -96,7 +105,6 @@ if __name__ == "__main__": for i, wav in enumerate(output_wavs): output_wavs[i] = wav.cpu().numpy().squeeze() - audios = output_wavs reconstructed_audio = np.concatenate(audios) sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") @@ -111,4 +119,4 @@ if __name__ == "__main__": print(f"Cost time without speaker cache: {end_time - start_time} seconds") else: print(f"Cost time with speaker cache: {end_time - start_time} seconds") - print(f"Total flow matching forward calls: {total_forward_count}") \ No newline at end of file + print(f"Total flow matching forward calls: {total_forward_count}") From a224be6117ee5af206481be566913144db27476a Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Thu, 9 Oct 2025 15:18:09 +0800 Subject: [PATCH 14/15] fix lint --- examples/grpo/cosyvoice2/infer_dataset.py | 2 +- .../triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py | 4 ++-- runtime/triton_trtllm/offline_inference.py | 2 +- runtime/triton_trtllm/streaming_inference.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/grpo/cosyvoice2/infer_dataset.py b/examples/grpo/cosyvoice2/infer_dataset.py index f0d22d7..f72cd77 100644 --- a/examples/grpo/cosyvoice2/infer_dataset.py +++ b/examples/grpo/cosyvoice2/infer_dataset.py @@ -53,7 +53,7 @@ except RuntimeError: pass -TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501 +TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501 def audio_decode_cosyvoice2( diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index 6bce5cc..1c6c423 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -464,7 +464,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] - for i, item in enumerate(batch): + for item in batch: generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) audio = torch.from_numpy(item['prompt_audio']['array']).float() prompt_audios_list.append(audio) @@ -496,7 +496,7 @@ if __name__ == "__main__": data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) - for epoch in range(args.warmup): + for _ in range(args.warmup): start_time = time.time() for batch in data_loader: ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py index e3eac2f..326fb0e 100644 --- a/runtime/triton_trtllm/offline_inference.py +++ b/runtime/triton_trtllm/offline_inference.py @@ -512,7 +512,7 @@ def main(args): )) else: outputs = [] - for i, chat in enumerate(batch["chat_list"]): + for chat in batch["chat_list"]: payload = { "model": args.openai_model_name, "messages": chat, diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index 9c4a2fb..7cfb6f9 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -13,7 +13,7 @@ import soundfile as sf def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] prompt_speech_tokens_list, prompt_text_list = [], [] - for i, item in enumerate(batch): + for item in batch: generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) audio = torch.from_numpy(item['prompt_audio']['array']).float() prompt_audios_list.append(audio) From 1fc843514689daa61b471f1bc862893b3a5035a7 Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Thu, 16 Oct 2025 15:58:22 +0800 Subject: [PATCH 15/15] add disaggregated deployment --- runtime/triton_trtllm/README.DIT.md | 37 +++++++++++- runtime/triton_trtllm/client_grpc.py | 2 + .../run_stepaudio2_dit_token2wav.sh | 59 +++++++++++++++++-- 3 files changed, 93 insertions(+), 5 deletions(-) diff --git a/runtime/triton_trtllm/README.DIT.md b/runtime/triton_trtllm/README.DIT.md index 3c130b3..2fd9f61 100644 --- a/runtime/triton_trtllm/README.DIT.md +++ b/runtime/triton_trtllm/README.DIT.md @@ -45,7 +45,8 @@ bash run_stepaudio2_dit_token2wav.sh - **Stage 4**: Runs the gRPC benchmark client for performance testing. - **Stage 5**: Runs the offline TTS inference benchmark test. - **Stage 6**: Runs a standalone inference script for the Step-Audio2-mini DiT Token2Wav model. - +- **Stage 7**: Launches servers in a disaggregated setup, with the LLM on GPU 0 and Token2Wav servers on GPUs 1-3. +- **Stage 8**: Runs the benchmark client for the disaggregated server configuration. ### Export Models and Launch Server Inside the Docker container, prepare the models and start the Triton server by running stages 0-3: @@ -100,6 +101,40 @@ The following results were obtained by decoding on a single L20 GPU with the `yu | TRTLLM | 16 | 2.01 | 5.03 | 0.0292 | +### Disaggregated Server +When the LLM and token2wav components are deployed on the same GPU, they compete for resources. To optimize performance, we use a disaggregated setup where the LLM is deployed on one dedicated L20 GPU, taking advantage of in-flight batching for inference. The token2wav module is deployed on separate, dedicated GPUs. + +The table below shows the first chunk latency results for this configuration. In our tests, we deploy two token2wav instances on each dedicated token2wav GPU. + +| token2wav_num_gpu | concurrent_task_per_instance | concurrent_tasks_per_gpu | avg (ms) | p50 (ms) | p90 (ms) | p99 (ms) | +|---|---|---|---|---|---|---| +| 1 | 1 | 1.00 | 218.53 | 217.86 | 254.07 | 296.49 | +| 2 | 1 | 1.33 | 218.82 | 219.21 | 256.62 | 303.13 | +| 3 | 1 | 1.50 | 229.08 | 223.27 | 302.13 | 324.41 | +| 4 | 1 | 1.60 | 203.87 | 198.23 | 254.92 | 279.31 | +| 1 | 2 | 2.00 | 293.46 | 280.53 | 370.81 | 407.40 | +| 2 | 2 | 2.67 | 263.38 | 236.84 | 350.82 | 397.39 | +| 3 | 2 | 3.00 | 308.09 | 275.48 | 385.22 | 521.45 | +| 4 | 2 | 3.20 | 271.85 | 253.25 | 359.03 | 387.91 | +| 1 | 3 | 3.00 | 389.15 | 373.01 | 469.22 | 542.89 | +| 2 | 3 | 4.00 | 403.48 | 394.80 | 481.24 | 507.75 | +| 3 | 3 | 4.50 | 406.33 | 391.28 | 495.43 | 571.29 | +| 4 | 3 | 4.80 | 436.72 | 383.81 | 638.44 | 879.23 | +| 1 | 4 | 4.00 | 520.12 | 493.98 | 610.38 | 739.85 | +| 2 | 4 | 5.33 | 494.60 | 490.50 | 605.93 | 708.09 | +| 3 | 4 | 6.00 | 538.23 | 508.33 | 687.62 | 736.96 | +| 4 | 4 | 6.40 | 579.68 | 546.20 | 721.53 | 958.04 | +| 1 | 5 | 5.00 | 635.02 | 623.30 | 786.85 | 819.84 | +| 2 | 5 | 6.67 | 598.23 | 617.09 | 741.00 | 788.96 | +| 3 | 5 | 7.50 | 644.78 | 684.40 | 786.45 | 1009.45 | +| 4 | 5 | 8.00 | 733.92 | 642.26 | 1024.79 | 1281.55 | +| 1 | 6 | 6.00 | 715.38 | 745.68 | 887.04 | 906.68 | +| 2 | 6 | 8.00 | 748.31 | 753.94 | 873.59 | 1007.14 | +| 3 | 6 | 9.00 | 900.27 | 822.28 | 1431.14 | 1800.23 | +| 4 | 6 | 9.60 | 857.54 | 820.33 | 1150.30 | 1298.53 | + +The `concurrent_task_per_gpu` is calculated as: +`concurrent_task_per_gpu = concurrent_task_per_instance * num_token2wav_instance_per_gpu (2) * token2wav_gpus / (token2wav_gpus + llm_gpus (1))` ### Acknowledgements diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index b344849..1ceccb2 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -134,6 +134,8 @@ def write_triton_stats(stats, summary_file): compute_output = batch["compute_output"] compute_infer = batch["compute_infer"] batch_count = int(compute_infer["count"]) + if batch_count == 0: + continue assert compute_infer["count"] == compute_output["count"] == compute_input["count"] compute_infer_time_ms = int(compute_infer["ns"]) / 1e6 compute_input_time_ms = int(compute_input["ns"]) / 1e6 diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index 5881b44..28ab13f 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -20,7 +20,7 @@ trt_weights_dir=./trt_weights_${trt_dtype} trt_engines_dir=./trt_engines_${trt_dtype} model_repo=./model_repo_cosyvoice2_dit -bls_instance_num=4 +bls_instance_num=10 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then @@ -58,7 +58,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then echo "Building TensorRT engines" trtllm-build --checkpoint_dir $trt_weights_dir \ --output_dir $trt_engines_dir \ - --max_batch_size 16 \ + --max_batch_size 64 \ --max_num_tokens 32768 \ --gemm_plugin $trt_dtype || exit 1 @@ -100,14 +100,14 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve" - mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 & + mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64 --kv_cache_free_gpu_memory_fraction 0.4 & tritonserver --model-repository $model_repo --http-port 18000 & wait # Test using curl # curl http://localhost:8000/v1/chat/completions \ # -H "Content-Type: application/json" \ # -d '{ - # "model": "trt_engines_bfloat16", + # "model": "", # "messages":[{"role": "user", "content": "Where is New York?"}, # {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}], # "max_tokens": 512, @@ -172,3 +172,54 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then fi +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + echo "Disaggregated Server: LLM and Token2wav on different GPUs" + echo "Starting LLM server on GPU 0" + export CUDA_VISIBLE_DEVICES=0 + mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64 --kv_cache_free_gpu_memory_fraction 0.4 & + echo "Starting Token2wav server on GPUs 1-3" + Token2wav_num_gpus=3 + http_port=17000 + grpc_port=18000 + metrics_port=16000 + for i in $(seq 0 $(($Token2wav_num_gpus - 1))); do + echo "Starting server on GPU $i" + http_port=$((http_port + 1)) + grpc_port=$((grpc_port + 1)) + metrics_port=$((metrics_port + 1)) + # Two instances of Token2wav server on the same GPU + CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port & + http_port=$((http_port + 1)) + grpc_port=$((grpc_port + 1)) + metrics_port=$((metrics_port + 1)) + CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port & + done + wait +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + echo "Running benchmark client for Disaggregated Server" + per_gpu_instances=2 + mode=streaming + BLS_INSTANCE_NUM=$bls_instance_num + Token2wav_num_gpus=(1 2 3) + concurrent_tasks=(1 2 3 4 5 6) + for n_gpu in ${Token2wav_num_gpus[@]}; do + echo "Test 1 GPU for LLM server and $n_gpu GPUs for Token2wav servers" + for concurrent_task in ${concurrent_tasks[@]}; do + num_instances=$((per_gpu_instances * n_gpu)) + for i in $(seq 1 $num_instances); do + port=$(($i + 18000)) + python3 client_grpc.py \ + --server-addr localhost \ + --server-port $port \ + --model-name cosyvoice2_dit \ + --num-tasks $concurrent_task \ + --mode $mode \ + --huggingface-dataset yuekai/seed_tts_cosy2 \ + --log-dir ./log_disagg_concurrent_tasks_${concurrent_task}_per_instance_total_token2wav_instances_${num_instances}_port_${port} & + done + wait + done + done +fi \ No newline at end of file