diff --git a/README.md b/README.md index 18fcb49..5e3cfd5 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,10 @@ ## Roadmap +- [x] 2025/08 + + - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support + - [x] 2025/07 - [x] release cosyvoice 3.0 eval set diff --git a/runtime/triton_trtllm/Dockerfile.server b/runtime/triton_trtllm/Dockerfile.server new file mode 100644 index 0000000..38494f4 --- /dev/null +++ b/runtime/triton_trtllm/Dockerfile.server @@ -0,0 +1,8 @@ +FROM nvcr.io/nvidia/tritonserver:25.06-trtllm-python-py3 +LABEL maintainer="zhangyuekai@foxmail.com" + +RUN apt-get update && apt-get install -y cmake +RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop +COPY ./requirements.txt /workspace/requirements.txt +RUN pip install -r /workspace/requirements.txt +WORKDIR /workspace \ No newline at end of file diff --git a/runtime/triton_trtllm/README.md b/runtime/triton_trtllm/README.md new file mode 100644 index 0000000..b1e091c --- /dev/null +++ b/runtime/triton_trtllm/README.md @@ -0,0 +1,91 @@ +## Best Practices for Serving CosyVoice with NVIDIA Triton Inference Server + +Thanks to the contribution from NVIDIA Yuekai Zhang. + +### Quick Start +Launch the service directly with Docker Compose: +```sh +docker compose up +``` + +### Build the Docker Image +Build the image from scratch: +```sh +docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06 +``` + +### Run a Docker Container +```sh +your_mount_dir=/mnt:/mnt +docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06 +``` + +### Understanding `run.sh` +The `run.sh` script orchestrates the entire workflow through numbered stages. + +Run a subset of stages with: +```sh +bash run.sh [service_type] +``` +- `` – stage to start from (0-5). +- `` – stage to stop after (0-5). + +Stages: +- **Stage 0** – Download the cosyvoice-2 0.5B model from HuggingFace. +- **Stage 1** – Convert the HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines. +- **Stage 2** – Create the Triton model repository and configure the model files (adjusts depending on whether `Decoupled=True/False` will be used later). +- **Stage 3** – Launch the Triton Inference Server. +- **Stage 4** – Run the single-utterance HTTP client. +- **Stage 5** – Run the gRPC benchmark client. + +### Export Models to TensorRT-LLM and Launch the Server +Inside the Docker container, prepare the models and start the Triton server by running stages 0-3: +```sh +# Runs stages 0, 1, 2, and 3 +bash run.sh 0 3 +``` +*Note: Stage 2 prepares the model repository differently depending on whether you intend to run with `Decoupled=False` or `Decoupled=True`. Rerun stage 2 if you switch the service type.* + +### Single-Utterance HTTP Client +Send a single HTTP inference request: +```sh +bash run.sh 4 4 +``` + +### Benchmark with a Dataset +Benchmark the running Triton server. Pass either `streaming` or `offline` as the third argument. +```sh +bash run.sh 5 5 + +# You can also customise parameters such as num_task and dataset split directly: +# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline] +``` +> [!TIP] +> Only offline CosyVoice TTS is currently supported. Setting the client to `streaming` simply enables NVIDIA Triton’s decoupled mode so that responses are returned as soon as they are ready. + +### Benchmark Results +Decoding on a single L20 GPU with 26 prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts) (≈221 s of audio): + +| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF | +|------|------|-------------|------------------|------------------|-----| +| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 | +| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 | +| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 | +| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 659.87 | 655.63 | 0.0891 | +| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1103.16 | 992.96 | 0.0693 | +| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1790.91 | 1668.63 | 0.0604 | + +### OpenAI-Compatible Server +To launch an OpenAI-compatible service, run: +```sh +git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git +pip install -r requirements.txt +# After the Triton service is up, start the FastAPI bridge: +python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000 +# Test with curl +bash test/test_cosyvoice.sh +``` + +### Acknowledgements +This section originates from the NVIDIA CISI project. We also provide other multimodal resources—see [mair-hub](https://github.com/nvidia-china-sae/mair-hub) for details. + diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py new file mode 100644 index 0000000..881b519 --- /dev/null +++ b/runtime/triton_trtllm/client_grpc.py @@ -0,0 +1,834 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Nvidia (authors: Yuekai Zhang) +# 2023 Recurrent.ai (authors: Songtao Shi) +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script supports to load dataset from huggingface and sends it to the server +for decoding, in parallel. + +Usage: +num_task=2 + +# For offline F5-TTS +python3 client_grpc.py \ + --server-addr localhost \ + --model-name f5_tts \ + --num-tasks $num_task \ + --huggingface-dataset yuekai/seed_tts \ + --split-name test_zh \ + --log-dir ./log_concurrent_tasks_${num_task} + +# For offline Spark-TTS-0.5B +python3 client_grpc.py \ + --server-addr localhost \ + --model-name spark_tts \ + --num-tasks $num_task \ + --huggingface-dataset yuekai/seed_tts \ + --split-name wenetspeech4tts \ + --log-dir ./log_concurrent_tasks_${num_task} +""" + +import argparse +import asyncio +import json +import queue # Added +import uuid # Added +import functools # Added + +import os +import time +import types +from pathlib import Path + +import numpy as np +import soundfile as sf +import tritonclient +import tritonclient.grpc.aio as grpcclient_aio # Renamed original import +import tritonclient.grpc as grpcclient_sync # Added sync client import +from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException + + +# --- Added UserData and callback --- +class UserData: + def __init__(self): + self._completed_requests = queue.Queue() + self._first_chunk_time = None + self._start_time = None + + def record_start_time(self): + self._start_time = time.time() + + def get_first_chunk_latency(self): + if self._first_chunk_time and self._start_time: + return self._first_chunk_time - self._start_time + return None + + +def callback(user_data, result, error): + if user_data._first_chunk_time is None and not error: + user_data._first_chunk_time = time.time() # Record time of first successful chunk + if error: + user_data._completed_requests.put(error) + else: + user_data._completed_requests.put(result) +# --- End Added UserData and callback --- + + +def write_triton_stats(stats, summary_file): + with open(summary_file, "w") as summary_f: + model_stats = stats["model_stats"] + # write a note, the log is from triton_client.get_inference_statistics(), to better human readability + summary_f.write( + "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n" + ) + summary_f.write("To learn more about the log, please refer to: \n") + summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n") + summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n") + summary_f.write( + "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n" + ) + summary_f.write( + "However, there is a trade-off between the increased queue time and the increased batch size. \n" + ) + summary_f.write( + "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n" + ) + summary_f.write( + "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n" + ) + for model_state in model_stats: + if "last_inference" not in model_state: + continue + summary_f.write(f"model name is {model_state['name']} \n") + model_inference_stats = model_state["inference_stats"] + total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9 + total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9 + total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 + total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 + summary_f.write( + f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa + ) + model_batch_stats = model_state["batch_stats"] + for batch in model_batch_stats: + batch_size = int(batch["batch_size"]) + compute_input = batch["compute_input"] + compute_output = batch["compute_output"] + compute_infer = batch["compute_infer"] + batch_count = int(compute_infer["count"]) + assert compute_infer["count"] == compute_output["count"] == compute_input["count"] + compute_infer_time_ms = int(compute_infer["ns"]) / 1e6 + compute_input_time_ms = int(compute_input["ns"]) / 1e6 + compute_output_time_ms = int(compute_output["ns"]) / 1e6 + summary_f.write( + f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa + ) + summary_f.write( + f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa + ) + summary_f.write( + f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa + ) + + +def get_args(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=8001, + help="Grpc port of the triton server, default is 8001", + ) + + parser.add_argument( + "--reference-audio", + type=str, + default=None, + help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", + ) + + parser.add_argument( + "--reference-text", + type=str, + default="", + help="", + ) + + parser.add_argument( + "--target-text", + type=str, + default="", + help="", + ) + + parser.add_argument( + "--huggingface-dataset", + type=str, + default="yuekai/seed_tts", + help="dataset name in huggingface dataset hub", + ) + + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="dataset split name, default is 'test'", + ) + + parser.add_argument( + "--manifest-path", + type=str, + default=None, + help="Path to the manifest dir which includes wav.scp trans.txt files.", + ) + + parser.add_argument( + "--model-name", + type=str, + default="f5_tts", + choices=[ + "f5_tts", + "spark_tts", + "cosyvoice2"], + help="triton model_repo module name to request", + ) + + parser.add_argument( + "--num-tasks", + type=int, + default=1, + help="Number of concurrent tasks for sending", + ) + + parser.add_argument( + "--log-interval", + type=int, + default=5, + help="Controls how frequently we print the log.", + ) + + parser.add_argument( + "--compute-wer", + action="store_true", + default=False, + help="""True to compute WER. + """, + ) + + parser.add_argument( + "--log-dir", + type=str, + required=False, + default="./tmp", + help="log directory", + ) + + # --- Added arguments --- + parser.add_argument( + "--mode", + type=str, + default="offline", + choices=["offline", "streaming"], + help="Select offline or streaming benchmark mode." + ) + parser.add_argument( + "--chunk-overlap-duration", + type=float, + default=0.1, + help="Chunk overlap duration for streaming reconstruction (in seconds)." + ) + # --- End Added arguments --- + + return parser.parse_args() + + +def load_audio(wav_path, target_sample_rate=16000): + assert target_sample_rate == 16000, "hard coding in server" + if isinstance(wav_path, dict): + waveform = wav_path["array"] + sample_rate = wav_path["sampling_rate"] + else: + waveform, sample_rate = sf.read(wav_path) + if sample_rate != target_sample_rate: + from scipy.signal import resample + + num_samples = int(len(waveform) * (target_sample_rate / sample_rate)) + waveform = resample(waveform, num_samples) + return waveform, target_sample_rate + + +def prepare_request_input_output( + protocol_client, # Can be grpcclient_aio or grpcclient_sync + waveform, + reference_text, + target_text, + sample_rate=16000, + padding_duration: int = None # Optional padding for offline mode +): + """Prepares inputs for Triton inference (offline or streaming).""" + assert len(waveform.shape) == 1, "waveform should be 1D" + lengths = np.array([[len(waveform)]], dtype=np.int32) + + # Apply padding only if padding_duration is provided (for offline) + if padding_duration: + duration = len(waveform) / sample_rate + # Estimate target duration based on text length ratio (crude estimation) + # Avoid division by zero if reference_text is empty + if reference_text: + estimated_target_duration = duration / len(reference_text) * len(target_text) + else: + estimated_target_duration = duration # Assume target duration similar to reference if no text + + # Calculate required samples based on estimated total duration + required_total_samples = padding_duration * sample_rate * ( + (int(estimated_target_duration + duration) // padding_duration) + 1 + ) + samples = np.zeros((1, required_total_samples), dtype=np.float32) + samples[0, : len(waveform)] = waveform + else: + # No padding for streaming or if padding_duration is None + samples = waveform.reshape(1, -1).astype(np.float32) + + # Common input creation logic + inputs = [ + protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)), + protocol_client.InferInput( + "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype) + ), + protocol_client.InferInput("reference_text", [1, 1], "BYTES"), + protocol_client.InferInput("target_text", [1, 1], "BYTES"), + ] + inputs[0].set_data_from_numpy(samples) + inputs[1].set_data_from_numpy(lengths) + + input_data_numpy = np.array([reference_text], dtype=object) + input_data_numpy = input_data_numpy.reshape((1, 1)) + inputs[2].set_data_from_numpy(input_data_numpy) + + input_data_numpy = np.array([target_text], dtype=object) + input_data_numpy = input_data_numpy.reshape((1, 1)) + inputs[3].set_data_from_numpy(input_data_numpy) + + outputs = [protocol_client.InferRequestedOutput("waveform")] + + return inputs, outputs + + +def run_sync_streaming_inference( + sync_triton_client: tritonclient.grpc.InferenceServerClient, + model_name: str, + inputs: list, + outputs: list, + request_id: str, + user_data: UserData, + chunk_overlap_duration: float, + save_sample_rate: int, + audio_save_path: str, +): + """Helper function to run the blocking sync streaming call.""" + start_time_total = time.time() + user_data.record_start_time() # Record start time for first chunk latency calculation + + # Establish stream + sync_triton_client.start_stream(callback=functools.partial(callback, user_data)) + + # Send request + sync_triton_client.async_stream_infer( + model_name, + inputs, + request_id=request_id, + outputs=outputs, + enable_empty_final_response=True, + ) + + # Process results + audios = [] + while True: + try: + result = user_data._completed_requests.get() # Add timeout + if isinstance(result, InferenceServerException): + print(f"Received InferenceServerException: {result}") + sync_triton_client.stop_stream() + return None, None, None # Indicate error + # Get response metadata + response = result.get_response() + final = response.parameters["triton_final_response"].bool_param + if final is True: + break + + audio_chunk = result.as_numpy("waveform").reshape(-1) + if audio_chunk.size > 0: # Only append non-empty chunks + audios.append(audio_chunk) + else: + print("Warning: received empty audio chunk.") + + except queue.Empty: + print(f"Timeout waiting for response for request id {request_id}") + sync_triton_client.stop_stream() + return None, None, None # Indicate error + + sync_triton_client.stop_stream() + end_time_total = time.time() + total_request_latency = end_time_total - start_time_total + first_chunk_latency = user_data.get_first_chunk_latency() + + # Reconstruct audio using cross-fade (from client_grpc_streaming.py) + actual_duration = 0 + if audios: + cross_fade_samples = int(chunk_overlap_duration * save_sample_rate) + fade_out = np.linspace(1, 0, cross_fade_samples) + fade_in = np.linspace(0, 1, cross_fade_samples) + reconstructed_audio = None + + # Simplified reconstruction based on client_grpc_streaming.py + if not audios: + print("Warning: No audio chunks received.") + reconstructed_audio = np.array([], dtype=np.float32) # Empty array + elif len(audios) == 1: + reconstructed_audio = audios[0] + else: + reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap + for i in range(1, len(audios)): + # Cross-fade section + cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + + audios[i - 1][-cross_fade_samples:] * fade_out) + # Middle section of the current chunk + middle_part = audios[i][cross_fade_samples:-cross_fade_samples] + # Concatenate + reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) + # Add the last part of the final chunk + reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) + + if reconstructed_audio is not None and reconstructed_audio.size > 0: + actual_duration = len(reconstructed_audio) / save_sample_rate + # Save reconstructed audio + os.makedirs(os.path.dirname(audio_save_path), exist_ok=True) + sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") + else: + print("Warning: No audio chunks received or reconstructed.") + actual_duration = 0 # Set duration to 0 if no audio + + else: + print("Warning: No audio chunks received.") + actual_duration = 0 + + return total_request_latency, first_chunk_latency, actual_duration + + +async def send_streaming( + manifest_item_list: list, + name: str, + server_url: str, # Changed from sync_triton_client + protocol_client: types.ModuleType, + log_interval: int, + model_name: str, + audio_save_dir: str = "./", + save_sample_rate: int = 16000, + chunk_overlap_duration: float = 0.1, + padding_duration: int = None, +): + total_duration = 0.0 + latency_data = [] + task_id = int(name[5:]) + sync_triton_client = None # Initialize client variable + + try: # Wrap in try...finally to ensure client closing + print(f"{name}: Initializing sync client for streaming...") + sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here + + print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.") + for i, item in enumerate(manifest_item_list): + if i % log_interval == 0: + print(f"{name}: Processing item {i}/{len(manifest_item_list)}") + + try: + waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000) + reference_text, target_text = item["reference_text"], item["target_text"] + + inputs, outputs = prepare_request_input_output( + protocol_client, + waveform, + reference_text, + target_text, + sample_rate, + padding_duration=padding_duration + ) + request_id = str(uuid.uuid4()) + user_data = UserData() + + audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") + + total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread( + run_sync_streaming_inference, + sync_triton_client, + model_name, + inputs, + outputs, + request_id, + user_data, + chunk_overlap_duration, + save_sample_rate, + audio_save_path + ) + + if total_request_latency is not None: + print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s") + latency_data.append((total_request_latency, first_chunk_latency, actual_duration)) + total_duration += actual_duration + else: + print(f"{name}: Item {i} failed.") + + except FileNotFoundError: + print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}") + except Exception as e: + print(f"Error processing item {i} ({item['target_audio_path']}): {e}") + import traceback + traceback.print_exc() + + finally: # Ensure client is closed + if sync_triton_client: + try: + print(f"{name}: Closing sync client...") + sync_triton_client.close() + except Exception as e: + print(f"{name}: Error closing sync client: {e}") + + print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s") + return total_duration, latency_data + + +async def send( + manifest_item_list: list, + name: str, + triton_client: tritonclient.grpc.aio.InferenceServerClient, + protocol_client: types.ModuleType, + log_interval: int, + model_name: str, + padding_duration: int = None, + audio_save_dir: str = "./", + save_sample_rate: int = 16000, +): + total_duration = 0.0 + latency_data = [] + task_id = int(name[5:]) + + print(f"manifest_item_list: {manifest_item_list}") + for i, item in enumerate(manifest_item_list): + if i % log_interval == 0: + print(f"{name}: {i}/{len(manifest_item_list)}") + waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000) + reference_text, target_text = item["reference_text"], item["target_text"] + + inputs, outputs = prepare_request_input_output( + protocol_client, + waveform, + reference_text, + target_text, + sample_rate, + padding_duration=padding_duration + ) + sequence_id = 100000000 + i + task_id * 10 + start = time.time() + response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs) + + audio = response.as_numpy("waveform").reshape(-1) + actual_duration = len(audio) / save_sample_rate + + end = time.time() - start + + audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") + sf.write(audio_save_path, audio, save_sample_rate, "PCM_16") + + latency_data.append((end, actual_duration)) + total_duration += actual_duration + + return total_duration, latency_data + + +def load_manifests(manifest_path): + with open(manifest_path, "r") as f: + manifest_list = [] + for line in f: + assert len(line.strip().split("|")) == 4 + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") + utt = Path(utt).stem + # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav") + if not os.path.isabs(prompt_wav): + prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav) + manifest_list.append( + { + "audio_filepath": prompt_wav, + "reference_text": prompt_text, + "target_text": gt_text, + "target_audio_path": utt, + } + ) + return manifest_list + + +def split_data(data, k): + n = len(data) + if n < k: + print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.") + k = n + + quotient = n // k + remainder = n % k + + result = [] + start = 0 + for i in range(k): + if i < remainder: + end = start + quotient + 1 + else: + end = start + quotient + + result.append(data[start:end]) + start = end + + return result + + +async def main(): + args = get_args() + url = f"{args.server_addr}:{args.server_port}" + + # --- Client Initialization based on mode --- + triton_client = None + protocol_client = None + if args.mode == "offline": + print("Initializing gRPC client for offline mode...") + # Use the async client for offline tasks + triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) + protocol_client = grpcclient_aio + elif args.mode == "streaming": + print("Initializing gRPC client for streaming mode...") + # Use the sync client for streaming tasks, handled via asyncio.to_thread + # We will create one sync client instance PER TASK inside send_streaming. + # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now + protocol_client = grpcclient_sync # protocol client for input prep + else: + raise ValueError(f"Invalid mode: {args.mode}") + # --- End Client Initialization --- + + if args.reference_audio: + args.num_tasks = 1 + args.log_interval = 1 + manifest_item_list = [ + { + "reference_text": args.reference_text, + "target_text": args.target_text, + "audio_filepath": args.reference_audio, + "target_audio_path": "test", + } + ] + elif args.huggingface_dataset: + import datasets + + dataset = datasets.load_dataset( + args.huggingface_dataset, + split=args.split_name, + trust_remote_code=True, + ) + manifest_item_list = [] + for i in range(len(dataset)): + manifest_item_list.append( + { + "audio_filepath": dataset[i]["prompt_audio"], + "reference_text": dataset[i]["prompt_text"], + "target_audio_path": dataset[i]["id"], + "target_text": dataset[i]["target_text"], + } + ) + else: + manifest_item_list = load_manifests(args.manifest_path) + + num_tasks = min(args.num_tasks, len(manifest_item_list)) + manifest_item_list = split_data(manifest_item_list, num_tasks) + + os.makedirs(args.log_dir, exist_ok=True) + tasks = [] + start_time = time.time() + for i in range(num_tasks): + # --- Task Creation based on mode --- + if args.mode == "offline": + task = asyncio.create_task( + send( + manifest_item_list[i], + name=f"task-{i}", + triton_client=triton_client, + protocol_client=protocol_client, + log_interval=args.log_interval, + model_name=args.model_name, + audio_save_dir=args.log_dir, + padding_duration=1, + save_sample_rate=16000 if args.model_name == "spark_tts" else 24000, + ) + ) + elif args.mode == "streaming": + task = asyncio.create_task( + send_streaming( + manifest_item_list[i], + name=f"task-{i}", + server_url=url, # Pass URL instead of client + protocol_client=protocol_client, + log_interval=args.log_interval, + model_name=args.model_name, + audio_save_dir=args.log_dir, + padding_duration=10, + save_sample_rate=16000 if args.model_name == "spark_tts" else 24000, + chunk_overlap_duration=args.chunk_overlap_duration, + ) + ) + # --- End Task Creation --- + tasks.append(task) + + ans_list = await asyncio.gather(*tasks) + + end_time = time.time() + elapsed = end_time - start_time + + total_duration = 0.0 + latency_data = [] + for ans in ans_list: + if ans: + total_duration += ans[0] + latency_data.extend(ans[1]) # Use extend for list of lists + else: + print("Warning: A task returned None, possibly due to an error.") + + if total_duration == 0: + print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.") + rtf = float('inf') + else: + rtf = elapsed / total_duration + + s = f"Mode: {args.mode}\n" + s += f"RTF: {rtf:.4f}\n" + s += f"total_duration: {total_duration:.3f} seconds\n" + s += f"({total_duration / 3600:.2f} hours)\n" + s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n" + + # --- Statistics Reporting based on mode --- + if latency_data: + if args.mode == "offline": + # Original offline latency calculation + latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data] + if latency_list: + latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0 + latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0 + s += f"latency_variance: {latency_variance:.2f}\n" + s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n" + s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n" + s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n" + s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n" + s += f"average_latency_ms: {latency_ms:.2f}\n" + else: + s += "No latency data collected for offline mode.\n" + + elif args.mode == "streaming": + # Calculate stats for total request latency and first chunk latency + total_latency_list = [total for (total, first, duration) in latency_data if total is not None] + first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None] + + s += "\n--- Total Request Latency ---\n" + if total_latency_list: + avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0 + variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0 + s += f"total_request_latency_variance: {variance_total_latency:.2f}\n" + s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n" + s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n" + s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n" + s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n" + s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n" + else: + s += "No total request latency data collected.\n" + + s += "\n--- First Chunk Latency ---\n" + if first_chunk_latency_list: + avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0 + variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0 + s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n" + s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n" + s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n" + s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n" + s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n" + s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n" + else: + s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n" + else: + s += "No latency data collected.\n" + # --- End Statistics Reporting --- + + print(s) + if args.manifest_path: + name = Path(args.manifest_path).stem + elif args.split_name: + name = args.split_name + elif args.reference_audio: + name = Path(args.reference_audio).stem + else: + name = "results" # Default name if no manifest/split/audio provided + with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: + f.write(s) + + # --- Statistics Fetching using temporary Async Client --- + # Use a separate async client for fetching stats regardless of mode + stats_client = None + try: + print("Initializing temporary async client for fetching stats...") + stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) + print("Fetching inference statistics...") + # Fetching for all models, filtering might be needed depending on server setup + stats = await stats_client.get_inference_statistics(model_name="", as_json=True) + print("Fetching model config...") + metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True) + + write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") + + with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: + json.dump(metadata, f, indent=4) + + except Exception as e: + print(f"Could not retrieve statistics or config: {e}") + finally: + if stats_client: + try: + print("Closing temporary async stats client...") + await stats_client.close() + except Exception as e: + print(f"Error closing async stats client: {e}") + # --- End Statistics Fetching --- + + +if __name__ == "__main__": + # asyncio.run(main()) # Use TaskGroup for better exception handling if needed + async def run_main(): + try: + await main() + except Exception as e: + print(f"An error occurred in main: {e}") + import traceback + traceback.print_exc() + + asyncio.run(run_main()) diff --git a/runtime/triton_trtllm/client_http.py b/runtime/triton_trtllm/client_http.py new file mode 100644 index 0000000..4d73e0b --- /dev/null +++ b/runtime/triton_trtllm/client_http.py @@ -0,0 +1,173 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import requests +import soundfile as sf +import json +import numpy as np +import argparse + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-url", + type=str, + default="localhost:8000", + help="Address of the server", + ) + + parser.add_argument( + "--reference-audio", + type=str, + default="../../example/prompt_audio.wav", + help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", + ) + + parser.add_argument( + "--reference-text", + type=str, + default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。", + help="", + ) + + parser.add_argument( + "--target-text", + type=str, + default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。", + help="", + ) + + parser.add_argument( + "--model-name", + type=str, + default="spark_tts", + choices=[ + "f5_tts", + "spark_tts", + "cosyvoice2"], + help="triton model_repo module name to request", + ) + + parser.add_argument( + "--output-audio", + type=str, + default="output.wav", + help="Path to save the output audio", + ) + return parser.parse_args() + + +def prepare_request( + waveform, + reference_text, + target_text, + sample_rate=16000, + padding_duration: int = None, + audio_save_dir: str = "./", +): + assert len(waveform.shape) == 1, "waveform should be 1D" + lengths = np.array([[len(waveform)]], dtype=np.int32) + if padding_duration: + # padding to nearset 10 seconds + samples = np.zeros( + ( + 1, + padding_duration + * sample_rate + * ((int(len(waveform) / sample_rate) // padding_duration) + 1), + ), + dtype=np.float32, + ) + + samples[0, : len(waveform)] = waveform + else: + samples = waveform + + samples = samples.reshape(1, -1).astype(np.float32) + + data = { + "inputs": [ + { + "name": "reference_wav", + "shape": samples.shape, + "datatype": "FP32", + "data": samples.tolist() + }, + { + "name": "reference_wav_len", + "shape": lengths.shape, + "datatype": "INT32", + "data": lengths.tolist(), + }, + { + "name": "reference_text", + "shape": [1, 1], + "datatype": "BYTES", + "data": [reference_text] + }, + { + "name": "target_text", + "shape": [1, 1], + "datatype": "BYTES", + "data": [target_text] + } + ] + } + + return data + + +if __name__ == "__main__": + args = get_args() + server_url = args.server_url + if not server_url.startswith(("http://", "https://")): + server_url = f"http://{server_url}" + + url = f"{server_url}/v2/models/{args.model_name}/infer" + waveform, sr = sf.read(args.reference_audio) + assert sr == 16000, "sample rate hardcoded in server" + + samples = np.array(waveform, dtype=np.float32) + data = prepare_request(samples, args.reference_text, args.target_text) + + rsp = requests.post( + url, + headers={"Content-Type": "application/json"}, + json=data, + verify=False, + params={"request_id": '0'} + ) + result = rsp.json() + audio = result["outputs"][0]["data"] + audio = np.array(audio, dtype=np.float32) + if args.model_name == "spark_tts": + sample_rate = 16000 + else: + sample_rate = 24000 + sf.write(args.output_audio, audio, sample_rate, "PCM_16") diff --git a/runtime/triton_trtllm/docker-compose.yml b/runtime/triton_trtllm/docker-compose.yml new file mode 100644 index 0000000..e221e56 --- /dev/null +++ b/runtime/triton_trtllm/docker-compose.yml @@ -0,0 +1,20 @@ +services: + tts: + image: soar97/triton-cosyvoice:25.06 + shm_size: '1gb' + ports: + - "8000:8000" + - "8001:8001" + - "8002:8002" + environment: + - PYTHONIOENCODING=utf-8 + - MODEL_ID=${MODEL_ID} + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['0'] + capabilities: [gpu] + command: > + /bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/FunAudioLLM/CosyVoice.git && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3" \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py new file mode 100644 index 0000000..47383e2 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py @@ -0,0 +1,97 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import json +import torch +from torch.utils.dlpack import to_dlpack + +import triton_python_backend_utils as pb_utils + +import os +import numpy as np +import s3tokenizer + +ORIGINAL_VOCAB_SIZE = 151663 + + +class TritonPythonModel: + """Triton Python model for audio tokenization. + + This model takes reference audio input and extracts semantic tokens + using s3tokenizer. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + # Parse model parameters + parameters = json.loads(args['model_config'])['parameters'] + model_params = {k: v["string_value"] for k, v in parameters.items()} + + self.device = torch.device("cuda") + model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx") + self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device) + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing tokenized outputs + """ + mels = [] + + # Process each request in batch + for request in requests: + # Extract input tensors + wav_array = pb_utils.get_input_tensor_by_name( + request, "reference_wav").as_numpy() + wav_len = pb_utils.get_input_tensor_by_name( + request, "reference_wav_len").as_numpy().item() + + wav_array = torch.from_numpy(wav_array).to(self.device) + # Prepare inputs + wav = wav_array[:, :wav_len].squeeze(0) + mels.append(s3tokenizer.log_mel_spectrogram(wav)) + + mels, mels_lens = s3tokenizer.padding(mels) + codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device)) + codes = codes.clone() + ORIGINAL_VOCAB_SIZE + + responses = [] + for i in range(len(requests)): + prompt_speech_tokens = codes[i, :codes_lens[i].item()] + prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack( + "prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) + inference_response = pb_utils.InferenceResponse( + output_tensors=[prompt_speech_tokens_tensor]) + responses.append(inference_response) + + return responses diff --git a/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt b/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt new file mode 100644 index 0000000..6d8bd0c --- /dev/null +++ b/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt @@ -0,0 +1,53 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "audio_tokenizer" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +parameters [ + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "reference_wav" + data_type: TYPE_FP32 + dims: [-1] + }, + { + name: "reference_wav_len" + data_type: TYPE_INT32 + dims: [1] + } +] +output [ + { + name: "prompt_speech_tokens" + data_type: TYPE_INT32 + dims: [-1] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py new file mode 100644 index 0000000..77a440b --- /dev/null +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py @@ -0,0 +1,346 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import math +import os +import re +from typing import Dict, List, Tuple, Optional, Union + +import numpy as np +import torch +from torch.utils.dlpack import from_dlpack, to_dlpack +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer +import torchaudio.compliance.kaldi as kaldi +import torchaudio +import onnxruntime + + +from matcha.utils.audio import mel_spectrogram + + +class TritonPythonModel: + """Triton Python model for Spark TTS. + + This model orchestrates the end-to-end TTS pipeline by coordinating + between audio tokenizer, LLM, and vocoder components. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + self.logger = pb_utils.Logger + # Parse model parameters + self.model_config = json.loads(args['model_config']) + parameters = self.model_config['parameters'] + model_params = {k: v["string_value"] for k, v in parameters.items()} + self.logger.log_info(f"model_params:{model_params}") + + # Initialize tokenizer + llm_tokenizer_dir = model_params["llm_tokenizer_dir"] + self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) + self.prompt_template = "<|sos|>{input_text}<|task_id|>" + self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>") + + self.device = torch.device("cuda") + self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) + + campplus_model = f'{model_params["model_dir"]}/campplus.onnx' + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"]) + + def forward_llm(self, input_ids): + """ + Prepares the response from the language model based on the provided + inputs. Creates a `pb_utils.InferenceRequest` object with passed + `llm_request_inputs` to send to a decoupled TensorRTLLM model. + For each response from the language model: + - Checks for errors and raise an exception if any are found. + - Extracts the "output_ids" tensor from the response. + - Determines the finish reason based on the presence of the + end-of-sequence token or reaching the maximum length. + - Appends the generated token IDs to `output_ids`. + - If the finish reason is determined, decodes the output IDs to text + and prepares the final response. + + The final response includes the generated text, finish reason, + completion tokens, prompt tokens, and total tokens. + + Parameters + ---------- + - llm_request_inputs (dict): A dictionary containing the inputs for the language model. + + Returns + ------- + - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata. + """ + # convert input_ids to numpy, with shape [1, sequence_length] + input_ids = input_ids.cpu().numpy() + max_tokens = 1024 + input_dict = { + "request_output_len": np.array([[max_tokens]], dtype=np.int32), + "end_id": np.array([[self.eos_token_id]], dtype=np.int32), + "pad_id": np.array([[self.eos_token_id]], dtype=np.int32), + "streaming": np.array([[self.decoupled]], dtype=np.bool_), + "runtime_top_p": np.array([[0.95]], dtype=np.float32), + "runtime_top_k": np.array([[50]], dtype=np.int32), + "temperature": np.array([[0.8]], dtype=np.float32), + "input_ids": input_ids, + "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), + } + + # Convert inputs to Triton tensors + input_tensor_list = [ + pb_utils.Tensor(k, v) for k, v in input_dict.items() + ] + + # Create and execute inference request + llm_request = pb_utils.InferenceRequest( + model_name="tensorrt_llm", + requested_output_names=["output_ids", "sequence_length"], + inputs=input_tensor_list, + ) + + llm_responses = llm_request.exec(decoupled=self.decoupled) + if self.decoupled: + for llm_response in llm_responses: + if llm_response.has_error(): + raise pb_utils.TritonModelException(llm_response.error().message()) + + # Extract and process output + output_ids = pb_utils.get_output_tensor_by_name( + llm_response, "output_ids").as_numpy() + seq_lens = pb_utils.get_output_tensor_by_name( + llm_response, "sequence_length").as_numpy() + + # Get actual output IDs up to the sequence length + actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + + yield actual_output_ids + else: + llm_response = llm_responses + if llm_response.has_error(): + raise pb_utils.TritonModelException(llm_response.error().message()) + + # Extract and process output + output_ids = pb_utils.get_output_tensor_by_name( + llm_response, "output_ids").as_numpy() + seq_lens = pb_utils.get_output_tensor_by_name( + llm_response, "sequence_length").as_numpy() + + # Get actual output IDs up to the sequence length + actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + + yield actual_output_ids + + def forward_audio_tokenizer(self, wav, wav_len): + """Forward pass through the audio tokenizer component. + + Args: + wav: Input waveform tensor + wav_len: Waveform length tensor + + Returns: + Tuple of global and semantic tokens + """ + inference_request = pb_utils.InferenceRequest( + model_name='audio_tokenizer', + requested_output_names=['prompt_speech_tokens'], + inputs=[wav, wav_len] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') + prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() + + return prompt_speech_tokens + + def forward_token2wav( + self, + prompt_speech_tokens: torch.Tensor, + prompt_speech_feat: torch.Tensor, + prompt_spk_embedding: torch.Tensor, + target_speech_tokens: torch.Tensor) -> torch.Tensor: + """Forward pass through the vocoder component. + + Args: + prompt_speech_tokens: Prompt speech tokens tensor + prompt_speech_feat: Prompt speech feat tensor + prompt_spk_embedding: Prompt spk embedding tensor + target_speech_tokens: Target speech tokens tensor + + Returns: + Generated waveform tensor + """ + prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) + prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) + prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) + target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) + + # Create and execute inference request + inference_request = pb_utils.InferenceRequest( + model_name='token2wav', + requested_output_names=['waveform'], + inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output waveform + waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') + waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() + + return waveform + + def parse_input(self, text, prompt_text, prompt_speech_tokens): + total_text = f"{prompt_text}{text}" + prompt = self.prompt_template.format(input_text=total_text) + input_ids = self.tokenizer.encode(prompt) + input_ids = torch.tensor([input_ids], dtype=torch.int32) + input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1) + return input_ids + + def _extract_spk_embedding(self, speech): + feat = kaldi.fbank(speech, + num_mel_bins=80, + dither=0, + sample_frequency=16000) + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = self.campplus_session.run(None, + {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() + embedding = torch.tensor([embedding]).to(self.device).half() + return embedding + + def _extract_speech_feat(self, speech): + speech_feat = mel_spectrogram( + speech, + n_fft=1920, + num_mels=80, + sampling_rate=24000, + hop_size=480, + win_size=1920, + fmin=0, + fmax=8000).squeeze( + dim=0).transpose( + 0, + 1).to( + self.device) + speech_feat = speech_feat.unsqueeze(dim=0) + return speech_feat + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated audio + """ + responses = [] + + for request in requests: + # Extract input tensors + wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + + # Process reference audio through audio tokenizer + + prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) + prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) + + wav_tensor = wav.as_numpy() + wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) + speech_feat = self._extract_speech_feat(prompt_speech_resample) + token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) + prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() + prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() + + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') + + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = target_text[0][0].decode('utf-8') + + # Prepare prompt for LLM + input_ids = self.parse_input( + text=target_text, + prompt_text=reference_text, + prompt_speech_tokens=prompt_speech_tokens, + ) + + # Generate semantic tokens with LLM + generated_ids_iter = self.forward_llm(input_ids) + + if self.decoupled: + response_sender = request.get_response_sender() + request_id = request.request_id() + generated_ids = [] + for generated_id in generated_ids_iter: + # convert the numpy array into a int32 tensor + generated_id = generated_id.tolist() + if len(generated_id) > 0: + assert len(generated_id) == 1, "Generated ID is not a single integer" + generated_ids.append(generated_id[0]) + generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device) + prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) + audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) + + # Prepare response + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + self.logger.log_info("send tritonserver_response_complete_final to end") + else: + generated_ids = next(generated_ids_iter) + generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device) + if generated_ids is None or len(generated_ids) == 0: + raise pb_utils.TritonModelException("Generated IDs is None or empty") + + prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) + audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) + + # Prepare response + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + responses.append(inference_response) + + if not self.decoupled: + return responses diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt b/runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt new file mode 100644 index 0000000..c370336 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt @@ -0,0 +1,70 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "cosyvoice2" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +model_transaction_policy { + decoupled: ${decoupled_mode} +} +parameters [ + { + key: "llm_tokenizer_dir", + value: {string_value:"${llm_tokenizer_dir}"} + }, + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "reference_wav" + data_type: TYPE_FP32 + dims: [-1] + }, + { + name: "reference_wav_len" + data_type: TYPE_INT32 + dims: [1] + }, + { + name: "reference_text" + data_type: TYPE_STRING + dims: [1] + }, + { + name: "target_text" + data_type: TYPE_STRING + dims: [1] + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: ${bls_instance_num} + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep b/runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt b/runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt new file mode 100644 index 0000000..3da7894 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt @@ -0,0 +1,857 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm" +backend: "${triton_backend}" +max_batch_size: ${triton_max_batch_size} + +model_transaction_policy { + decoupled: ${decoupled_mode} +} + +dynamic_batching { + preferred_batch_size: [ ${triton_max_batch_size} ] + max_queue_delay_microseconds: ${max_queue_delay_microseconds} + default_queue_policy: { max_queue_size: ${max_queue_size} } +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + allow_ragged_batch: true + optional: true + }, + { + name: "encoder_input_features" + data_type: ${encoder_input_features_data_type} + dims: [ -1, -1 ] + allow_ragged_batch: true + optional: true + }, + { + name: "encoder_output_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "num_return_sequences" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "draft_input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "decoder_input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "decoder_input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + reshape: { shape: [ ] } + }, + { + name: "draft_logits" + data_type: ${logits_datatype} + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "draft_acceptance_threshold" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "end_id" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "pad_id" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "bad_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "embedding_bias" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "beam_width" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_min" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_decay" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_reset_ids" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "early_stopping" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_search_diversity_rate" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "frequency_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_context_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_generation_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_perf_metrics" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "exclude_input_in_output" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "prompt_embedding_table" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "prompt_table_extra_ids" + data_type: TYPE_UINT64 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "prompt_vocab_size" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + # cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]` + { + name: "cross_attention_mask" + data_type: TYPE_BOOL + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + # Mrope param when mrope is used + { + name: "mrope_rotary_cos_sin" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + }, + { + name: "mrope_position_deltas" + data_type: TYPE_INT64 + dims: [ 1 ] + optional: true + }, + # the unique task ID for the given LoRA. + # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given. + # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`. + # If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached. + { + name: "lora_task_id" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ] + # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer + # each of the in / out tensors are first flattened and then concatenated together in the format above. + # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out. + { + name: "lora_weights" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + # module identifier (same size a first dimension of lora_weights) + # See LoraModule::ModuleType for model id mapping + # + # "attn_qkv": 0 # compbined qkv adapter + # "attn_q": 1 # q adapter + # "attn_k": 2 # k adapter + # "attn_v": 3 # v adapter + # "attn_dense": 4 # adapter for the dense layer in attention + # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection + # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection + # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate + # + # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ] + { + name: "lora_config" + data_type: TYPE_INT32 + dims: [ -1, 3 ] + optional: true + allow_ragged_batch: true + }, + { + name: "context_phase_params" + data_type: TYPE_UINT8 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + # skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama + { + name: "skip_cross_attn_blocks" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_starts" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_ends" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_priorities" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_durations_ms" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_decode_priority" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_decode_duration_ms" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "guided_decoding_guide_type" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "guided_decoding_guide" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_window_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_ngram_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_verification_set_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + }, + { + name: "sequence_length" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "cum_log_probs" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "output_log_probs" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "context_logits" + data_type: ${logits_datatype} + dims: [ -1, -1 ] + }, + { + name: "generation_logits" + data_type: ${logits_datatype} + dims: [ -1, -1, -1 ] + }, + { + name: "batch_index" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "sequence_index" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "context_phase_params" + data_type: TYPE_UINT8 + dims: [ -1 ] + }, + { + name: "kv_cache_alloc_new_blocks" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "kv_cache_reused_blocks" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "kv_cache_alloc_total_blocks" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "arrival_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_scheduled_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "last_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "acceptance_rate" + data_type: TYPE_FP32 + dims: [ 1 ] + }, + { + name: "total_accepted_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "total_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters: { + key: "max_beam_width" + value: { + string_value: "${max_beam_width}" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} +parameters: { + key: "gpt_model_type" + value: { + string_value: "${batching_strategy}" + } +} +parameters: { + key: "gpt_model_path" + value: { + string_value: "${engine_dir}" + } +} +parameters: { + key: "encoder_model_path" + value: { + string_value: "${encoder_engine_dir}" + } +} +parameters: { + key: "max_tokens_in_paged_kv_cache" + value: { + string_value: "${max_tokens_in_paged_kv_cache}" + } +} +parameters: { + key: "max_attention_window_size" + value: { + string_value: "${max_attention_window_size}" + } +} +parameters: { + key: "sink_token_length" + value: { + string_value: "${sink_token_length}" + } +} +parameters: { + key: "batch_scheduler_policy" + value: { + string_value: "${batch_scheduler_policy}" + } +} +parameters: { + key: "kv_cache_free_gpu_mem_fraction" + value: { + string_value: "${kv_cache_free_gpu_mem_fraction}" + } +} +parameters: { + key: "cross_kv_cache_fraction" + value: { + string_value: "${cross_kv_cache_fraction}" + } +} +parameters: { + key: "kv_cache_host_memory_bytes" + value: { + string_value: "${kv_cache_host_memory_bytes}" + } +} +# kv_cache_onboard_blocks is for internal implementation. +parameters: { + key: "kv_cache_onboard_blocks" + value: { + string_value: "${kv_cache_onboard_blocks}" + } +} +# enable_trt_overlap is deprecated and doesn't have any effect on the runtime +# parameters: { +# key: "enable_trt_overlap" +# value: { +# string_value: "${enable_trt_overlap}" +# } +# } +parameters: { + key: "exclude_input_in_output" + value: { + string_value: "${exclude_input_in_output}" + } +} +parameters: { + key: "cancellation_check_period_ms" + value: { + string_value: "${cancellation_check_period_ms}" + } +} +parameters: { + key: "stats_check_period_ms" + value: { + string_value: "${stats_check_period_ms}" + } +} +parameters: { + key: "iter_stats_max_iterations" + value: { + string_value: "${iter_stats_max_iterations}" + } +} +parameters: { + key: "request_stats_max_iterations" + value: { + string_value: "${request_stats_max_iterations}" + } +} +parameters: { + key: "enable_kv_cache_reuse" + value: { + string_value: "${enable_kv_cache_reuse}" + } +} +parameters: { + key: "normalize_log_probs" + value: { + string_value: "${normalize_log_probs}" + } +} +parameters: { + key: "enable_chunked_context" + value: { + string_value: "${enable_chunked_context}" + } +} +parameters: { + key: "gpu_device_ids" + value: { + string_value: "${gpu_device_ids}" + } +} +parameters: { + key: "participant_ids" + value: { + string_value: "${participant_ids}" + } +} +parameters: { + key: "lora_cache_optimal_adapter_size" + value: { + string_value: "${lora_cache_optimal_adapter_size}" + } +} +parameters: { + key: "lora_cache_max_adapter_size" + value: { + string_value: "${lora_cache_max_adapter_size}" + } +} +parameters: { + key: "lora_cache_gpu_memory_fraction" + value: { + string_value: "${lora_cache_gpu_memory_fraction}" + } +} +parameters: { + key: "lora_cache_host_memory_bytes" + value: { + string_value: "${lora_cache_host_memory_bytes}" + } +} +parameters: { + key: "lora_prefetch_dir" + value: { + string_value: "${lora_prefetch_dir}" + } +} +parameters: { + key: "decoding_mode" + value: { + string_value: "${decoding_mode}" + } +} +parameters: { + key: "executor_worker_path" + value: { + string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker" + } +} +parameters: { + key: "lookahead_window_size" + value: { + string_value: "${lookahead_window_size}" + } +} +parameters: { + key: "lookahead_ngram_size" + value: { + string_value: "${lookahead_ngram_size}" + } +} +parameters: { + key: "lookahead_verification_set_size" + value: { + string_value: "${lookahead_verification_set_size}" + } +} +parameters: { + key: "medusa_choices" + value: { + string_value: "${medusa_choices}" + } +} +parameters: { + key: "eagle_choices" + value: { + string_value: "${eagle_choices}" + } +} +parameters: { + key: "gpu_weights_percent" + value: { + string_value: "${gpu_weights_percent}" + } +} +parameters: { + key: "enable_context_fmha_fp32_acc" + value: { + string_value: "${enable_context_fmha_fp32_acc}" + } +} +parameters: { + key: "multi_block_mode" + value: { + string_value: "${multi_block_mode}" + } +} +parameters: { + key: "cuda_graph_mode" + value: { + string_value: "${cuda_graph_mode}" + } +} +parameters: { + key: "cuda_graph_cache_size" + value: { + string_value: "${cuda_graph_cache_size}" + } +} +parameters: { + key: "speculative_decoding_fast_logits" + value: { + string_value: "${speculative_decoding_fast_logits}" + } +} +parameters: { + key: "tokenizer_dir" + value: { + string_value: "${tokenizer_dir}" + } +} +parameters: { + key: "guided_decoding_backend" + value: { + string_value: "${guided_decoding_backend}" + } +} +parameters: { + key: "xgrammar_tokenizer_info_path" + value: { + string_value: "${xgrammar_tokenizer_info_path}" + } +} diff --git a/runtime/triton_trtllm/model_repo/token2wav/1/model.py b/runtime/triton_trtllm/model_repo/token2wav/1/model.py new file mode 100644 index 0000000..d38f8a4 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/token2wav/1/model.py @@ -0,0 +1,195 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import os + +import logging +from typing import List, Dict + +import torch +from torch.utils.dlpack import to_dlpack + +import triton_python_backend_utils as pb_utils + +from hyperpyyaml import load_hyperpyyaml +from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm +from cosyvoice.utils.common import TrtContextWrapper + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +ORIGINAL_VOCAB_SIZE = 151663 + + +class CosyVoice2: + + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): + + self.model_dir = model_dir + self.fp16 = fp16 + + hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir) + if not os.path.exists(hyper_yaml_path): + raise ValueError('{} not found!'.format(hyper_yaml_path)) + with open(hyper_yaml_path, 'r') as f: + configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) + self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16) + self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) + if load_jit: + self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) + if load_trt: + self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), + '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + trt_concurrent, + self.fp16) + + +class CosyVoice2Model: + + def __init__(self, + flow: torch.nn.Module, + hift: torch.nn.Module, + fp16: bool = False): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.flow = flow + self.hift = hift + self.fp16 = fp16 + if self.fp16 is True: + self.flow.half() + + def load_jit(self, flow_encoder_model): + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + self.flow.encoder = flow_encoder + + def load(self, flow_model, hift_model): + self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) + self.flow.to(self.device).eval() + # in case hift_model is a hifigan model + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() + + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) + del self.flow.decoder.estimator + import tensorrt as trt + with open(flow_decoder_estimator_model, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_trt_kwargs(self): + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] + opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)] + max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] + input_names = ["x", "mask", "mu", "cond"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + +class TritonPythonModel: + """Triton Python model for vocoder. + + This model takes global and semantic tokens as input and generates audio waveforms + using the BiCodec vocoder. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + # Parse model parameters + parameters = json.loads(args['model_config'])['parameters'] + model_params = {key: value["string_value"] for key, value in parameters.items()} + model_dir = model_params["model_dir"] + + # Initialize device and vocoder + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Initializing vocoder from {model_dir} on {self.device}") + + self.token2wav_model = CosyVoice2( + model_dir, load_jit=True, load_trt=True, fp16=True + ) + + logger.info("Token2Wav initialized successfully") + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated waveforms + """ + responses = [] + # Process each request in batch + for request in requests: + target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() + prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy() + prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy() + prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy() + + target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device) + prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device) + prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device) + prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device) + + # shift the speech tokens according to the original vocab size + prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE + target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE + + tts_mel, _ = self.token2wav_model.model.flow.inference( + token=target_speech_tokens, + token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( + self.device + ), + prompt_token=prompt_speech_tokens, + prompt_token_len=torch.tensor( + [prompt_speech_tokens.shape[1]], dtype=torch.int32 + ).to(self.device), + prompt_feat=prompt_speech_feat, + prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=prompt_spk_embedding, + streaming=False, + finalize=True, + ) + + audio_hat, _ = self.token2wav_model.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) + + generated_wave = audio_hat.squeeze(0).cpu().numpy() + + wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) + inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) + responses.append(inference_response) + + return responses diff --git a/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt new file mode 100644 index 0000000..36489ff --- /dev/null +++ b/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt @@ -0,0 +1,63 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "token2wav" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +parameters [ + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "target_speech_tokens" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "prompt_speech_tokens" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "prompt_speech_feat" + data_type: TYPE_FP16 + dims: [-1, 80] + }, + { + name: "prompt_spk_embedding" + data_type: TYPE_FP16 + dims: [-1] + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/runtime/triton_trtllm/requirements.txt b/runtime/triton_trtllm/requirements.txt new file mode 100644 index 0000000..ba0623b --- /dev/null +++ b/runtime/triton_trtllm/requirements.txt @@ -0,0 +1,14 @@ +hyperpyyaml +s3tokenizer +onnxruntime-gpu +omegaconf +conformer +hydra-core +lightning +gdown +wget +librosa +pyworld +openai-whisper +tritonclient +modelscope diff --git a/runtime/triton_trtllm/run.sh b/runtime/triton_trtllm/run.sh new file mode 100644 index 0000000..2e81896 --- /dev/null +++ b/runtime/triton_trtllm/run.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) +export CUDA_VISIBLE_DEVICES=0 +cosyvoice_path=/workspace/CosyVoice +export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH +export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH +stage=$1 +stop_stage=$2 + +huggingface_model_local_dir=./cosyvoice2_llm +model_scope_model_local_dir=./CosyVoice2-0.5B +trt_dtype=bfloat16 +trt_weights_dir=./trt_weights_${trt_dtype} +trt_engines_dir=./trt_engines_${trt_dtype} + +model_repo=./model_repo_cosyvoice2 + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + echo "Cloning CosyVoice" + git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path + cd $cosyvoice_path + git submodule update --init --recursive + cd runtime/triton_trtllm +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + echo "Downloading CosyVoice2-0.5B" + huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm + modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir +fi + + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + echo "Converting checkpoint to TensorRT weights" + python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \ + --output_dir $trt_weights_dir \ + --dtype $trt_dtype || exit 1 + + echo "Building TensorRT engines" + trtllm-build --checkpoint_dir $trt_weights_dir \ + --output_dir $trt_engines_dir \ + --max_batch_size 16 \ + --max_num_tokens 32768 \ + --gemm_plugin $trt_dtype || exit 1 + + echo "Testing TensorRT engines" + python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \ + --tokenizer_dir $huggingface_model_local_dir \ + --top_k 50 --top_p 0.95 --temperature 0.8 \ + --engine_dir=$trt_engines_dir || exit 1 +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + echo "Creating model repository" + rm -rf $model_repo + mkdir -p $model_repo + cosyvoice2_dir="cosyvoice2" + + cp -r ./model_repo/${cosyvoice2_dir} $model_repo + cp -r ./model_repo/audio_tokenizer $model_repo + cp -r ./model_repo/tensorrt_llm $model_repo + cp -r ./model_repo/token2wav $model_repo + + ENGINE_PATH=$trt_engines_dir + MAX_QUEUE_DELAY_MICROSECONDS=0 + MODEL_DIR=$model_scope_model_local_dir + LLM_TOKENIZER_DIR=$huggingface_model_local_dir + BLS_INSTANCE_NUM=4 + TRITON_MAX_BATCH_SIZE=16 + DECOUPLED_MODE=False + + python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 + +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + echo "Starting Triton server" + tritonserver --model-repository $model_repo +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + echo "Single request test http" + python3 client_http.py \ + --reference-audio ./assets/prompt_audio.wav \ + --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ + --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \ + --model-name cosyvoice2 +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + echo "Running benchmark client grpc" + num_task=4 + # set mode=streaming, when decoupled=True + # set mode=offline, when decoupled=False + mode=offline + python3 client_grpc.py \ + --server-addr localhost \ + --model-name cosyvoice2 \ + --num-tasks $num_task \ + --mode $mode \ + --huggingface-dataset yuekai/seed_tts_cosy2 \ + --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_4_${trt_dtype} +fi \ No newline at end of file diff --git a/runtime/triton_trtllm/scripts/convert_checkpoint.py b/runtime/triton_trtllm/scripts/convert_checkpoint.py new file mode 100644 index 0000000..7cd166f --- /dev/null +++ b/runtime/triton_trtllm/scripts/convert_checkpoint.py @@ -0,0 +1,330 @@ +import argparse +import os +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed + +from transformers import AutoConfig + +import tensorrt_llm +from tensorrt_llm._utils import release_gc +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models import QWenForCausalLM +from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization import QuantAlgo + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_dir', type=str, default=None, required=True) + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + parser.add_argument('--cp_size', + type=int, + default=1, + help='N-way context parallelism size') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'float16', 'bfloat16', 'float32'], + help="The data type for the model weights and activations if not quantized. " + "If 'auto', the data type is automatically inferred from the source model; " + "however, if the source dtype is float32, it is converted to float16.") + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--disable_weight_only_quant_plugin', + default=False, + action="store_true", + help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8', 'int4', 'int4_gptq'], + help='Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument( + '--calib_dataset', + type=str, + default='ccdv/cnn_dailymail', + help="The huggingface dataset name or the local directory of the dataset for calibration." + ) + parser.add_argument( + "--smoothquant", + "-sq", + type=float, + default=None, + help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" + " to Smoothquant the model, and output int8 weights." + " A good first try is 0.5. Must be in [0, 1]") + parser.add_argument( + '--per_channel', + action="store_true", + default=False, + help='By default, we use a single static scaling factor for the GEMM\'s result. ' + 'per_channel instead uses a different static scaling factor for each channel. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + '--per_token', + action="store_true", + default=False, + help='By default, we use a single static scaling factor to scale activations in the int8 range. ' + 'per_token chooses at run time, and for each token, a custom scaling factor. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + '--int8_kv_cache', + default=False, + action="store_true", + help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' + ) + parser.add_argument( + '--per_group', + default=False, + action="store_true", + help='By default, we use a single static scaling factor to scale weights in the int4 range. ' + 'per_group chooses at run time, and for each group, a custom scaling factor. ' + 'The flag is built for GPTQ/AWQ quantization.') + + parser.add_argument('--group_size', + type=int, + default=128, + help='Group size used in GPTQ quantization.') + + parser.add_argument("--load_model_on_cpu", action="store_true") + parser.add_argument( + '--use_parallel_embedding', + action="store_true", + default=False, + help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument( + '--workers', + type=int, + default=1, + help='The number of workers for converting checkpoint in parallel') + parser.add_argument( + '--moe_tp_size', + type=int, + default=-1, + help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE' + ) + parser.add_argument( + '--moe_ep_size', + type=int, + default=-1, + help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE' + ) + args = parser.parse_args() + return args + + +def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: + '''return config dict with quantization info based on the command line args + ''' + quant_config = QuantConfig() + if args.use_weight_only: + if args.weight_only_precision == 'int8': + quant_config.quant_algo = QuantAlgo.W8A16 + elif args.weight_only_precision == 'int4': + quant_config.quant_algo = QuantAlgo.W4A16 + elif args.smoothquant: + quant_config.smoothquant_val = args.smoothquant + if args.per_channel: + if args.per_token: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN + else: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN + else: + if args.per_token: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN + else: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN + + if args.int8_kv_cache: + quant_config.kv_cache_quant_algo = QuantAlgo.INT8 + + if args.weight_only_precision == 'int4_gptq': + quant_config.group_size = args.group_size + quant_config.has_zero_point = True + quant_config.pre_quant_scale = False + quant_config.quant_algo = QuantAlgo.W4A16_GPTQ + + return quant_config + + +def update_quant_config_from_hf(quant_config, hf_config, + override_fields) -> tuple[QuantConfig, dict]: + hf_config_dict = hf_config.to_dict() + if hf_config_dict.get('quantization_config'): + # update the quant_algo, and clamp_val. + if hf_config_dict['quantization_config'].get('quant_method') == 'awq': + logger.info( + "Load quantization configs from huggingface model_config.") + quant_config.quant_algo = QuantAlgo.W4A16_GPTQ + quant_config.group_size = hf_config_dict['quantization_config'].get( + 'group_size', 128) + quant_config.has_zero_point = hf_config_dict[ + 'quantization_config'].get('zero_point', False) + override_fields.update({"use_autoawq": True}) + elif hf_config_dict['quantization_config'].get( + 'quant_method') == 'gptq': + logger.info( + "Load quantization configs from huggingface model_config.") + desc_act = hf_config_dict['quantization_config'].get( + 'desc_act', False) + if desc_act: + raise ValueError("GPTQ with desc_act=True is not implemented!") + quant_config.quant_algo = QuantAlgo.W4A16_GPTQ + quant_config.group_size = hf_config_dict['quantization_config'].get( + 'group_size', 128) + quant_config.has_zero_point = hf_config_dict[ + 'quantization_config'].get('sym', False) + return quant_config, override_fields + + +def args_to_build_options(args): + return { + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'disable_weight_only_quant_plugin': + args.disable_weight_only_quant_plugin + } + + +def convert_and_save_hf(args): + model_dir = args.model_dir + world_size = args.tp_size * args.pp_size + # Need to convert the cli args to the kay-value pairs and override them in the generate config dict. + # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, + # before the refactor is done. + override_fields = {} + override_fields.update(args_to_build_options(args)) + quant_config = args_to_quant_config(args) + + try: + hf_config = AutoConfig.from_pretrained(model_dir, + trust_remote_code=True) + quant_config, override_fields = update_quant_config_from_hf( + quant_config, hf_config, override_fields) + except BaseException: + logger.warning("AutoConfig cannot load the huggingface config.") + + if args.smoothquant is not None or args.int8_kv_cache: + mapping = Mapping(world_size=world_size, + tp_size=args.tp_size, + pp_size=args.pp_size, + moe_tp_size=args.moe_tp_size, + moe_ep_size=args.moe_ep_size, + cp_size=args.cp_size) + QWenForCausalLM.quantize(args.model_dir, + args.output_dir, + dtype=args.dtype, + mapping=mapping, + quant_config=quant_config, + calib_dataset=args.calib_dataset, + **override_fields) + else: + + def convert_and_save_rank(args, rank): + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size, + moe_tp_size=args.moe_tp_size, + moe_ep_size=args.moe_ep_size) + qwen = QWenForCausalLM.from_hugging_face(model_dir, + args.dtype, + mapping=mapping, + quant_config=quant_config, + **override_fields) + qwen.config.mapping.cp_size = args.cp_size + qwen.config.mapping.attn_tp_size = -1 + qwen.config.mapping.attn_cp_size = -1 + qwen.config.mapping.world_size *= args.cp_size + qwen.save_checkpoint(args.output_dir, save_config=(rank == 0)) + del qwen + + execute(args.workers, [convert_and_save_rank] * world_size, args) + release_gc() + + +def execute(workers, func, args): + if workers == 1: + for rank, f in enumerate(func): + f(args, rank) + else: + with ThreadPoolExecutor(max_workers=workers) as p: + futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." + + +def main(): + print(tensorrt_llm.__version__) + args = parse_arguments() + + if (args.moe_tp_size == -1 and args.moe_ep_size == -1): + # moe default to tp-only + args.moe_tp_size = args.tp_size + args.moe_ep_size = 1 + elif (args.moe_tp_size == -1): + args.moe_tp_size = args.tp_size // args.moe_ep_size + elif (args.moe_ep_size == -1): + args.moe_ep_size = args.tp_size // args.moe_tp_size + assert (args.moe_tp_size * args.moe_ep_size == args.tp_size + ), "moe_tp_size * moe_ep_size must equal to tp_size" + + tik = time.time() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + assert args.model_dir is not None + convert_and_save_hf(args) + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Total time of converting checkpoints: {t}') + + +if __name__ == '__main__': + main() diff --git a/runtime/triton_trtllm/scripts/fill_template.py b/runtime/triton_trtllm/scripts/fill_template.py new file mode 100644 index 0000000..6e6a2bc --- /dev/null +++ b/runtime/triton_trtllm/scripts/fill_template.py @@ -0,0 +1,69 @@ +# /usr/bin/env python3 +from argparse import ArgumentParser +from string import Template + + +def split(string, delimiter): + """Split a string using delimiter. Supports escaping. + + Args: + string (str): The string to split. + delimiter (str): The delimiter to split the string with. + + Returns: + list: A list of strings. + """ + result = [] + current = "" + escape = False + for char in string: + if escape: + current += char + escape = False + elif char == delimiter: + result.append(current) + current = "" + elif char == "\\": + escape = True + else: + current += char + result.append(current) + return result + + +def main(file_path, substitutions, in_place): + with open(file_path) as f: + pbtxt = Template(f.read()) + + sub_dict = { + "max_queue_size": 0, + 'max_queue_delay_microseconds': 0, + } + for sub in split(substitutions, ","): + key, value = split(sub, ":") + sub_dict[key] = value + + assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}." + + pbtxt = pbtxt.safe_substitute(sub_dict) + + if in_place: + with open(file_path, "w") as f: + f.write(pbtxt) + else: + print(pbtxt) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("file_path", help="path of the .pbtxt to modify") + parser.add_argument( + "substitutions", + help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." + ) + parser.add_argument("--in_place", + "-i", + action="store_true", + help="do the operation in-place") + args = parser.parse_args() + main(**vars(args)) diff --git a/runtime/triton_trtllm/scripts/test_llm.py b/runtime/triton_trtllm/scripts/test_llm.py new file mode 100644 index 0000000..d52d724 --- /dev/null +++ b/runtime/triton_trtllm/scripts/test_llm.py @@ -0,0 +1,143 @@ + +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import ast +import csv +import os +from pathlib import Path +from typing import List, Optional + +import numpy as np +import torch + +import tensorrt_llm +from tensorrt_llm.logger import logger + +from tensorrt_llm.runtime import ModelRunnerCpp +from transformers import AutoTokenizer + + +def parse_arguments(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + '--input_text', + type=str, + nargs='+', + default=["Born in north-east France, Soyer trained as a"]) + parser.add_argument('--tokenizer_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") + parser.add_argument('--engine_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") + parser.add_argument('--log_level', type=str, default="debug") + parser.add_argument('--kv_cache_free_gpu_memory_fraction', type=float, default=0.6) + parser.add_argument('--temperature', type=float, default=0.8) + parser.add_argument('--top_k', type=int, default=50) + parser.add_argument('--top_p', type=float, default=0.95) + + return parser.parse_args(args=args) + + +def parse_input(tokenizer, + input_text=None, + prompt_template=None): + batch_input_ids = [] + for curr_text in input_text: + if prompt_template is not None: + curr_text = prompt_template.format(input_text=curr_text) + input_ids = tokenizer.encode( + curr_text) + batch_input_ids.append(input_ids) + + batch_input_ids = [ + torch.tensor(x, dtype=torch.int32) for x in batch_input_ids + ] + + logger.debug(f"Input token ids (batch_size = {len(batch_input_ids)}):") + for i, input_ids in enumerate(batch_input_ids): + logger.debug(f"Request {i}: {input_ids.tolist()}") + + return batch_input_ids + + +def main(args): + runtime_rank = tensorrt_llm.mpi_rank() + logger.set_level(args.log_level) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) + prompt_template = "<|sos|>{input_text}<|task_id|>" + end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") + + batch_input_ids = parse_input(tokenizer=tokenizer, + input_text=args.input_text, + prompt_template=prompt_template) + + input_lengths = [x.size(0) for x in batch_input_ids] + + runner_kwargs = dict( + engine_dir=args.engine_dir, + rank=runtime_rank, + max_output_len=1024, + enable_context_fmha_fp32_acc=False, + max_batch_size=len(batch_input_ids), + max_input_len=max(input_lengths), + kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction, + cuda_graph_mode=False, + gather_generation_logits=False, + ) + + runner = ModelRunnerCpp.from_dir(**runner_kwargs) + + with torch.no_grad(): + outputs = runner.generate( + batch_input_ids=batch_input_ids, + max_new_tokens=1024, + end_id=end_id, + pad_id=end_id, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + num_return_sequences=1, + repetition_penalty=1.1, + random_seed=42, + streaming=False, + output_sequence_lengths=True, + output_generation_logits=False, + return_dict=True, + return_all_generated_tokens=False) + torch.cuda.synchronize() + output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"] + num_output_sents, num_beams, _ = output_ids.size() + assert num_beams == 1 + beam = 0 + batch_size = len(input_lengths) + num_return_sequences = num_output_sents // batch_size + assert num_return_sequences == 1 + for i in range(batch_size * num_return_sequences): + batch_idx = i // num_return_sequences + seq_idx = i % num_return_sequences + inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist() + input_text = tokenizer.decode(inputs) + print(f'Input [Text {batch_idx}]: \"{input_text}\"') + output_begin = input_lengths[batch_idx] + output_end = sequence_lengths[i][beam] + outputs = output_ids[i][beam][output_begin:output_end].tolist() + output_text = tokenizer.decode(outputs) + print(f'Output [Text {batch_idx}]: \"{output_text}\"') + logger.debug(str(outputs)) + + +if __name__ == '__main__': + args = parse_arguments() + main(args)