diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 994a401..afbab68 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -209,7 +209,8 @@ def get_args(): choices=[ "f5_tts", "spark_tts", - "cosyvoice2"], + "cosyvoice2", + "cosyvoice2_dit"], help="triton model_repo module name to request", ) @@ -260,8 +261,8 @@ def get_args(): parser.add_argument( "--use-spk2info-cache", - type=bool, - default=False, + type=str, + default="False", help="Use spk2info cache for reference audio.", ) @@ -490,6 +491,7 @@ async def send_streaming( padding_duration=padding_duration, use_spk2info_cache=use_spk2info_cache ) + request_id = str(uuid.uuid4()) user_data = UserData() @@ -670,11 +672,15 @@ async def main(): trust_remote_code=True, ) manifest_item_list = [] + tmp_audio_path="./asset_zero_shot_prompt.wav" + tmp_audio_text="希望你以后能够做的比我还好呦。" for i in range(len(dataset)): manifest_item_list.append( { "audio_filepath": dataset[i]["prompt_audio"], "reference_text": dataset[i]["prompt_text"], + # "audio_filepath": tmp_audio_path, + # "reference_text": tmp_audio_text, "target_audio_path": dataset[i]["id"], "target_text": dataset[i]["target_text"], } @@ -686,7 +692,7 @@ async def main(): manifest_item_list = split_data(manifest_item_list, num_tasks) os.makedirs(args.log_dir, exist_ok=True) - + args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true" tasks = [] start_time = time.time() for i in range(num_tasks): diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index 97659ad..d0977c5 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -227,12 +227,11 @@ class TritonPythonModel: def forward_token2wav( self, + index: int, target_speech_tokens: torch.Tensor, request_id: str, - prompt_speech_tokens: torch.Tensor = None, - prompt_speech_feat: torch.Tensor = None, - prompt_spk_embedding: torch.Tensor = None, - token_offset: int = None, + reference_wav: object, + reference_wav_len: object, finalize: bool = None) -> torch.Tensor: """Forward pass through the vocoder component. @@ -246,29 +245,16 @@ class TritonPythonModel: Generated waveform tensor """ target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) - - inputs_tensor = [target_speech_tokens_tensor] - - if token_offset is not None: - assert finalize is not None - token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32)) - finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) - inputs_tensor.append(token_offset_tensor) - inputs_tensor.append(finalize_tensor) - - if prompt_spk_embedding is not None: - assert prompt_speech_feat is not None - prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) - prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) - prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) - inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor]) + finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) + inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] # Create and execute inference request inference_request = pb_utils.InferenceRequest( - model_name='token2wav', + model_name='token2wav_dit', requested_output_names=['waveform'], inputs=inputs_tensor, request_id=request_id, + parameters={"priority": index+1}, ) inference_response = inference_request.exec() @@ -346,8 +332,15 @@ class TritonPythonModel: reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') - prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + + # reference_text = self.default_spk_info["prompt_text"] + # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + # prompt_speech_feat = None + # prompt_spk_embedding = None + else: + assert False, "wav is None" # using pre-cached reference text reference_text = self.default_spk_info["prompt_text"] prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE @@ -391,12 +384,12 @@ class TritonPythonModel: break if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: - this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] + this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) sub_tts_speech = self.forward_token2wav( - this_tts_speech_token, request_id, prompt_speech_tokens, - prompt_speech_feat, prompt_spk_embedding, token_offset, False + chunk_index, + this_tts_speech_token, request_id, wav, wav_len, False ) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) @@ -429,8 +422,8 @@ class TritonPythonModel: else: time.sleep(0.02) - this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device) - sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True) + this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) + sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) @@ -439,17 +432,7 @@ class TritonPythonModel: response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) self.logger.log_info("send tritonserver_response_complete_final to end") else: - generated_ids = next(generated_ids_iter) - generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device) - if generated_ids is None or len(generated_ids) == 0: - raise pb_utils.TritonModelException("Generated IDs is None or empty") - - audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding) - - # Prepare response - audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) - inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) - responses.append(inference_response) + raise NotImplementedError("Decoupled mode is not supported") if not self.decoupled: return responses diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py new file mode 100644 index 0000000..b4a6348 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py @@ -0,0 +1,438 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import math +import os +import re +import time +from typing import Dict, List, Tuple, Optional, Union +import asyncio +import httpx + +import numpy as np +import torch +from torch.utils.dlpack import from_dlpack, to_dlpack +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer + +import torchaudio + + +from matcha.utils.audio import mel_spectrogram + +ORIGINAL_VOCAB_SIZE = 151663 +torch.set_num_threads(1) + + +def parse_speech_token_string(response_text: str) -> List[int]: + """ + Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs. + """ + speech_tokens = response_text.strip().split('><') + if len(speech_tokens) > 1: + # Add back the missing '<' and '>' for proper parsing + speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens] + speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens] + + speech_ids = [] + for token_str in speech_tokens: + match = re.match(r'<\|s_(\d+)\|>', token_str) + if match: + speech_ids.append(int(match.group(1))) + return speech_ids + + +class TritonPythonModel: + """Triton Python model for Spark TTS. + + This model orchestrates the end-to-end TTS pipeline by coordinating + between audio tokenizer, LLM, and vocoder components. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + self.logger = pb_utils.Logger + # Parse model parameters + self.model_config = json.loads(args['model_config']) + parameters = self.model_config['parameters'] + model_params = {k: v["string_value"] for k, v in parameters.items()} + self.logger.log_info(f"model_params:{model_params}") + self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" + self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") + + # Initialize tokenizer + llm_tokenizer_dir = model_params["llm_tokenizer_dir"] + self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) + self.prompt_template = "<|sos|>{input_text}<|task_id|>" + self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>") + + self.device = torch.device("cuda") + self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) + + self.token_frame_rate = 25 + self.flow_pre_lookahead_len = 3 + self.token_hop_len = 15 + + spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt") + if not os.path.exists(spk_info_path): + raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") + spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) + # self.default_spk_info = spk_info["001"] + + def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: + """Converts a tensor or list of speech token IDs to a string representation.""" + if isinstance(speech_tokens, torch.Tensor): + # Ensure tensor is on CPU and flattened + speech_tokens = speech_tokens.cpu().numpy().flatten().tolist() + + speech_id_str = "" + for token_id in speech_tokens: + # Convert token ID back to the speech number N + token_num = token_id - ORIGINAL_VOCAB_SIZE + speech_id_str += f"<|s_{token_num}|>" + return speech_id_str + + async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]): + """ + Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response. + """ + full_text = f"{reference_text}{target_text}" + prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens) + + chat = [ + {"role": "user", "content": full_text}, + {"role": "assistant", "content": prompt_speech_tokens_str} + ] + print(chat) + + payload = { + "model": "trt_engines_bfloat16", + "messages": chat, + "max_tokens": 750, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "repetition_penalty": 1.1, + "stop": ["<|eos1|>", "<|eos|>"], + "stream": True, + } + + api_base = "http://localhost:8000/v1/chat/completions" + + buffer = "" + async with httpx.AsyncClient() as client: + async with client.stream("POST", api_base, json=payload, timeout=None) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if line.startswith("data: "): + line_data = line[len("data: "):].strip() + if line_data == "[DONE]": + break + try: + json_data = json.loads(line_data) + content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") + if content: + buffer += content + while True: + match = re.search(r"<\|s_(\d+)\|>", buffer) + if not match: + break + + token_num = int(match.group(1)) + final_id = token_num + ORIGINAL_VOCAB_SIZE + yield final_id + buffer = buffer[match.end():] + except json.JSONDecodeError: + self.logger.log_info(f"Skipping non-JSON line: {line_data}") + continue + + # Process any remaining complete tokens in the buffer after the stream ends + while True: + match = re.search(r"<\|s_(\d+)\|>", buffer) + if not match: + break + token_num = int(match.group(1)) + final_id = token_num + ORIGINAL_VOCAB_SIZE + yield final_id + buffer = buffer[match.end():] + + + def forward_audio_tokenizer(self, wav, wav_len): + """Forward pass through the audio tokenizer component. + + Args: + wav: Input waveform tensor + wav_len: Waveform length tensor + + Returns: + Tuple of global and semantic tokens + """ + inference_request = pb_utils.InferenceRequest( + model_name='audio_tokenizer', + requested_output_names=['prompt_speech_tokens'], + inputs=[wav, wav_len] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') + prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() + + return prompt_speech_tokens + + def forward_speaker_embedding(self, wav): + """Forward pass through the speaker embedding component. + + Args: + wav: Input waveform tensor + + Returns: + Prompt speaker embedding tensor + """ + inference_request = pb_utils.InferenceRequest( + model_name='speaker_embedding', + requested_output_names=['prompt_spk_embedding'], + inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding') + prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack()) + + return prompt_spk_embedding + + def forward_token2wav( + self, + index: int, + target_speech_tokens: torch.Tensor, + request_id: str, + reference_wav: object, + reference_wav_len: object, + finalize: bool = None) -> torch.Tensor: + """Forward pass through the vocoder component. + + Args: + prompt_speech_tokens: Prompt speech tokens tensor + prompt_speech_feat: Prompt speech feat tensor + prompt_spk_embedding: Prompt spk embedding tensor + target_speech_tokens: Target speech tokens tensor + + Returns: + Generated waveform tensor + """ + target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) + finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) + inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] + + # Create and execute inference request + inference_request = pb_utils.InferenceRequest( + model_name='token2wav_dit', + requested_output_names=['waveform'], + inputs=inputs_tensor, + request_id=request_id, + parameters={"priority": index+1}, + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output waveform + waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') + waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() + + return waveform + + def _extract_speech_feat(self, speech): + speech_feat = mel_spectrogram( + speech, + n_fft=1920, + num_mels=80, + sampling_rate=24000, + hop_size=480, + win_size=1920, + fmin=0, + fmax=8000).squeeze( + dim=0).transpose( + 0, + 1).to( + self.device) + speech_feat = speech_feat.unsqueeze(dim=0) + return speech_feat + + async def _process_request(self, request): + request_id = request.request_id() + # Extract input tensors + wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") + + # Process reference audio through audio tokenizer + if wav is not None: + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) + prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) + + wav_tensor = wav.as_numpy() + wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) + speech_feat = self._extract_speech_feat(prompt_speech_resample) + token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) + prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() + prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() + + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') + # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + + # reference_text = self.default_spk_info["prompt_text"] + # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + # prompt_speech_feat = None + # prompt_spk_embedding = None + + else: + # using pre-cached reference text + assert False, "using pre-cached reference text is not supported" + reference_text = self.default_spk_info["prompt_text"] + prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + prompt_speech_feat = None + prompt_spk_embedding = None + + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = target_text[0][0].decode('utf-8') + + if self.decoupled: + response_sender = request.get_response_sender() + + semantic_token_ids_arr = [] + token_offset, chunk_index = 0, 0 + start_time = time.time() + this_token_hop_len = self.token_hop_len + + async for generated_ids in self.forward_llm_async( + target_text=target_text, + reference_text=reference_text, + prompt_speech_tokens=prompt_speech_tokens, + ): + if not generated_ids: + break + semantic_token_ids_arr.append(generated_ids) + + while True: + pending_num = len(semantic_token_ids_arr) - token_offset + if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: + this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] + this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) + + sub_tts_speech = self.forward_token2wav( + chunk_index, + this_tts_speech_token, request_id, wav, wav_len, False + ) + + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + + token_offset += this_token_hop_len + self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}") + + if self.dynamic_chunk_strategy == "exponential": + this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) + elif self.dynamic_chunk_strategy == "time_based": + # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 + cost_time = time.time() - start_time + duration = token_offset / self.token_frame_rate + if chunk_index > 0 and cost_time > 0: + avg_chunk_processing_time = cost_time / (chunk_index + 1) + if avg_chunk_processing_time > 0: + multiples = (duration - cost_time) / avg_chunk_processing_time + self.logger.log_info(f"multiples: {multiples}") + next_pending_num = len(semantic_token_ids_arr) - token_offset + if multiples > 4: + this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len + elif multiples > 2: + this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len + else: + this_token_hop_len = self.token_hop_len + this_token_hop_len = max(self.token_hop_len, this_token_hop_len) + chunk_index += 1 + else: + break + + this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) + sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + + ## debug + ## save semantic_token_ids_arr and reference_text, target_text to a single json file + # save into a torch .pt + # for i, item in enumerate(semantic_token_ids_arr): + # semantic_token_ids_arr[i] = item - ORIGINAL_VOCAB_SIZE + # import json + # data = { + # "semantic_token_ids_arr": semantic_token_ids_arr, + # "reference_text": reference_text, + # "target_text": target_text + # } + # with open(f"semantic_token_ids_arr_debug_{request_id}.pt", "wb") as f: + # torch.save(data, f) + # with open(f"semantic_token_ids_arr_debug_{request_id}.json", "w") as f: + # json.dump(data, f) + + # ## + + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + self.logger.log_info("send tritonserver_response_complete_final to end") + else: + raise NotImplementedError("Decoupled mode is not supported") + + async def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated audio + """ + tasks = [ + asyncio.create_task(self._process_request(request)) + for request in requests + ] + await asyncio.gather(*tasks) + return None diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt index 73a9a05..e64647e 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: "cosyvoice2" +name: "cosyvoice2_dit" backend: "python" max_batch_size: ${triton_max_batch_size} dynamic_batching { diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index 10bc272..8f9ffba 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -42,6 +42,8 @@ from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vl from cosyvoice.utils.common import TrtContextWrapper from collections import defaultdict import numpy as np +from .token2wav_dit import CosyVoice2_Token2Wav +import hashlib logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) @@ -49,117 +51,19 @@ logger = logging.getLogger(__name__) ORIGINAL_VOCAB_SIZE = 151663 torch.set_num_threads(1) +def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str: + """ + Generates a unique ID for a torch.Tensor. + Tensors with the same elements and properties will have the same ID. + """ + # Convert tensor to a byte string + tensor_bytes = tensor.numpy().tobytes() -class CosyVoice2: - - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'): - - self.model_dir = model_dir - self.fp16 = fp16 - - hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir) - if not os.path.exists(hyper_yaml_path): - raise ValueError('{} not found!'.format(hyper_yaml_path)) - with open(hyper_yaml_path, 'r') as f: - configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) - self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device) - self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) - if load_jit: - self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) - if load_trt: - self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), - '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), - trt_concurrent, - self.fp16) - - -class CosyVoice2Model: - - def __init__(self, - flow: torch.nn.Module, - hift: torch.nn.Module, - fp16: bool = False, - device: str = 'cuda'): - self.device = device - self.flow = flow - self.hift = hift - self.fp16 = fp16 - if self.fp16 is True: - self.flow.half() - - # streaming tts config - self.token_hop_len = 25 - self.mel_cache_len = 8 - self.source_cache_len = int(self.mel_cache_len * 480) - self.speech_window = np.hamming(2 * self.source_cache_len) - self.hift_cache_dict = defaultdict(lambda: None) - - def load_jit(self, flow_encoder_model): - flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) - self.flow.encoder = flow_encoder - - def load(self, flow_model, hift_model): - self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) - self.flow.to(self.device).eval() - # in case hift_model is a hifigan model - hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} - self.hift.load_state_dict(hift_state_dict, strict=True) - self.hift.to(self.device).eval() - - def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16): - assert torch.cuda.is_available(), 'tensorrt only supports gpu!' - if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: - convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) - del self.flow.decoder.estimator - import tensorrt as trt - with open(flow_decoder_estimator_model, 'rb') as f: - estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) - assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) - self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) - - def get_trt_kwargs(self): - min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] - opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)] - max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] - input_names = ["x", "mask", "mu", "cond"] - return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} - - def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): - with torch.cuda.amp.autocast(self.fp16): - tts_mel, _ = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device), - streaming=stream, - finalize=finalize) - tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] - # append hift cache - if self.hift_cache_dict[uuid] is not None: - hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] - tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) - else: - hift_cache_source = torch.zeros(1, 1, 0) - # keep overlap mel and hift cache - if finalize is False: - tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) - if self.hift_cache_dict[uuid] is not None: - tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) - self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:], - 'source': tts_source[:, :, -self.source_cache_len:], - 'speech': tts_speech[:, -self.source_cache_len:]} - tts_speech = tts_speech[:, :-self.source_cache_len] - else: - if speed != 1.0: - assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode' - tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear') - tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) - if self.hift_cache_dict[uuid] is not None: - tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) - return tts_speech - + # Create a SHA-256 hash of the byte string + hasher = hashlib.sha256() + hasher.update(tensor_bytes) + + return hasher.hexdigest() class TritonPythonModel: """Triton Python model for vocoder. @@ -183,16 +87,10 @@ class TritonPythonModel: self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger.info(f"Initializing vocoder from {model_dir} on {self.device}") - self.token2wav_model = CosyVoice2( - model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device + # FIXME: device id settings + self.token2wav_model = CosyVoice2_Token2Wav( + model_dir, enable_trt=True, streaming=True ) - - spk_info_path = os.path.join(model_dir, "spk2info.pt") - if not os.path.exists(spk_info_path): - raise ValueError(f"spk2info.pt not found in {model_dir}") - spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) - self.default_spk_info = spk_info["001"] - logger.info("Token2Wav initialized successfully") def execute(self, requests): @@ -208,66 +106,31 @@ class TritonPythonModel: # Process each request in batch for request in requests: target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() - target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device) - - prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens") - if prompt_speech_tokens_tensor is not None: - prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy() - prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy() - prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy() - prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device) - prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device) - prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device) - prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE - else: - prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device) - prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device) - prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device) - + target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device) # shift the speech tokens according to the original vocab size target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE + target_speech_tokens = target_speech_tokens.squeeze().tolist() # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. - token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset") - if token_offset is not None: - token_offset = token_offset.as_numpy().item() - finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() - if not finalize: - stream = True - else: - stream = False - request_id = request.request_id() - audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens, - prompt_token=prompt_speech_tokens, - prompt_feat=prompt_speech_feat, - embedding=prompt_spk_embedding, - token_offset=token_offset, - uuid=request_id, - stream=stream, - finalize=finalize) - if finalize: - self.token2wav_model.model.hift_cache_dict.pop(request_id) + + finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() + + request_id = request.request_id() + - else: - tts_mel, _ = self.token2wav_model.model.flow.inference( - token=target_speech_tokens, - token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( - self.device - ), - prompt_token=prompt_speech_tokens, - prompt_token_len=torch.tensor( - [prompt_speech_tokens.shape[1]], dtype=torch.int32 - ).to(self.device), - prompt_feat=prompt_speech_feat, - prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=prompt_spk_embedding, - streaming=False, - finalize=True, - ) + wav_array = pb_utils.get_input_tensor_by_name( + request, "reference_wav").as_numpy() + wav_len = pb_utils.get_input_tensor_by_name( + request, "reference_wav_len").as_numpy().item() - audio_hat, _ = self.token2wav_model.model.hift.inference( - speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) - ) + wav_array = torch.from_numpy(wav_array) + # Prepare inputs + wav = wav_array[:, :wav_len].squeeze(0) + + spk_id = get_spk_id_from_prompt_audio(wav) + # wav = wav.to(self.device) + + audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) generated_wave = audio_hat.squeeze(0).cpu().numpy() diff --git a/runtime/triton_trtllm/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py similarity index 91% rename from runtime/triton_trtllm/token2wav_dit.py rename to runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index fdc1a12..3b696e9 100644 --- a/runtime/triton_trtllm/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -362,17 +362,17 @@ class CosyVoice2_Token2Wav(torch.nn.Module): spk_emb_for_flow.to(self.device), n_timesteps=10 ) + new_cache = {k: v.clone() for k, v in cache.items()} # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] - cache['estimator_att_cache'] = cache['estimator_att_cache'].clone() - cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone() - return cache + return new_cache @torch.inference_mode() def forward_streaming( self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 - ): + ): if speaker_id not in self.speaker_cache: + # if 1: assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio_sample_rate == 16000 @@ -382,10 +382,21 @@ class CosyVoice2_Token2Wav(torch.nn.Module): prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous() prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len] - cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} + + # if speaker_id not in self.speaker_cache: + # if 1: + cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} + print(f"speaker_id {speaker_id} added to cache") + + # get a clone of cache dict ['estimator_att_cache'] and later check if it would be change + att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'].clone() + cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'].clone() + conformer_cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'].clone() + conformer_att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'].clone() + if request_id not in self.streaming_flow_cache: self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} @@ -409,6 +420,33 @@ class CosyVoice2_Token2Wav(torch.nn.Module): n_timesteps=10, ) + # get the original att_cache + original_att_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'] + original_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'] + original_conformer_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'] + original_conformer_att_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'] + if not torch.allclose(original_att_cache, att_cache_clone): + print("att_cache changed") + # print the last 10 elements of original_att_cache and att_cache_clone + print(original_att_cache[:, :, :, -10:]) + print(att_cache_clone[:, :, :, -10:]) + breakpoint() + if not torch.allclose(original_cnn_cache, cnn_cache_clone): + print("cnn_cache changed") + print(original_cnn_cache[..., -10:]) + print(cnn_cache_clone[..., -10:]) + breakpoint() + if not torch.allclose(original_conformer_cnn_cache, conformer_cnn_cache_clone): + print("conformer_cnn_cache changed") + print(original_conformer_cnn_cache[..., -10:]) + print(conformer_cnn_cache_clone[..., -10:]) + breakpoint() + if not torch.allclose(original_conformer_att_cache, conformer_att_cache_clone): + print("conformer_att_cache changed") + print(original_conformer_att_cache[..., -10:]) + print(conformer_att_cache_clone[..., -10:]) + breakpoint() + self.streaming_flow_cache[request_id] = new_streaming_flow_cache @@ -420,10 +458,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module): - hift_cache_mel = self.hift_cache_dict[request_id]['mel'] - hift_cache_source = self.hift_cache_dict[request_id]['source'] - hift_cache_speech = self.hift_cache_dict[request_id]['speech'] - mel = torch.concat([hift_cache_mel, chunk_mel], dim=2) + hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone() + hift_cache_source = self.hift_cache_dict[request_id]['source'].clone() + hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone() + mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone() speech, source = self.hift(mel, hift_cache_source) @@ -444,7 +482,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): assert request_id in self.streaming_flow_cache self.streaming_flow_cache.pop(request_id) self.hift_cache_dict.pop(request_id) - + # breakpoint() return speech def collate_fn(batch): diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt index c33a85f..2040cfe 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: "token2wav" +name: "token2wav_dit" backend: "python" max_batch_size: ${triton_max_batch_size} dynamic_batching { max_queue_delay_microseconds: ${max_queue_delay_microseconds} + priority_levels: 10 + default_priority_level: 10 } parameters [ { @@ -32,29 +34,14 @@ input [ dims: [-1] }, { - name: "prompt_speech_tokens" - data_type: TYPE_INT32 + name: "reference_wav" + data_type: TYPE_FP32 dims: [-1] - optional: true }, { - name: "prompt_speech_feat" - data_type: TYPE_FP16 - dims: [-1, 80] - optional: true - }, - { - name: "prompt_spk_embedding" - data_type: TYPE_FP16 - dims: [-1] - optional: true - }, - { - name: "token_offset" + name: "reference_wav_len" data_type: TYPE_INT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true + dims: [1] }, { name: "finalize" diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py index 6f1a836..30c3b3b 100644 --- a/runtime/triton_trtllm/offline_inference.py +++ b/runtime/triton_trtllm/offline_inference.py @@ -43,6 +43,9 @@ import soundfile as sf import s3tokenizer from functools import partial import time +import requests +import asyncio +import httpx from token2wav import CosyVoice2_Token2Wav @@ -53,6 +56,32 @@ except RuntimeError: pass +async def send_request_async(client, url, payload): + response = await client.post(url, json=payload, timeout=None) + response.raise_for_status() + response_json = response.json() + return response_json['choices'][0]['message']['content'] + + +async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k): + async with httpx.AsyncClient() as client: + tasks = [] + for chat in chats: + payload = { + "model": model_name, + "messages": chat, + "max_tokens": 2048, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "repetition_penalty": 1.1, + "stop": ["<|eos1|>", "<|eos|>"], + "stream": False, + } + tasks.append(send_request_async(client, api_base, payload)) + return await asyncio.gather(*tasks) + + def extract_speech_ids(speech_tokens_str): """Extract speech IDs from token strings like <|s_23456|>""" speech_ids = [] @@ -149,7 +178,7 @@ def get_args(): "--backend", type=str, default="hf", - choices=["hf", "trtllm", "vllm"], + choices=["hf", "trtllm", "vllm", "trtllm-serve"], help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM", ) parser.add_argument( @@ -164,6 +193,18 @@ def get_args(): default=0.6, help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)", ) + parser.add_argument( + "--openai-api-base", + type=str, + default="http://localhost:8000/v1/chat/completions", + help="OpenAI API base URL (for trtllm-serve backend)", + ) + parser.add_argument( + "--openai-model-name", + type=str, + default="trt_engines_bfloat16", + help="Model name to use with OpenAI API (for trtllm-serve backend)", + ) args = parser.parse_args() return args @@ -180,6 +221,7 @@ def data_collator(batch, tokenizer, s3_tokenizer): input_ids_list, prompt_audio_list, prompt_text_list = [], [], [] prompt_text_after_apply_template_list = [] mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], [] + chat_list = [] for _, item in enumerate(batch): audio_processing_start_time = time.time() prompt_text, target_text = ( @@ -237,6 +279,7 @@ def data_collator(batch, tokenizer, s3_tokenizer): {"role": "user", "content": full_text_list[i]}, {"role": "assistant", "content": prompt_audio_cosy2_id_str} ] + chat_list.append(chat) assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template" @@ -265,6 +308,7 @@ def data_collator(batch, tokenizer, s3_tokenizer): "audio_processing_time": total_audio_processing_time, "speech_tokenization_time": total_speech_tokenization_time, "text_tokenization_time": total_text_tokenization_time, + "chat_list": chat_list } @@ -318,6 +362,9 @@ def main(args): elif args.backend == "vllm": model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4) runner = None + elif args.backend == "trtllm-serve": + model = None + runner = None else: raise ValueError(f"Unsupported backend: {args.backend}") @@ -452,6 +499,35 @@ def main(args): print(outputs) for j, output in enumerate(outputs): outputs[j] = input_ids_list[j] + output.outputs[0].token_ids + elif args.backend == "trtllm-serve": + if args.batch_size > 1: + outputs = asyncio.run(send_batch_requests_async( + args.openai_api_base, + args.openai_model_name, + batch["chat_list"], + args.temperature, + args.top_p, + args.top_k, + )) + else: + outputs = [] + for i, chat in enumerate(batch["chat_list"]): + payload = { + "model": args.openai_model_name, + "messages": chat, + "max_tokens": 2048, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": args.top_k, + "repetition_penalty": 1.1, + "stop": ["<|eos1|>", "<|eos|>"], + "stream": False, + } + response = requests.post(args.openai_api_base, json=payload) + response.raise_for_status() + response_json = response.json() + generated_content = response_json['choices'][0]['message']['content'] + outputs.append(generated_content) llm_end_time = time.time() total_llm_time += (llm_end_time - llm_start_time) @@ -459,10 +535,21 @@ def main(args): items_for_token_2wav = [] for i in range(len(batch["ids"])): llm_post_processing_start_time = time.time() - input_length = len(batch["input_ids"][i]) - generated_ids = outputs[i][input_length:] - speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - speech_ids = extract_speech_ids(speech_tokens_str) + if args.backend == "trtllm-serve": + speech_tokens_str = outputs[i].strip().split('><') + if len(speech_tokens_str) > 1: + speech_tokens_str = [ + t if t.startswith('<') else '<' + t for t in speech_tokens_str + ] + speech_tokens_str = [ + t if t.endswith('>') else t + '>' for t in speech_tokens_str + ] + speech_ids = extract_speech_ids(speech_tokens_str) + else: + input_length = len(batch["input_ids"][i]) + generated_ids = outputs[i][input_length:] + speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + speech_ids = extract_speech_ids(speech_tokens_str) print(i, speech_ids) if len(speech_ids) == 0: print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping") @@ -558,6 +645,8 @@ if __name__ == "__main__": from tensorrt_llm.runtime import ModelRunnerCpp elif args.backend == "hf": from transformers import AutoModelForCausalLM + elif args.backend == "trtllm-serve": + pass else: raise ValueError(f"Unsupported backend: {args.backend}") main(args) diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index c0034c2..ad3407e 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -1,6 +1,6 @@ #!/bin/bash # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=1 cosyvoice_path=/workspace/CosyVoice cosyvoice_path=/workspace_yuekai/tts/CosyVoice stepaudio2_path=/workspace_yuekai/tts/Step-Audio2 @@ -16,7 +16,7 @@ trt_dtype=bfloat16 trt_weights_dir=./trt_weights_${trt_dtype} trt_engines_dir=./trt_engines_${trt_dtype} -model_repo=./model_repo_cosyvoice2 +model_repo=./model_repo_cosyvoice2_dit use_spk2info_cache=False @@ -58,40 +58,78 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then --engine_dir=$trt_engines_dir || exit 1 fi + +# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then +# echo "Creating model repository" +# rm -rf $model_repo +# mkdir -p $model_repo +# cosyvoice2_dir="cosyvoice2_dit" +# token2wav_dir="token2wav_dit" + +# cp -r ./model_repo/${cosyvoice2_dir} $model_repo +# cp -r ./model_repo/tensorrt_llm $model_repo +# cp -r ./model_repo/${token2wav_dir} $model_repo +# #if [ $use_spk2info_cache == "False" ]; then +# cp -r ./model_repo/audio_tokenizer $model_repo +# cp -r ./model_repo/speaker_embedding $model_repo +# #fi + +# ENGINE_PATH=$trt_engines_dir +# MAX_QUEUE_DELAY_MICROSECONDS=0 +# MODEL_DIR=$model_scope_model_local_dir +# LLM_TOKENIZER_DIR=$huggingface_model_local_dir +# BLS_INSTANCE_NUM=1 +# TRITON_MAX_BATCH_SIZE=16 +# DECOUPLED_MODE=True # True for streaming, False for offline +# STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav + +# python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} +# python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} +# python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 +# #if [ $use_spk2info_cache == "False" ]; then +# python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} +# python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} +# #fi +# fi + if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - echo "Creating model repository" + echo "Creating model repository async mode" rm -rf $model_repo mkdir -p $model_repo - cosyvoice2_dir="cosyvoice2" + cosyvoice2_dir="cosyvoice2_dit" + token2wav_dir="token2wav_dit" cp -r ./model_repo/${cosyvoice2_dir} $model_repo cp -r ./model_repo/tensorrt_llm $model_repo - cp -r ./model_repo/token2wav $model_repo - if [ $use_spk2info_cache == "False" ]; then + cp -r ./model_repo/${token2wav_dir} $model_repo + #if [ $use_spk2info_cache == "False" ]; then cp -r ./model_repo/audio_tokenizer $model_repo cp -r ./model_repo/speaker_embedding $model_repo - fi + #fi ENGINE_PATH=$trt_engines_dir MAX_QUEUE_DELAY_MICROSECONDS=0 MODEL_DIR=$model_scope_model_local_dir LLM_TOKENIZER_DIR=$huggingface_model_local_dir BLS_INSTANCE_NUM=4 - TRITON_MAX_BATCH_SIZE=16 + TRITON_MAX_BATCH_SIZE=32 DECOUPLED_MODE=True # True for streaming, False for offline + STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav - python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 - if [ $use_spk2info_cache == "False" ]; then + #if [ $use_spk2info_cache == "False" ]; then python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} - fi + #fi + rm -rf $model_repo/tensorrt_llm + # mv $model_repo/cosyvoice2_dit/1 $model_repo/cosyvoice2_dit/4 fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then echo "Starting Triton server" - tritonserver --model-repository $model_repo + tritonserver --model-repository $model_repo --http-port 18000 fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then @@ -112,26 +150,26 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then python3 client_grpc.py \ --server-addr localhost \ - --model-name cosyvoice2 \ + --model-name cosyvoice2_dit \ --num-tasks $num_task \ --mode $mode \ - --use-spk2info-cache $use_spk2info_cache \ --huggingface-dataset yuekai/seed_tts_cosy2 \ - --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache} + --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_no_att_cnn_cache_new fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then echo "stage 6: Offline inference benchmark" n_gpus=1 datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh - backend=trtllm # hf, trtllm, vllm + backend=trtllm-serve # hf, trtllm, vllm batch_sizes=(16 8 4 2 1) + batch_sizes=(16 8 4 2) token2wav_batch_size=1 for batch_size in ${batch_sizes[@]}; do for dataset in ${datasets[@]}; do output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size} - CUDA_VISIBLE_DEVICES=0 \ + CUDA_VISIBLE_DEVICES=1 \ python3 offline_inference.py \ --output-dir $output_dir \ --llm-model-name-or-path $huggingface_model_local_dir \ @@ -147,7 +185,31 @@ fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - python3 benchmark_streaming_token2wav.py --enable-trt + python3 streaming_inference.py +fi + + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 + +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + #! /usr/bin/env bash + curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "trt_engines_bfloat16", + "messages":[{"role": "user", "content": "Where is New York?"}, + {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}], + "max_tokens": 512, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "stop": ["<|eos1|>"], + "repetition_penalty": 1.2, + "stream": false + }' fi \ No newline at end of file diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py new file mode 100644 index 0000000..863358c --- /dev/null +++ b/runtime/triton_trtllm/streaming_inference.py @@ -0,0 +1,115 @@ +import torch +import os +import argparse +from datasets import load_dataset +from torch.utils.data import DataLoader +import numpy as np +import torchaudio +import time +from token2wav_dit import CosyVoice2_Token2Wav +import soundfile as sf + +def collate_fn(batch): + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] + prompt_speech_tokens_list, prompt_text_list = [], [] + for i, item in enumerate(batch): + generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) + audio = torch.from_numpy(item['prompt_audio']['array']).float() + prompt_audios_list.append(audio) + prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) + ids.append(item['id']) + prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens']) + prompt_text_list.append(item['prompt_text']) + + return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--enable-trt", action="store_true") + parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--output-dir", type=str, default="generated_wavs") + parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") + parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2") + return parser.parse_args() + + +def fake_generated_id_iter(generated_speech_tokens_list): + for i in range(len(generated_speech_tokens_list)): + yield generated_speech_tokens_list[i] + + + +if __name__ == "__main__": + args = get_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + dataset_name = args.dataset_name + dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) + data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) + + token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) + + flow_pre_lookahead_len = 3 + CHUNK_SIZE = 25 + OVERLAP_SIZE = 0 + + warmup_times = 3 + for _ in range(warmup_times): + start_time = time.time() + for batch in data_loader: + tts_speech_list = [] + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch + + id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0] + # if id != "unseen3_text5": + # continue + # else: + # a = torch.load("semantic_token_ids_arr_debug_871e2b90-42a7-4829-957c-b45e6a96fdb2.pt") + # generated_speech_tokens = a["semantic_token_ids_arr"] + # print(generated_speech_tokens) + assert prompt_audio_sample_rate == 16000 + + prompt_text = prompt_text_list[0] + prompt_speech_tokens = prompt_speech_tokens_list[0] + + + # generated_ids_iter = fake_generated_id_iter(generated_speech_tokens) + + semantic_token_ids_arr, token_offset = [], 0 + flow_prompt_speech_token_len = len(prompt_speech_tokens) + + buffer = generated_speech_tokens + output_wavs = [] + while True: + + if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len: + wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:] + + output_wavs.append(wavs) + + else: + wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + output_wavs.append(wavs) + break + + for i, wav in enumerate(output_wavs): + output_wavs[i] = wav.cpu().numpy().squeeze() + + + audios = output_wavs + reconstructed_audio = np.concatenate(audios) + # Save reconstructed audio + sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") + + + print(f"Saved {id}") + end_time = time.time() + + if _ == 0: + token2wav_model.speaker_cache = {} + print(f"Warmup time: {end_time - start_time} seconds") +