diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index d0977c5..2f81786 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -28,9 +28,10 @@ import json import math import os import re -import threading import time from typing import Dict, List, Tuple, Optional, Union +import asyncio +import httpx import numpy as np import torch @@ -42,11 +43,30 @@ import torchaudio from matcha.utils.audio import mel_spectrogram +from datetime import datetime ORIGINAL_VOCAB_SIZE = 151663 torch.set_num_threads(1) +def parse_speech_token_string(response_text: str) -> List[int]: + """ + Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs. + """ + speech_tokens = response_text.strip().split('><') + if len(speech_tokens) > 1: + # Add back the missing '<' and '>' for proper parsing + speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens] + speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens] + + speech_ids = [] + for token_str in speech_tokens: + match = re.match(r'<\|s_(\d+)\|>', token_str) + if match: + speech_ids.append(int(match.group(1))) + return speech_ids + + class TritonPythonModel: """Triton Python model for Spark TTS. @@ -67,6 +87,7 @@ class TritonPythonModel: model_params = {k: v["string_value"] for k, v in parameters.items()} self.logger.log_info(f"model_params:{model_params}") self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" + # self.dynamic_chunk_strategy = "equal" self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") # Initialize tokenizer @@ -87,92 +108,86 @@ class TritonPythonModel: raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) self.default_spk_info = spk_info["001"] + self.http_client = httpx.AsyncClient() - def forward_llm(self, input_ids): + def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: + """Converts a tensor or list of speech token IDs to a string representation.""" + if isinstance(speech_tokens, torch.Tensor): + # Ensure tensor is on CPU and flattened + speech_tokens = speech_tokens.cpu().numpy().flatten().tolist() + + speech_id_str = "" + for token_id in speech_tokens: + # Convert token ID back to the speech number N + token_num = token_id - ORIGINAL_VOCAB_SIZE + speech_id_str += f"<|s_{token_num}|>" + return speech_id_str + + async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]): """ - Prepares the response from the language model based on the provided - inputs. Creates a `pb_utils.InferenceRequest` object with passed - `llm_request_inputs` to send to a decoupled TensorRTLLM model. - For each response from the language model: - - Checks for errors and raise an exception if any are found. - - Extracts the "output_ids" tensor from the response. - - Determines the finish reason based on the presence of the - end-of-sequence token or reaching the maximum length. - - Appends the generated token IDs to `output_ids`. - - If the finish reason is determined, decodes the output IDs to text - and prepares the final response. - - The final response includes the generated text, finish reason, - completion tokens, prompt tokens, and total tokens. - - Parameters - ---------- - - llm_request_inputs (dict): A dictionary containing the inputs for the language model. - - Returns - ------- - - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata. + Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response. """ - # convert input_ids to numpy, with shape [1, sequence_length] - input_ids = input_ids.cpu().numpy() - max_tokens = 750 - input_dict = { - "request_output_len": np.array([[max_tokens]], dtype=np.int32), - "end_id": np.array([[self.eos_token_id]], dtype=np.int32), - "pad_id": np.array([[self.eos_token_id]], dtype=np.int32), - "streaming": np.array([[self.decoupled]], dtype=np.bool_), - "runtime_top_p": np.array([[0.95]], dtype=np.float32), - "runtime_top_k": np.array([[50]], dtype=np.int32), - "temperature": np.array([[0.8]], dtype=np.float32), - "repetition_penalty": np.array([[1.1]], dtype=np.float32), - "random_seed": np.array([[42]], dtype=np.uint64), - "input_ids": input_ids, - "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), - } + full_text = f"{reference_text}{target_text}" + prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens) - # Convert inputs to Triton tensors - input_tensor_list = [ - pb_utils.Tensor(k, v) for k, v in input_dict.items() + chat = [ + {"role": "user", "content": full_text}, + {"role": "assistant", "content": prompt_speech_tokens_str} ] - # Create and execute inference request - llm_request = pb_utils.InferenceRequest( - model_name="tensorrt_llm", - requested_output_names=["output_ids", "sequence_length"], - inputs=input_tensor_list, - ) + payload = { + "model": "trt_engines_bfloat16", + "messages": chat, + "max_tokens": 750, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "repetition_penalty": 1.1, + "stop": ["<|eos1|>", "<|eos|>"], + "stream": True, + } - llm_responses = llm_request.exec(decoupled=self.decoupled) - if self.decoupled: - for llm_response in llm_responses: - if llm_response.has_error(): - raise pb_utils.TritonModelException(llm_response.error().message()) + api_base = "http://localhost:8000/v1/chat/completions" - # Extract and process output - output_ids = pb_utils.get_output_tensor_by_name( - llm_response, "output_ids").as_numpy() - seq_lens = pb_utils.get_output_tensor_by_name( - llm_response, "sequence_length").as_numpy() + buffer = "" + async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response: + print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}") + print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}") + response.raise_for_status() + async for line in response.aiter_lines(): + if line.startswith("data: "): + line_data = line[len("data: "):].strip() + if line_data == "[DONE]": + break + try: + json_data = json.loads(line_data) + content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") + if content: + buffer += content + print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}") + while True: + match = re.search(r"<\|s_(\d+)\|>", buffer) + if not match: + break - # Get actual output IDs up to the sequence length - actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + token_num = int(match.group(1)) + final_id = token_num + ORIGINAL_VOCAB_SIZE + yield final_id + buffer = buffer[match.end():] + except json.JSONDecodeError: + self.logger.log_info(f"Skipping non-JSON line: {line_data}") + continue - yield actual_output_ids - else: - llm_response = llm_responses - if llm_response.has_error(): - raise pb_utils.TritonModelException(llm_response.error().message()) + # Process any remaining complete tokens in the buffer after the stream ends + while True: + match = re.search(r"<\|s_(\d+)\|>", buffer) + if not match: + break + token_num = int(match.group(1)) + final_id = token_num + ORIGINAL_VOCAB_SIZE + yield final_id + buffer = buffer[match.end():] - # Extract and process output - output_ids = pb_utils.get_output_tensor_by_name( - llm_response, "output_ids").as_numpy() - seq_lens = pb_utils.get_output_tensor_by_name( - llm_response, "sequence_length").as_numpy() - - # Get actual output IDs up to the sequence length - actual_output_ids = output_ids[0][0][:seq_lens[0][0]] - - yield actual_output_ids def forward_audio_tokenizer(self, wav, wav_len): """Forward pass through the audio tokenizer component. @@ -225,7 +240,7 @@ class TritonPythonModel: return prompt_spk_embedding - def forward_token2wav( + async def forward_token2wav( self, index: int, target_speech_tokens: torch.Tensor, @@ -247,17 +262,19 @@ class TritonPythonModel: target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] - + # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav_dit', - requested_output_names=['waveform'], + requested_output_names=[ + "waveform", + ], inputs=inputs_tensor, request_id=request_id, parameters={"priority": index+1}, ) - inference_response = inference_request.exec() + inference_response = await inference_request.async_exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) @@ -267,14 +284,6 @@ class TritonPythonModel: return waveform - def parse_input(self, text, prompt_text, prompt_speech_tokens): - total_text = f"{prompt_text}{text}" - prompt = self.prompt_template.format(input_text=total_text) - input_ids = self.tokenizer.encode(prompt) - input_ids = torch.tensor([input_ids], dtype=torch.int32) - input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1) - return input_ids - def _extract_speech_feat(self, speech): speech_feat = mel_spectrogram( speech, @@ -292,106 +301,75 @@ class TritonPythonModel: speech_feat = speech_feat.unsqueeze(dim=0) return speech_feat - def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag): - for generated_ids in generated_ids_iter: - generated_ids = generated_ids.tolist() - if len(generated_ids) == 0: - break - semantic_token_ids_arr.extend(generated_ids) - llm_is_done_flag[0] = True + async def _process_request(self, request): + request_id = request.request_id() + # Extract input tensors + wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") - def execute(self, requests): - """Execute inference on the batched requests. + # Process reference audio through audio tokenizer + if wav is not None: + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) + prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - Args: - requests: List of inference requests + wav_tensor = wav.as_numpy() + wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}") + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) + speech_feat = self._extract_speech_feat(prompt_speech_resample) + token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) + prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() + prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() - Returns: - List of inference responses containing generated audio - """ - responses = [] + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') + # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) - for request in requests: - request_id = request.request_id() - # Extract input tensors - wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") + # reference_text = self.default_spk_info["prompt_text"] + # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + # prompt_speech_feat = None + # prompt_spk_embedding = None - # Process reference audio through audio tokenizer - if wav is not None: - wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") - prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) - prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) + else: + # using pre-cached reference text + assert False, "using pre-cached reference text is not supported" + reference_text = self.default_spk_info["prompt_text"] + prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + prompt_speech_feat = None + prompt_spk_embedding = None - wav_tensor = wav.as_numpy() - wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) - speech_feat = self._extract_speech_feat(prompt_speech_resample) - token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) - prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() - prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = target_text[0][0].decode('utf-8') + print(f"target_text: {target_text}, time: {datetime.now()}") - reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() - reference_text = reference_text[0][0].decode('utf-8') - # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + if self.decoupled: + response_sender = request.get_response_sender() - # reference_text = self.default_spk_info["prompt_text"] - # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - # prompt_speech_feat = None - # prompt_spk_embedding = None - - else: - assert False, "wav is None" - # using pre-cached reference text - reference_text = self.default_spk_info["prompt_text"] - prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - prompt_speech_feat = None - prompt_spk_embedding = None - - target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() - target_text = target_text[0][0].decode('utf-8') - - # Prepare prompt for LLM - input_ids = self.parse_input( - text=target_text, - prompt_text=reference_text, + semantic_token_ids_arr = [] + token_offset, chunk_index = 0, 0 + start_time = time.time() + this_token_hop_len = self.token_hop_len + print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}") + async for generated_ids in self.forward_llm_async( + target_text=target_text, + reference_text=reference_text, prompt_speech_tokens=prompt_speech_tokens, - ) - - # Generate semantic tokens with LLM - generated_ids_iter = self.forward_llm(input_ids) - - if self.decoupled: - response_sender = request.get_response_sender() - - semantic_token_ids_arr = [] - llm_is_done_flag = [False] - - llm_thread = threading.Thread( - target=self._llm_gen_thread, - args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag) - ) - - llm_thread.start() - - token_offset, chunk_index = 0, 0 - start_time = time.time() - this_token_hop_len = self.token_hop_len - + ): + if not generated_ids: + break + semantic_token_ids_arr.append(generated_ids) + print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}") while True: pending_num = len(semantic_token_ids_arr) - token_offset - - if llm_is_done_flag[0]: - break - if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) - - sub_tts_speech = self.forward_token2wav( + print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}") + sub_tts_speech = await self.forward_token2wav( chunk_index, this_tts_speech_token, request_id, wav, wav_len, False ) - + print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}") audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) @@ -401,6 +379,8 @@ class TritonPythonModel: if self.dynamic_chunk_strategy == "exponential": this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) + elif self.dynamic_chunk_strategy == "equal": + this_token_hop_len = self.token_hop_len elif self.dynamic_chunk_strategy == "time_based": # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 cost_time = time.time() - start_time @@ -420,19 +400,36 @@ class TritonPythonModel: this_token_hop_len = max(self.token_hop_len, this_token_hop_len) chunk_index += 1 else: - time.sleep(0.02) + break + + this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) + sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) - this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) - sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) - audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) - inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) - response_sender.send(inference_response) + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + self.logger.log_info("send tritonserver_response_complete_final to end") + else: + raise NotImplementedError("Decoupled mode is not supported") - llm_thread.join() - response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) - self.logger.log_info("send tritonserver_response_complete_final to end") - else: - raise NotImplementedError("Decoupled mode is not supported") + async def execute(self, requests): + """Execute inference on the batched requests. - if not self.decoupled: - return responses + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated audio + """ + tasks = [ + asyncio.create_task(self._process_request(request)) + for request in requests + ] + await asyncio.gather(*tasks) + return None + + def finalize(self): + self.logger.log_info("Finalizing CosyVoice DIT model") + if hasattr(self, "http_client"): + asyncio.run(self.http_client.aclose()) diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py deleted file mode 100644 index 2f81786..0000000 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py +++ /dev/null @@ -1,435 +0,0 @@ -# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import json -import math -import os -import re -import time -from typing import Dict, List, Tuple, Optional, Union -import asyncio -import httpx - -import numpy as np -import torch -from torch.utils.dlpack import from_dlpack, to_dlpack -import triton_python_backend_utils as pb_utils -from transformers import AutoTokenizer - -import torchaudio - - -from matcha.utils.audio import mel_spectrogram -from datetime import datetime - -ORIGINAL_VOCAB_SIZE = 151663 -torch.set_num_threads(1) - - -def parse_speech_token_string(response_text: str) -> List[int]: - """ - Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs. - """ - speech_tokens = response_text.strip().split('><') - if len(speech_tokens) > 1: - # Add back the missing '<' and '>' for proper parsing - speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens] - speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens] - - speech_ids = [] - for token_str in speech_tokens: - match = re.match(r'<\|s_(\d+)\|>', token_str) - if match: - speech_ids.append(int(match.group(1))) - return speech_ids - - -class TritonPythonModel: - """Triton Python model for Spark TTS. - - This model orchestrates the end-to-end TTS pipeline by coordinating - between audio tokenizer, LLM, and vocoder components. - """ - - def initialize(self, args): - """Initialize the model. - - Args: - args: Dictionary containing model configuration - """ - self.logger = pb_utils.Logger - # Parse model parameters - self.model_config = json.loads(args['model_config']) - parameters = self.model_config['parameters'] - model_params = {k: v["string_value"] for k, v in parameters.items()} - self.logger.log_info(f"model_params:{model_params}") - self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" - # self.dynamic_chunk_strategy = "equal" - self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") - - # Initialize tokenizer - llm_tokenizer_dir = model_params["llm_tokenizer_dir"] - self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) - self.prompt_template = "<|sos|>{input_text}<|task_id|>" - self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>") - - self.device = torch.device("cuda") - self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) - - self.token_frame_rate = 25 - self.flow_pre_lookahead_len = 3 - self.token_hop_len = 15 - - spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt") - if not os.path.exists(spk_info_path): - raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") - spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) - self.default_spk_info = spk_info["001"] - self.http_client = httpx.AsyncClient() - - def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: - """Converts a tensor or list of speech token IDs to a string representation.""" - if isinstance(speech_tokens, torch.Tensor): - # Ensure tensor is on CPU and flattened - speech_tokens = speech_tokens.cpu().numpy().flatten().tolist() - - speech_id_str = "" - for token_id in speech_tokens: - # Convert token ID back to the speech number N - token_num = token_id - ORIGINAL_VOCAB_SIZE - speech_id_str += f"<|s_{token_num}|>" - return speech_id_str - - async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]): - """ - Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response. - """ - full_text = f"{reference_text}{target_text}" - prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens) - - chat = [ - {"role": "user", "content": full_text}, - {"role": "assistant", "content": prompt_speech_tokens_str} - ] - - payload = { - "model": "trt_engines_bfloat16", - "messages": chat, - "max_tokens": 750, - "temperature": 0.8, - "top_p": 0.95, - "top_k": 50, - "repetition_penalty": 1.1, - "stop": ["<|eos1|>", "<|eos|>"], - "stream": True, - } - - api_base = "http://localhost:8000/v1/chat/completions" - - buffer = "" - async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response: - print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}") - print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}") - response.raise_for_status() - async for line in response.aiter_lines(): - if line.startswith("data: "): - line_data = line[len("data: "):].strip() - if line_data == "[DONE]": - break - try: - json_data = json.loads(line_data) - content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") - if content: - buffer += content - print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}") - while True: - match = re.search(r"<\|s_(\d+)\|>", buffer) - if not match: - break - - token_num = int(match.group(1)) - final_id = token_num + ORIGINAL_VOCAB_SIZE - yield final_id - buffer = buffer[match.end():] - except json.JSONDecodeError: - self.logger.log_info(f"Skipping non-JSON line: {line_data}") - continue - - # Process any remaining complete tokens in the buffer after the stream ends - while True: - match = re.search(r"<\|s_(\d+)\|>", buffer) - if not match: - break - token_num = int(match.group(1)) - final_id = token_num + ORIGINAL_VOCAB_SIZE - yield final_id - buffer = buffer[match.end():] - - - def forward_audio_tokenizer(self, wav, wav_len): - """Forward pass through the audio tokenizer component. - - Args: - wav: Input waveform tensor - wav_len: Waveform length tensor - - Returns: - Tuple of global and semantic tokens - """ - inference_request = pb_utils.InferenceRequest( - model_name='audio_tokenizer', - requested_output_names=['prompt_speech_tokens'], - inputs=[wav, wav_len] - ) - - inference_response = inference_request.exec() - if inference_response.has_error(): - raise pb_utils.TritonModelException(inference_response.error().message()) - - # Extract and convert output tensors - prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') - prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() - - return prompt_speech_tokens - - def forward_speaker_embedding(self, wav): - """Forward pass through the speaker embedding component. - - Args: - wav: Input waveform tensor - - Returns: - Prompt speaker embedding tensor - """ - inference_request = pb_utils.InferenceRequest( - model_name='speaker_embedding', - requested_output_names=['prompt_spk_embedding'], - inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))] - ) - - inference_response = inference_request.exec() - if inference_response.has_error(): - raise pb_utils.TritonModelException(inference_response.error().message()) - - # Extract and convert output tensors - prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding') - prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack()) - - return prompt_spk_embedding - - async def forward_token2wav( - self, - index: int, - target_speech_tokens: torch.Tensor, - request_id: str, - reference_wav: object, - reference_wav_len: object, - finalize: bool = None) -> torch.Tensor: - """Forward pass through the vocoder component. - - Args: - prompt_speech_tokens: Prompt speech tokens tensor - prompt_speech_feat: Prompt speech feat tensor - prompt_spk_embedding: Prompt spk embedding tensor - target_speech_tokens: Target speech tokens tensor - - Returns: - Generated waveform tensor - """ - target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) - finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) - inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] - - # Create and execute inference request - inference_request = pb_utils.InferenceRequest( - model_name='token2wav_dit', - requested_output_names=[ - "waveform", - ], - inputs=inputs_tensor, - request_id=request_id, - parameters={"priority": index+1}, - ) - - inference_response = await inference_request.async_exec() - if inference_response.has_error(): - raise pb_utils.TritonModelException(inference_response.error().message()) - - # Extract and convert output waveform - waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') - waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() - - return waveform - - def _extract_speech_feat(self, speech): - speech_feat = mel_spectrogram( - speech, - n_fft=1920, - num_mels=80, - sampling_rate=24000, - hop_size=480, - win_size=1920, - fmin=0, - fmax=8000).squeeze( - dim=0).transpose( - 0, - 1).to( - self.device) - speech_feat = speech_feat.unsqueeze(dim=0) - return speech_feat - - async def _process_request(self, request): - request_id = request.request_id() - # Extract input tensors - wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") - - # Process reference audio through audio tokenizer - if wav is not None: - wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") - prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) - prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - - wav_tensor = wav.as_numpy() - wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] - print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}") - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) - speech_feat = self._extract_speech_feat(prompt_speech_resample) - token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) - prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() - prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() - - reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() - reference_text = reference_text[0][0].decode('utf-8') - # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) - - # reference_text = self.default_spk_info["prompt_text"] - # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - # prompt_speech_feat = None - # prompt_spk_embedding = None - - else: - # using pre-cached reference text - assert False, "using pre-cached reference text is not supported" - reference_text = self.default_spk_info["prompt_text"] - prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - prompt_speech_feat = None - prompt_spk_embedding = None - - target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() - target_text = target_text[0][0].decode('utf-8') - print(f"target_text: {target_text}, time: {datetime.now()}") - - if self.decoupled: - response_sender = request.get_response_sender() - - semantic_token_ids_arr = [] - token_offset, chunk_index = 0, 0 - start_time = time.time() - this_token_hop_len = self.token_hop_len - print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}") - async for generated_ids in self.forward_llm_async( - target_text=target_text, - reference_text=reference_text, - prompt_speech_tokens=prompt_speech_tokens, - ): - if not generated_ids: - break - semantic_token_ids_arr.append(generated_ids) - print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}") - while True: - pending_num = len(semantic_token_ids_arr) - token_offset - if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: - this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] - this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) - print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}") - sub_tts_speech = await self.forward_token2wav( - chunk_index, - this_tts_speech_token, request_id, wav, wav_len, False - ) - print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}") - audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) - inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) - response_sender.send(inference_response) - - token_offset += this_token_hop_len - self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}") - - if self.dynamic_chunk_strategy == "exponential": - this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) - elif self.dynamic_chunk_strategy == "equal": - this_token_hop_len = self.token_hop_len - elif self.dynamic_chunk_strategy == "time_based": - # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 - cost_time = time.time() - start_time - duration = token_offset / self.token_frame_rate - if chunk_index > 0 and cost_time > 0: - avg_chunk_processing_time = cost_time / (chunk_index + 1) - if avg_chunk_processing_time > 0: - multiples = (duration - cost_time) / avg_chunk_processing_time - self.logger.log_info(f"multiples: {multiples}") - next_pending_num = len(semantic_token_ids_arr) - token_offset - if multiples > 4: - this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len - elif multiples > 2: - this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len - else: - this_token_hop_len = self.token_hop_len - this_token_hop_len = max(self.token_hop_len, this_token_hop_len) - chunk_index += 1 - else: - break - - this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) - sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) - audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) - inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) - response_sender.send(inference_response) - - response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) - self.logger.log_info("send tritonserver_response_complete_final to end") - else: - raise NotImplementedError("Decoupled mode is not supported") - - async def execute(self, requests): - """Execute inference on the batched requests. - - Args: - requests: List of inference requests - - Returns: - List of inference responses containing generated audio - """ - tasks = [ - asyncio.create_task(self._process_request(request)) - for request in requests - ] - await asyncio.gather(*tasks) - return None - - def finalize(self): - self.logger.log_info("Finalizing CosyVoice DIT model") - if hasattr(self, "http_client"): - asyncio.run(self.http_client.aclose()) diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py index 30c3b3b..d309d18 100644 --- a/runtime/triton_trtllm/offline_inference.py +++ b/runtime/triton_trtllm/offline_inference.py @@ -47,8 +47,6 @@ import requests import asyncio import httpx -from token2wav import CosyVoice2_Token2Wav - sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") try: torch.multiprocessing.set_start_method("spawn") @@ -367,7 +365,12 @@ def main(args): runner = None else: raise ValueError(f"Unsupported backend: {args.backend}") - + + if 'Step-Audio-2-mini' in args.token2wav_path: + from token2wav_dit import CosyVoice2_Token2Wav + else: + assert 'CosyVoice2-0.5B' in args.token2wav_path + from token2wav import CosyVoice2_Token2Wav token2wav_model = CosyVoice2_Token2Wav( model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank ) @@ -589,7 +592,6 @@ def main(args): t2w_prompt_audios_list, t2w_prompt_audios_sample_rate, ) - torch.cuda.synchronize() token2wav_end_time = time.time() total_token2wav_time += (token2wav_end_time - token2wav_start_time) diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index 2c19a1d..463e490 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -1,28 +1,33 @@ #!/bin/bash # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) export CUDA_VISIBLE_DEVICES=0 -cosyvoice_path=/workspace/CosyVoice +# cosyvoice_path=/workspace/CosyVoice cosyvoice_path=/workspace_yuekai/tts/CosyVoice stepaudio2_path=/workspace_yuekai/tts/Step-Audio2 + export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH + stage=$1 stop_stage=$2 -N_GPUS=2 # set the number of GPUs to use - huggingface_model_local_dir=./cosyvoice2_llm model_scope_model_local_dir=./CosyVoice2-0.5B +step_audio_model_dir=./Step-Audio-2-mini + trt_dtype=bfloat16 trt_weights_dir=./trt_weights_${trt_dtype} trt_engines_dir=./trt_engines_${trt_dtype} model_repo=./model_repo_cosyvoice2_dit - -use_spk2info_cache=False +bls_instance_num=4 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + + echo "Cloning Step-Audio2-mini" + git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt $stepaudio2_path + echo "Cloning CosyVoice" git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path cd $cosyvoice_path @@ -35,8 +40,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir - # download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings - wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt + + echo "Step-Audio2-mini" + huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini + cd $stepaudio2_path/token2wav + wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx + wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx + cd - fi @@ -60,40 +70,6 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then --engine_dir=$trt_engines_dir || exit 1 fi - -# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then -# echo "Creating model repository" -# rm -rf $model_repo -# mkdir -p $model_repo -# cosyvoice2_dir="cosyvoice2_dit" -# token2wav_dir="token2wav_dit" - -# cp -r ./model_repo/${cosyvoice2_dir} $model_repo -# cp -r ./model_repo/tensorrt_llm $model_repo -# cp -r ./model_repo/${token2wav_dir} $model_repo -# #if [ $use_spk2info_cache == "False" ]; then -# cp -r ./model_repo/audio_tokenizer $model_repo -# cp -r ./model_repo/speaker_embedding $model_repo -# #fi - -# ENGINE_PATH=$trt_engines_dir -# MAX_QUEUE_DELAY_MICROSECONDS=0 -# MODEL_DIR=$model_scope_model_local_dir -# LLM_TOKENIZER_DIR=$huggingface_model_local_dir -# BLS_INSTANCE_NUM=1 -# TRITON_MAX_BATCH_SIZE=16 -# DECOUPLED_MODE=True # True for streaming, False for offline -# STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav - -# python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} -# python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} -# python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 -# #if [ $use_spk2info_cache == "False" ]; then -# python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} -# python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} -# #fi -# fi - if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then echo "Creating model repository async mode" rm -rf $model_repo @@ -102,122 +78,75 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then token2wav_dir="token2wav_dit" cp -r ./model_repo/${cosyvoice2_dir} $model_repo - cp -r ./model_repo/tensorrt_llm $model_repo cp -r ./model_repo/${token2wav_dir} $model_repo - #if [ $use_spk2info_cache == "False" ]; then - cp -r ./model_repo/audio_tokenizer $model_repo - cp -r ./model_repo/speaker_embedding $model_repo - #fi + cp -r ./model_repo/audio_tokenizer $model_repo + cp -r ./model_repo/speaker_embedding $model_repo + ENGINE_PATH=$trt_engines_dir MAX_QUEUE_DELAY_MICROSECONDS=0 MODEL_DIR=$model_scope_model_local_dir LLM_TOKENIZER_DIR=$huggingface_model_local_dir - BLS_INSTANCE_NUM=4 + BLS_INSTANCE_NUM=$bls_instance_num TRITON_MAX_BATCH_SIZE=1 - DECOUPLED_MODE=True # True for streaming, False for offline - STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav + DECOUPLED_MODE=True + STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} - python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 - #if [ $use_spk2info_cache == "False" ]; then - python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} - python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} - #fi - rm -rf $model_repo/tensorrt_llm - # mv $model_repo/cosyvoice2_dit/1 $model_repo/cosyvoice2_dit/4 + python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - echo "Starting Triton server on $N_GPUS GPUs" - for i in $(seq 0 $(($N_GPUS - 1))); do - echo "Starting server on GPU $i" - http_port=$((19000 + $i)) - grpc_port=$((18000 + $i)) - metrics_port=$((17000 + $i)) - CUDA_VISIBLE_DEVICES=$i tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port & - done - - echo "Servers are running in the background. Press Ctrl+C to stop them and the script." - wait -fi - -if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then - echo "Starting Triton server on $N_GPUS GPUs" - N_GPUS=1 - for i in $(seq 0 $(($N_GPUS - 1))); do - echo "Starting server on GPU $i" - http_port=$((19000 + $i)) - grpc_port=$((18000 + $i)) - metrics_port=$((17000 + $i)) - CUDA_VISIBLE_DEVICES=0 tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port & - done - - echo "Servers are running in the background. Press Ctrl+C to stop them and the script." + echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve" + tritonserver --model-repository $model_repo --http-port 18000 & + mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 & wait + # Test using curl + # curl http://localhost:8000/v1/chat/completions \ + # -H "Content-Type: application/json" \ + # -d '{ + # "model": "trt_engines_bfloat16", + # "messages":[{"role": "user", "content": "Where is New York?"}, + # {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}], + # "max_tokens": 512, + # "temperature": 0.8, + # "top_p": 0.95, + # "top_k": 50, + # "stop": ["<|eos1|>"], + # "repetition_penalty": 1.2, + # "stream": false + # }' fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - echo "Single request test http, only work for offline TTS mode" - python3 client_http.py \ - --reference-audio ./assets/prompt_audio.wav \ - --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ - --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \ - --model-name cosyvoice2 + echo "Running benchmark client" + num_task=4 + mode=streaming + BLS_INSTANCE_NUM=$bls_instance_num + + python3 client_grpc.py \ + --server-addr localhost \ + --server-port 8001 \ + --model-name cosyvoice2_dit \ + --num-tasks $num_task \ + --mode $mode \ + --huggingface-dataset yuekai/seed_tts_cosy2 \ + --log-dir ./log_single_gpu_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM} + fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - echo "Running benchmark client grpc on $N_GPUS GPUs" - num_task=1 + echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script" - mode=streaming - BLS_INSTANCE_NUM=4 - - for i in $(seq 0 $(($N_GPUS - 1))); do - grpc_port=$((18000 + $i)) - echo "Running client for server on localhost:$grpc_port" - python3 client_grpc.py \ - --server-addr localhost \ - --server-port $grpc_port \ - --model-name cosyvoice2_dit \ - --num-tasks $num_task \ - --mode $mode \ - --huggingface-dataset yuekai/seed_tts_cosy2 \ - --log-dir ./log_debug_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_gpu${i} & - done - wait -fi -if [ $stage -le 50 ] && [ $stop_stage -ge 50 ]; then - echo "Running benchmark client grpc on $N_GPUS GPUs" - num_task=4 - N_GPUS=1 - mode=streaming - BLS_INSTANCE_NUM=4 - - for i in $(seq 0 $(($N_GPUS - 1))); do - grpc_port=$((18000 + $i)) - echo "Running client for server on localhost:$grpc_port" - python3 client_grpc.py \ - --server-addr localhost \ - --server-port $grpc_port \ - --model-name cosyvoice2_dit \ - --num-tasks $num_task \ - --mode $mode \ - --huggingface-dataset yuekai/seed_tts_cosy2 \ - --log-dir ./log_single_card_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM} & - done - wait -fi -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - echo "stage 6: Offline inference benchmark" - n_gpus=1 datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh - backend=trtllm-serve # hf, trtllm, vllm + backend=trtllm # hf, trtllm, vllm, trtllm-serve - batch_sizes=(16 8 4 2 1) - batch_sizes=(16 8 4 2) + batch_sizes=(16) token2wav_batch_size=1 + for batch_size in ${batch_sizes[@]}; do for dataset in ${datasets[@]}; do output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size} @@ -225,7 +154,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then python3 offline_inference.py \ --output-dir $output_dir \ --llm-model-name-or-path $huggingface_model_local_dir \ - --token2wav-path $model_scope_model_local_dir \ + --token2wav-path $step_audio_model_dir/token2wav \ --backend $backend \ --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \ --engine-dir $trt_engines_dir \ @@ -234,34 +163,13 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then done fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - - CUDA_VISIBLE_DEVICES=2 python3 streaming_inference.py --enable-trt --strategy exponential - - +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + echo "Running Step-Audio2-mini DiT Token2Wav inference using a single python script" + export CUDA_VISIBLE_DEVICES=1 + # Note: Using pre-computed cosyvoice2 tokens + python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential + # Offline Token2wav inference + # python3 token2wav_dit.py --enable-trt fi -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 - -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - #! /usr/bin/env bash - curl http://localhost:8000/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "trt_engines_bfloat16", - "messages":[{"role": "user", "content": "Where is New York?"}, - {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}], - "max_tokens": 512, - "temperature": 0.8, - "top_p": 0.95, - "top_k": 50, - "stop": ["<|eos1|>"], - "repetition_penalty": 1.2, - "stream": false - }' -fi \ No newline at end of file diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index 93c6758..026feb5 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -54,7 +54,7 @@ if __name__ == "__main__": token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) flow_pre_lookahead_len = 3 - CHUNK_SIZE = 15 + CHUNK_SIZE = 25 token_frame_rate = 25 OVERLAP_SIZE = 0 @@ -67,20 +67,12 @@ if __name__ == "__main__": ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0] - # if id != "unseen3_text5": - # continue - # else: - # a = torch.load("semantic_token_ids_arr_debug_871e2b90-42a7-4829-957c-b45e6a96fdb2.pt") - # generated_speech_tokens = a["semantic_token_ids_arr"] - # print(generated_speech_tokens) + assert prompt_audio_sample_rate == 16000 prompt_text = prompt_text_list[0] prompt_speech_tokens = prompt_speech_tokens_list[0] - - # generated_ids_iter = fake_generated_id_iter(generated_speech_tokens) - semantic_token_ids_arr, token_offset = [], 0 flow_prompt_speech_token_len = len(prompt_speech_tokens) @@ -114,14 +106,16 @@ if __name__ == "__main__": audios = output_wavs reconstructed_audio = np.concatenate(audios) - # Save reconstructed audio sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") - - print(f"Saved {id}") end_time = time.time() if _ == 0: token2wav_model.speaker_cache = {} - print(f"Warmup time: {end_time - start_time} seconds") - print(f"Total forward count: {total_forward_count}") + print(f"Warmup time: {end_time - start_time} seconds") + print("clear speaker cache") + elif _ == 1: + print(f"Cost time without speaker cache: {end_time - start_time} seconds") + else: + print(f"Cost time with speaker cache: {end_time - start_time} seconds") + print(f"Total flow matching forward calls: {total_forward_count}") \ No newline at end of file