From 5427c274e3d0aa067199256593d71283da1aac06 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Tue, 22 Jul 2025 06:50:13 -0700 Subject: [PATCH 1/6] add triton solution --- runtime/triton_trtllm/Dockerfile.server | 6 + runtime/triton_trtllm/README.md | 94 ++ runtime/triton_trtllm/client_grpc.py | 831 +++++++++++++++++ runtime/triton_trtllm/client_http.py | 169 ++++ runtime/triton_trtllm/docker-compose.yml | 20 + .../model_repo/audio_tokenizer/1/model.py | 95 ++ .../model_repo/audio_tokenizer/config.pbtxt | 53 ++ .../model_repo/cosyvoice2/1/model.py | 331 +++++++ .../model_repo/cosyvoice2/config.pbtxt | 70 ++ .../model_repo/tensorrt_llm/1/.gitkeep | 0 .../model_repo/tensorrt_llm/config.pbtxt | 857 ++++++++++++++++++ .../model_repo/token2wav/1/model.py | 198 ++++ .../model_repo/token2wav/config.pbtxt | 63 ++ runtime/triton_trtllm/requirements.txt | 13 + runtime/triton_trtllm/run.sh | 92 ++ .../scripts/convert_checkpoint.py | 342 +++++++ .../triton_trtllm/scripts/fill_template.py | 70 ++ runtime/triton_trtllm/scripts/test_llm.py | 144 +++ 18 files changed, 3448 insertions(+) create mode 100644 runtime/triton_trtllm/Dockerfile.server create mode 100644 runtime/triton_trtllm/README.md create mode 100644 runtime/triton_trtllm/client_grpc.py create mode 100644 runtime/triton_trtllm/client_http.py create mode 100644 runtime/triton_trtllm/docker-compose.yml create mode 100644 runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py create mode 100644 runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt create mode 100644 runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py create mode 100644 runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt create mode 100644 runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep create mode 100644 runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt create mode 100644 runtime/triton_trtllm/model_repo/token2wav/1/model.py create mode 100644 runtime/triton_trtllm/model_repo/token2wav/config.pbtxt create mode 100644 runtime/triton_trtllm/requirements.txt create mode 100644 runtime/triton_trtllm/run.sh create mode 100644 runtime/triton_trtllm/scripts/convert_checkpoint.py create mode 100644 runtime/triton_trtllm/scripts/fill_template.py create mode 100644 runtime/triton_trtllm/scripts/test_llm.py diff --git a/runtime/triton_trtllm/Dockerfile.server b/runtime/triton_trtllm/Dockerfile.server new file mode 100644 index 0000000..8e827b4 --- /dev/null +++ b/runtime/triton_trtllm/Dockerfile.server @@ -0,0 +1,6 @@ +FROM nvcr.io/nvidia/tritonserver:25.06-trtllm-python-py3 +RUN apt-get update && apt-get install -y cmake +RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop +COPY ./requirements.txt /workspace/requirements.txt +RUN pip install -r /workspace/requirements.txt +WORKDIR /workspace \ No newline at end of file diff --git a/runtime/triton_trtllm/README.md b/runtime/triton_trtllm/README.md new file mode 100644 index 0000000..999d698 --- /dev/null +++ b/runtime/triton_trtllm/README.md @@ -0,0 +1,94 @@ +## Nvidia Triton Inference Serving Best Practice for Spark TTS + +### Quick Start +Directly launch the service using docker compose. +```sh +docker compose up +``` + +### Build Image +Build the docker image from scratch. +```sh +docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02 +``` + +### Create Docker Container +```sh +your_mount_dir=/mnt:/mnt +docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02 +``` + +### Understanding `run.sh` + +The `run.sh` script automates various steps using stages. You can run specific stages using: +```sh +bash run.sh [service_type] +``` +- ``: The stage to begin execution from (0-5). +- ``: The stage to end execution at (0-5). +- `[service_type]`: Optional, specifies the service type ('streaming' or 'offline', defaults may apply based on script logic). Required for stages 4 and 5. + +Stages: +- **Stage 0**: Download Spark-TTS-0.5B model from HuggingFace. +- **Stage 1**: Convert HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines. +- **Stage 2**: Create the Triton model repository structure and configure model files (adjusts for streaming/offline). +- **Stage 3**: Launch the Triton Inference Server. +- **Stage 4**: Run the gRPC benchmark client. +- **Stage 5**: Run the single utterance client (gRPC for streaming, HTTP for offline). + +### Export Models to TensorRT-LLM and Launch Server +Inside the docker container, you can prepare the models and launch the Triton server by running stages 0 through 3. This involves downloading the original model, converting it to TensorRT-LLM format, building the optimized TensorRT engines, creating the necessary model repository structure for Triton, and finally starting the server. +```sh +# This runs stages 0, 1, 2, and 3 +bash run.sh 0 3 +``` +*Note: Stage 2 prepares the model repository differently based on whether you intend to run streaming or offline inference later. You might need to re-run stage 2 if switching service types.* + + +### Single Utterance Client +Run a single inference request. Specify `streaming` or `offline` as the third argument. + +**Streaming Mode (gRPC):** +```sh +bash run.sh 5 5 streaming +``` +This executes the `client_grpc.py` script with predefined example text and prompt audio in streaming mode. + +**Offline Mode (HTTP):** +```sh +bash run.sh 5 5 offline +``` + +### Benchmark using Dataset +Run the benchmark client against the running Triton server. Specify `streaming` or `offline` as the third argument. +```sh +# Run benchmark in streaming mode +bash run.sh 4 4 streaming + +# Run benchmark in offline mode +bash run.sh 4 4 offline + +# You can also customize parameters like num_task directly in client_grpc.py or via args if supported +# Example from run.sh (streaming): +# python3 client_grpc.py \ +# --server-addr localhost \ +# --model-name spark_tts \ +# --num-tasks 2 \ +# --mode streaming \ +# --log-dir ./log_concurrent_tasks_2_streaming_new + +# Example customizing dataset (requires modifying client_grpc.py or adding args): +# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --mode [streaming|offline] +``` + +### Benchmark Results +Decoding on a single L20 GPU, using 26 different prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts), total audio duration 169 secs. + +| Mode | Note | Concurrency | Avg Latency | First Chunk Latency (P50) | RTF | +|-------|-----------|-----------------------|---------|----------------|-| +| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms |-| 0.1362| +| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms |-|0.0737| +| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms |-| 0.0704| +| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 1 | 913.28 ms |210.42 ms| 0.1501 | +| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 2 | 1009.23 ms |226.08 ms |0.0862 | +| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 4 | 1793.86 ms |1017.70 ms| 0.0824 | \ No newline at end of file diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py new file mode 100644 index 0000000..19f13e2 --- /dev/null +++ b/runtime/triton_trtllm/client_grpc.py @@ -0,0 +1,831 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Nvidia (authors: Yuekai Zhang) +# 2023 Recurrent.ai (authors: Songtao Shi) +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script supports to load dataset from huggingface and sends it to the server +for decoding, in parallel. + +Usage: +num_task=2 + +# For offline F5-TTS +python3 client_grpc.py \ + --server-addr localhost \ + --model-name f5_tts \ + --num-tasks $num_task \ + --huggingface-dataset yuekai/seed_tts \ + --split-name test_zh \ + --log-dir ./log_concurrent_tasks_${num_task} + +# For offline Spark-TTS-0.5B +python3 client_grpc.py \ + --server-addr localhost \ + --model-name spark_tts \ + --num-tasks $num_task \ + --huggingface-dataset yuekai/seed_tts \ + --split-name wenetspeech4tts \ + --log-dir ./log_concurrent_tasks_${num_task} +""" + +import argparse +import asyncio +import json +import queue # Added +import uuid # Added +import functools # Added + +import os +import time +import types +from pathlib import Path + +import numpy as np +import soundfile as sf +import tritonclient +import tritonclient.grpc.aio as grpcclient_aio # Renamed original import +import tritonclient.grpc as grpcclient_sync # Added sync client import +from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException + + +# --- Added UserData and callback --- +class UserData: + def __init__(self): + self._completed_requests = queue.Queue() + self._first_chunk_time = None + self._start_time = None + + def record_start_time(self): + self._start_time = time.time() + + def get_first_chunk_latency(self): + if self._first_chunk_time and self._start_time: + return self._first_chunk_time - self._start_time + return None + +def callback(user_data, result, error): + if user_data._first_chunk_time is None and not error: + user_data._first_chunk_time = time.time() # Record time of first successful chunk + if error: + user_data._completed_requests.put(error) + else: + user_data._completed_requests.put(result) +# --- End Added UserData and callback --- + + +def write_triton_stats(stats, summary_file): + with open(summary_file, "w") as summary_f: + model_stats = stats["model_stats"] + # write a note, the log is from triton_client.get_inference_statistics(), to better human readability + summary_f.write( + "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n" + ) + summary_f.write("To learn more about the log, please refer to: \n") + summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n") + summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n") + summary_f.write( + "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n" + ) + summary_f.write( + "However, there is a trade-off between the increased queue time and the increased batch size. \n" + ) + summary_f.write( + "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n" + ) + summary_f.write( + "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n" + ) + for model_state in model_stats: + if "last_inference" not in model_state: + continue + summary_f.write(f"model name is {model_state['name']} \n") + model_inference_stats = model_state["inference_stats"] + total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9 + total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9 + total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 + total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 + summary_f.write( + f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa + ) + model_batch_stats = model_state["batch_stats"] + for batch in model_batch_stats: + batch_size = int(batch["batch_size"]) + compute_input = batch["compute_input"] + compute_output = batch["compute_output"] + compute_infer = batch["compute_infer"] + batch_count = int(compute_infer["count"]) + assert compute_infer["count"] == compute_output["count"] == compute_input["count"] + compute_infer_time_ms = int(compute_infer["ns"]) / 1e6 + compute_input_time_ms = int(compute_input["ns"]) / 1e6 + compute_output_time_ms = int(compute_output["ns"]) / 1e6 + summary_f.write( + f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa + ) + summary_f.write( + f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa + ) + summary_f.write( + f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa + ) + + +def get_args(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=8001, + help="Grpc port of the triton server, default is 8001", + ) + + parser.add_argument( + "--reference-audio", + type=str, + default=None, + help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", + ) + + parser.add_argument( + "--reference-text", + type=str, + default="", + help="", + ) + + parser.add_argument( + "--target-text", + type=str, + default="", + help="", + ) + + parser.add_argument( + "--huggingface-dataset", + type=str, + default="yuekai/seed_tts", + help="dataset name in huggingface dataset hub", + ) + + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="dataset split name, default is 'test'", + ) + + parser.add_argument( + "--manifest-path", + type=str, + default=None, + help="Path to the manifest dir which includes wav.scp trans.txt files.", + ) + + parser.add_argument( + "--model-name", + type=str, + default="f5_tts", + choices=["f5_tts", "spark_tts", "cosyvoice2"], + help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", + ) + + parser.add_argument( + "--num-tasks", + type=int, + default=1, + help="Number of concurrent tasks for sending", + ) + + parser.add_argument( + "--log-interval", + type=int, + default=5, + help="Controls how frequently we print the log.", + ) + + parser.add_argument( + "--compute-wer", + action="store_true", + default=False, + help="""True to compute WER. + """, + ) + + parser.add_argument( + "--log-dir", + type=str, + required=False, + default="./tmp", + help="log directory", + ) + + # --- Added arguments --- + parser.add_argument( + "--mode", + type=str, + default="offline", + choices=["offline", "streaming"], + help="Select offline or streaming benchmark mode." + ) + parser.add_argument( + "--chunk-overlap-duration", + type=float, + default=0.1, + help="Chunk overlap duration for streaming reconstruction (in seconds)." + ) + # --- End Added arguments --- + + return parser.parse_args() + + +def load_audio(wav_path, target_sample_rate=16000): + assert target_sample_rate == 16000, "hard coding in server" + if isinstance(wav_path, dict): + waveform = wav_path["array"] + sample_rate = wav_path["sampling_rate"] + else: + waveform, sample_rate = sf.read(wav_path) + if sample_rate != target_sample_rate: + from scipy.signal import resample + + num_samples = int(len(waveform) * (target_sample_rate / sample_rate)) + waveform = resample(waveform, num_samples) + return waveform, target_sample_rate + +def prepare_request_input_output( + protocol_client, # Can be grpcclient_aio or grpcclient_sync + waveform, + reference_text, + target_text, + sample_rate=16000, + padding_duration: int = None # Optional padding for offline mode +): + """Prepares inputs for Triton inference (offline or streaming).""" + assert len(waveform.shape) == 1, "waveform should be 1D" + lengths = np.array([[len(waveform)]], dtype=np.int32) + + # Apply padding only if padding_duration is provided (for offline) + if padding_duration: + duration = len(waveform) / sample_rate + # Estimate target duration based on text length ratio (crude estimation) + # Avoid division by zero if reference_text is empty + if reference_text: + estimated_target_duration = duration / len(reference_text) * len(target_text) + else: + estimated_target_duration = duration # Assume target duration similar to reference if no text + + # Calculate required samples based on estimated total duration + required_total_samples = padding_duration * sample_rate * ( + (int(estimated_target_duration + duration) // padding_duration) + 1 + ) + samples = np.zeros((1, required_total_samples), dtype=np.float32) + samples[0, : len(waveform)] = waveform + else: + # No padding for streaming or if padding_duration is None + samples = waveform.reshape(1, -1).astype(np.float32) + + # Common input creation logic + inputs = [ + protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)), + protocol_client.InferInput( + "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype) + ), + protocol_client.InferInput("reference_text", [1, 1], "BYTES"), + protocol_client.InferInput("target_text", [1, 1], "BYTES"), + ] + inputs[0].set_data_from_numpy(samples) + inputs[1].set_data_from_numpy(lengths) + + input_data_numpy = np.array([reference_text], dtype=object) + input_data_numpy = input_data_numpy.reshape((1, 1)) + inputs[2].set_data_from_numpy(input_data_numpy) + + input_data_numpy = np.array([target_text], dtype=object) + input_data_numpy = input_data_numpy.reshape((1, 1)) + inputs[3].set_data_from_numpy(input_data_numpy) + + outputs = [protocol_client.InferRequestedOutput("waveform")] + + return inputs, outputs + +def run_sync_streaming_inference( + sync_triton_client: tritonclient.grpc.InferenceServerClient, + model_name: str, + inputs: list, + outputs: list, + request_id: str, + user_data: UserData, + chunk_overlap_duration: float, + save_sample_rate: int, + audio_save_path: str, +): + """Helper function to run the blocking sync streaming call.""" + start_time_total = time.time() + user_data.record_start_time() # Record start time for first chunk latency calculation + + # Establish stream + sync_triton_client.start_stream(callback=functools.partial(callback, user_data)) + + # Send request + sync_triton_client.async_stream_infer( + model_name, + inputs, + request_id=request_id, + outputs=outputs, + enable_empty_final_response=True, + ) + + # Process results + audios = [] + while True: + try: + result = user_data._completed_requests.get() # Add timeout + if isinstance(result, InferenceServerException): + print(f"Received InferenceServerException: {result}") + sync_triton_client.stop_stream() + return None, None, None # Indicate error + # Get response metadata + response = result.get_response() + final = response.parameters["triton_final_response"].bool_param + if final is True: + break + + audio_chunk = result.as_numpy("waveform").reshape(-1) + if audio_chunk.size > 0: # Only append non-empty chunks + audios.append(audio_chunk) + else: + print("Warning: received empty audio chunk.") + + except queue.Empty: + print(f"Timeout waiting for response for request id {request_id}") + sync_triton_client.stop_stream() + return None, None, None # Indicate error + + sync_triton_client.stop_stream() + end_time_total = time.time() + total_request_latency = end_time_total - start_time_total + first_chunk_latency = user_data.get_first_chunk_latency() + + # Reconstruct audio using cross-fade (from client_grpc_streaming.py) + actual_duration = 0 + if audios: + cross_fade_samples = int(chunk_overlap_duration * save_sample_rate) + fade_out = np.linspace(1, 0, cross_fade_samples) + fade_in = np.linspace(0, 1, cross_fade_samples) + reconstructed_audio = None + + # Simplified reconstruction based on client_grpc_streaming.py + if not audios: + print("Warning: No audio chunks received.") + reconstructed_audio = np.array([], dtype=np.float32) # Empty array + elif len(audios) == 1: + reconstructed_audio = audios[0] + else: + reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap + for i in range(1, len(audios)): + # Cross-fade section + cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + + audios[i - 1][-cross_fade_samples:] * fade_out) + # Middle section of the current chunk + middle_part = audios[i][cross_fade_samples:-cross_fade_samples] + # Concatenate + reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) + # Add the last part of the final chunk + reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) + + if reconstructed_audio is not None and reconstructed_audio.size > 0: + actual_duration = len(reconstructed_audio) / save_sample_rate + # Save reconstructed audio + os.makedirs(os.path.dirname(audio_save_path), exist_ok=True) + sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") + else: + print("Warning: No audio chunks received or reconstructed.") + actual_duration = 0 # Set duration to 0 if no audio + + else: + print("Warning: No audio chunks received.") + actual_duration = 0 + + return total_request_latency, first_chunk_latency, actual_duration + + +async def send_streaming( + manifest_item_list: list, + name: str, + server_url: str, # Changed from sync_triton_client + protocol_client: types.ModuleType, + log_interval: int, + model_name: str, + audio_save_dir: str = "./", + save_sample_rate: int = 16000, + chunk_overlap_duration: float = 0.1, + padding_duration: int = None, +): + total_duration = 0.0 + latency_data = [] + task_id = int(name[5:]) + sync_triton_client = None # Initialize client variable + + try: # Wrap in try...finally to ensure client closing + print(f"{name}: Initializing sync client for streaming...") + sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here + + print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.") + for i, item in enumerate(manifest_item_list): + if i % log_interval == 0: + print(f"{name}: Processing item {i}/{len(manifest_item_list)}") + + try: + waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000) + reference_text, target_text = item["reference_text"], item["target_text"] + + inputs, outputs = prepare_request_input_output( + protocol_client, + waveform, + reference_text, + target_text, + sample_rate, + padding_duration=padding_duration + ) + request_id = str(uuid.uuid4()) + user_data = UserData() + + audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") + + total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread( + run_sync_streaming_inference, + sync_triton_client, + model_name, + inputs, + outputs, + request_id, + user_data, + chunk_overlap_duration, + save_sample_rate, + audio_save_path + ) + + if total_request_latency is not None: + print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s") + latency_data.append((total_request_latency, first_chunk_latency, actual_duration)) + total_duration += actual_duration + else: + print(f"{name}: Item {i} failed.") + + + except FileNotFoundError: + print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}") + except Exception as e: + print(f"Error processing item {i} ({item['target_audio_path']}): {e}") + import traceback + traceback.print_exc() + + + finally: # Ensure client is closed + if sync_triton_client: + try: + print(f"{name}: Closing sync client...") + sync_triton_client.close() + except Exception as e: + print(f"{name}: Error closing sync client: {e}") + + + print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s") + return total_duration, latency_data + +async def send( + manifest_item_list: list, + name: str, + triton_client: tritonclient.grpc.aio.InferenceServerClient, + protocol_client: types.ModuleType, + log_interval: int, + model_name: str, + padding_duration: int = None, + audio_save_dir: str = "./", + save_sample_rate: int = 16000, +): + total_duration = 0.0 + latency_data = [] + task_id = int(name[5:]) + + print(f"manifest_item_list: {manifest_item_list}") + for i, item in enumerate(manifest_item_list): + if i % log_interval == 0: + print(f"{name}: {i}/{len(manifest_item_list)}") + waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000) + reference_text, target_text = item["reference_text"], item["target_text"] + + inputs, outputs = prepare_request_input_output( + protocol_client, + waveform, + reference_text, + target_text, + sample_rate, + padding_duration=padding_duration + ) + sequence_id = 100000000 + i + task_id * 10 + start = time.time() + response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs) + + audio = response.as_numpy("waveform").reshape(-1) + actual_duration = len(audio) / save_sample_rate + + end = time.time() - start + + audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") + sf.write(audio_save_path, audio, save_sample_rate, "PCM_16") + + latency_data.append((end, actual_duration)) + total_duration += actual_duration + + return total_duration, latency_data + + +def load_manifests(manifest_path): + with open(manifest_path, "r") as f: + manifest_list = [] + for line in f: + assert len(line.strip().split("|")) == 4 + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") + utt = Path(utt).stem + # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav") + if not os.path.isabs(prompt_wav): + prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav) + manifest_list.append( + { + "audio_filepath": prompt_wav, + "reference_text": prompt_text, + "target_text": gt_text, + "target_audio_path": utt, + } + ) + return manifest_list + + +def split_data(data, k): + n = len(data) + if n < k: + print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.") + k = n + + quotient = n // k + remainder = n % k + + result = [] + start = 0 + for i in range(k): + if i < remainder: + end = start + quotient + 1 + else: + end = start + quotient + + result.append(data[start:end]) + start = end + + return result + +async def main(): + args = get_args() + url = f"{args.server_addr}:{args.server_port}" + + # --- Client Initialization based on mode --- + triton_client = None + protocol_client = None + if args.mode == "offline": + print("Initializing gRPC client for offline mode...") + # Use the async client for offline tasks + triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) + protocol_client = grpcclient_aio + elif args.mode == "streaming": + print("Initializing gRPC client for streaming mode...") + # Use the sync client for streaming tasks, handled via asyncio.to_thread + # We will create one sync client instance PER TASK inside send_streaming. + # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now + protocol_client = grpcclient_sync # protocol client for input prep + else: + raise ValueError(f"Invalid mode: {args.mode}") + # --- End Client Initialization --- + + if args.reference_audio: + args.num_tasks = 1 + args.log_interval = 1 + manifest_item_list = [ + { + "reference_text": args.reference_text, + "target_text": args.target_text, + "audio_filepath": args.reference_audio, + "target_audio_path": "test", + } + ] + elif args.huggingface_dataset: + import datasets + + dataset = datasets.load_dataset( + args.huggingface_dataset, + split=args.split_name, + trust_remote_code=True, + ) + manifest_item_list = [] + for i in range(len(dataset)): + manifest_item_list.append( + { + "audio_filepath": dataset[i]["prompt_audio"], + "reference_text": dataset[i]["prompt_text"], + "target_audio_path": dataset[i]["id"], + "target_text": dataset[i]["target_text"], + } + ) + else: + manifest_item_list = load_manifests(args.manifest_path) + + num_tasks = min(args.num_tasks, len(manifest_item_list)) + manifest_item_list = split_data(manifest_item_list, num_tasks) + + os.makedirs(args.log_dir, exist_ok=True) + tasks = [] + start_time = time.time() + for i in range(num_tasks): + # --- Task Creation based on mode --- + if args.mode == "offline": + task = asyncio.create_task( + send( + manifest_item_list[i], + name=f"task-{i}", + triton_client=triton_client, + protocol_client=protocol_client, + log_interval=args.log_interval, + model_name=args.model_name, + audio_save_dir=args.log_dir, + padding_duration=1, + save_sample_rate=16000 if args.model_name == "spark_tts" else 24000, + ) + ) + elif args.mode == "streaming": + task = asyncio.create_task( + send_streaming( + manifest_item_list[i], + name=f"task-{i}", + server_url=url, # Pass URL instead of client + protocol_client=protocol_client, + log_interval=args.log_interval, + model_name=args.model_name, + audio_save_dir=args.log_dir, + padding_duration=10, + save_sample_rate=24000 if args.model_name == "f5_tts" else 16000, + chunk_overlap_duration=args.chunk_overlap_duration, + ) + ) + # --- End Task Creation --- + tasks.append(task) + + ans_list = await asyncio.gather(*tasks) + + end_time = time.time() + elapsed = end_time - start_time + + total_duration = 0.0 + latency_data = [] + for ans in ans_list: + if ans: + total_duration += ans[0] + latency_data.extend(ans[1]) # Use extend for list of lists + else: + print("Warning: A task returned None, possibly due to an error.") + + + if total_duration == 0: + print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.") + rtf = float('inf') + else: + rtf = elapsed / total_duration + + s = f"Mode: {args.mode}\n" + s += f"RTF: {rtf:.4f}\n" + s += f"total_duration: {total_duration:.3f} seconds\n" + s += f"({total_duration / 3600:.2f} hours)\n" + s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n" + + # --- Statistics Reporting based on mode --- + if latency_data: + if args.mode == "offline": + # Original offline latency calculation + latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data] + if latency_list: + latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0 + latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0 + s += f"latency_variance: {latency_variance:.2f}\n" + s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n" + s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n" + s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n" + s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n" + s += f"average_latency_ms: {latency_ms:.2f}\n" + else: + s += "No latency data collected for offline mode.\n" + + elif args.mode == "streaming": + # Calculate stats for total request latency and first chunk latency + total_latency_list = [total for (total, first, duration) in latency_data if total is not None] + first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None] + + s += "\n--- Total Request Latency ---\n" + if total_latency_list: + avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0 + variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0 + s += f"total_request_latency_variance: {variance_total_latency:.2f}\n" + s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n" + s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n" + s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n" + s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n" + s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n" + else: + s += "No total request latency data collected.\n" + + s += "\n--- First Chunk Latency ---\n" + if first_chunk_latency_list: + avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0 + variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0 + s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n" + s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n" + s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n" + s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n" + s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n" + s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n" + else: + s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n" + else: + s += "No latency data collected.\n" + # --- End Statistics Reporting --- + + print(s) + if args.manifest_path: + name = Path(args.manifest_path).stem + elif args.split_name: + name = args.split_name + elif args.reference_audio: + name = Path(args.reference_audio).stem + else: + name = "results" # Default name if no manifest/split/audio provided + with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: + f.write(s) + + # --- Statistics Fetching using temporary Async Client --- + # Use a separate async client for fetching stats regardless of mode + stats_client = None + try: + print("Initializing temporary async client for fetching stats...") + stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) + print("Fetching inference statistics...") + # Fetching for all models, filtering might be needed depending on server setup + stats = await stats_client.get_inference_statistics(model_name="", as_json=True) + print("Fetching model config...") + metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True) + + write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") + + with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: + json.dump(metadata, f, indent=4) + + except Exception as e: + print(f"Could not retrieve statistics or config: {e}") + finally: + if stats_client: + try: + print("Closing temporary async stats client...") + await stats_client.close() + except Exception as e: + print(f"Error closing async stats client: {e}") + # --- End Statistics Fetching --- + + +if __name__ == "__main__": + # asyncio.run(main()) # Use TaskGroup for better exception handling if needed + async def run_main(): + try: + await main() + except Exception as e: + print(f"An error occurred in main: {e}") + import traceback + traceback.print_exc() + + asyncio.run(run_main()) diff --git a/runtime/triton_trtllm/client_http.py b/runtime/triton_trtllm/client_http.py new file mode 100644 index 0000000..970d417 --- /dev/null +++ b/runtime/triton_trtllm/client_http.py @@ -0,0 +1,169 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import requests +import soundfile as sf +import json +import numpy as np +import argparse + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-url", + type=str, + default="localhost:8000", + help="Address of the server", + ) + + parser.add_argument( + "--reference-audio", + type=str, + default="../../example/prompt_audio.wav", + help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", + ) + + parser.add_argument( + "--reference-text", + type=str, + default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。", + help="", + ) + + parser.add_argument( + "--target-text", + type=str, + default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。", + help="", + ) + + parser.add_argument( + "--model-name", + type=str, + default="spark_tts", + choices=[ + "f5_tts", "spark_tts", "cosyvoice2" + ], + help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", + ) + + parser.add_argument( + "--output-audio", + type=str, + default="output.wav", + help="Path to save the output audio", + ) + return parser.parse_args() + +def prepare_request( + waveform, + reference_text, + target_text, + sample_rate=16000, + padding_duration: int = None, + audio_save_dir: str = "./", +): + assert len(waveform.shape) == 1, "waveform should be 1D" + lengths = np.array([[len(waveform)]], dtype=np.int32) + if padding_duration: + # padding to nearset 10 seconds + samples = np.zeros( + ( + 1, + padding_duration + * sample_rate + * ((int(duration) // padding_duration) + 1), + ), + dtype=np.float32, + ) + + samples[0, : len(waveform)] = waveform + else: + samples = waveform + + samples = samples.reshape(1, -1).astype(np.float32) + + data = { + "inputs":[ + { + "name": "reference_wav", + "shape": samples.shape, + "datatype": "FP32", + "data": samples.tolist() + }, + { + "name": "reference_wav_len", + "shape": lengths.shape, + "datatype": "INT32", + "data": lengths.tolist(), + }, + { + "name": "reference_text", + "shape": [1, 1], + "datatype": "BYTES", + "data": [reference_text] + }, + { + "name": "target_text", + "shape": [1, 1], + "datatype": "BYTES", + "data": [target_text] + } + ] + } + + return data + +if __name__ == "__main__": + args = get_args() + server_url = args.server_url + if not server_url.startswith(("http://", "https://")): + server_url = f"http://{server_url}" + + url = f"{server_url}/v2/models/{args.model_name}/infer" + waveform, sr = sf.read(args.reference_audio) + assert sr == 16000, "sample rate hardcoded in server" + + samples = np.array(waveform, dtype=np.float32) + data = prepare_request(samples, args.reference_text, args.target_text) + + rsp = requests.post( + url, + headers={"Content-Type": "application/json"}, + json=data, + verify=False, + params={"request_id": '0'} + ) + result = rsp.json() + audio = result["outputs"][0]["data"] + audio = np.array(audio, dtype=np.float32) + if args.model_name == "cosyvoice2": + sample_rate = 24000 + else: + sample_rate = 16000 + sf.write(args.output_audio, audio, sample_rate, "PCM_16") \ No newline at end of file diff --git a/runtime/triton_trtllm/docker-compose.yml b/runtime/triton_trtllm/docker-compose.yml new file mode 100644 index 0000000..eca94bc --- /dev/null +++ b/runtime/triton_trtllm/docker-compose.yml @@ -0,0 +1,20 @@ +services: + tts: + image: soar97/triton-spark-tts:25.02 + shm_size: '1gb' + ports: + - "8000:8000" + - "8001:8001" + - "8002:8002" + environment: + - PYTHONIOENCODING=utf-8 + - MODEL_ID=${MODEL_ID} + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['0'] + capabilities: [gpu] + command: > + /bin/bash -c "rm -rf Spark-TTS && git clone https://github.com/SparkAudio/Spark-TTS.git && cd Spark-TTS/runtime/triton_trtllm && bash run.sh 0 3" diff --git a/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py new file mode 100644 index 0000000..a197ec9 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py @@ -0,0 +1,95 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import json +import torch +from torch.utils.dlpack import to_dlpack + +import triton_python_backend_utils as pb_utils + +import os +import numpy as np +import s3tokenizer + + +class TritonPythonModel: + """Triton Python model for audio tokenization. + + This model takes reference audio input and extracts semantic tokens + using s3tokenizer. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + # Parse model parameters + parameters = json.loads(args['model_config'])['parameters'] + model_params = {k: v["string_value"] for k, v in parameters.items()} + + self.device = torch.device("cuda") + model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx") + self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device) + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing tokenized outputs + """ + mels = [] + + # Process each request in batch + for request in requests: + # Extract input tensors + wav_array = pb_utils.get_input_tensor_by_name( + request, "reference_wav").as_numpy() + wav_len = pb_utils.get_input_tensor_by_name( + request, "reference_wav_len").as_numpy().item() + + wav_array = torch.from_numpy(wav_array).to(self.device) + # Prepare inputs + wav = wav_array[:, :wav_len].squeeze(0) + mels.append(s3tokenizer.log_mel_spectrogram(wav)) + + mels, mels_lens = s3tokenizer.padding(mels) + codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device)) + codes = codes.clone() + 151663 + + responses = [] + for i in range(len(requests)): + prompt_speech_tokens = codes[i, :codes_lens[i].item()] + prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack( + "prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) + inference_response = pb_utils.InferenceResponse( + output_tensors=[prompt_speech_tokens_tensor]) + responses.append(inference_response) + + return responses \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt b/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt new file mode 100644 index 0000000..6d8bd0c --- /dev/null +++ b/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt @@ -0,0 +1,53 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "audio_tokenizer" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +parameters [ + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "reference_wav" + data_type: TYPE_FP32 + dims: [-1] + }, + { + name: "reference_wav_len" + data_type: TYPE_INT32 + dims: [1] + } +] +output [ + { + name: "prompt_speech_tokens" + data_type: TYPE_INT32 + dims: [-1] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py new file mode 100644 index 0000000..97b9fee --- /dev/null +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py @@ -0,0 +1,331 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import math +import os +import re +from typing import Dict, List, Tuple, Optional, Union + +import numpy as np +import torch +from torch.utils.dlpack import from_dlpack, to_dlpack +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer +import torchaudio.compliance.kaldi as kaldi +import torchaudio +import onnxruntime + + +from matcha.utils.audio import mel_spectrogram + +class TritonPythonModel: + """Triton Python model for Spark TTS. + + This model orchestrates the end-to-end TTS pipeline by coordinating + between audio tokenizer, LLM, and vocoder components. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + self.logger = pb_utils.Logger + # Parse model parameters + self.model_config = json.loads(args['model_config']) + parameters = self.model_config['parameters'] + model_params = {k: v["string_value"] for k, v in parameters.items()} + self.logger.log_info(f"model_params:{model_params}") + + # Initialize tokenizer + llm_tokenizer_dir = model_params["llm_tokenizer_dir"] + self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) + self.prompt_template = "<|sos|>{input_text}<|task_id|>" + self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>") + + self.device = torch.device("cuda") + self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) + + campplus_model = f'{model_params["model_dir"]}/campplus.onnx' + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"]) + + def forward_llm(self, input_ids): + """ + Prepares the response from the language model based on the provided + inputs. Creates a `pb_utils.InferenceRequest` object with passed + `llm_request_inputs` to send to a decoupled TensorRTLLM model. + For each response from the language model: + - Checks for errors and raise an exception if any are found. + - Extracts the "output_ids" tensor from the response. + - Determines the finish reason based on the presence of the + end-of-sequence token or reaching the maximum length. + - Appends the generated token IDs to `output_ids`. + - If the finish reason is determined, decodes the output IDs to text + and prepares the final response. + + The final response includes the generated text, finish reason, + completion tokens, prompt tokens, and total tokens. + + Parameters + ---------- + - llm_request_inputs (dict): A dictionary containing the inputs for the language model. + + Returns + ------- + - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata. + """ + # convert input_ids to numpy, with shape [1, sequence_length] + input_ids = input_ids.cpu().numpy() + max_tokens = 1024 + input_dict = { + "request_output_len": np.array([[max_tokens]], dtype=np.int32), + "end_id": np.array([[self.eos_token_id]], dtype=np.int32), + "pad_id": np.array([[self.eos_token_id]], dtype=np.int32), + "streaming": np.array([[self.decoupled]], dtype=np.bool_), + "runtime_top_p": np.array([[0.95]], dtype=np.float32), + "runtime_top_k": np.array([[50]], dtype=np.int32), + "temperature": np.array([[0.8]], dtype=np.float32), + "input_ids": input_ids, + "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), + } + + # Convert inputs to Triton tensors + input_tensor_list = [ + pb_utils.Tensor(k, v) for k, v in input_dict.items() + ] + + # Create and execute inference request + llm_request = pb_utils.InferenceRequest( + model_name="tensorrt_llm", + requested_output_names=["output_ids", "sequence_length"], + inputs=input_tensor_list, + ) + + llm_responses = llm_request.exec(decoupled=self.decoupled) + if self.decoupled: + for llm_response in llm_responses: + if llm_response.has_error(): + raise pb_utils.TritonModelException(llm_response.error().message()) + + # Extract and process output + output_ids = pb_utils.get_output_tensor_by_name( + llm_response, "output_ids").as_numpy() + seq_lens = pb_utils.get_output_tensor_by_name( + llm_response, "sequence_length").as_numpy() + + # Get actual output IDs up to the sequence length + actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + + yield actual_output_ids + else: + llm_response = llm_responses + if llm_response.has_error(): + raise pb_utils.TritonModelException(llm_response.error().message()) + + # Extract and process output + output_ids = pb_utils.get_output_tensor_by_name( + llm_response, "output_ids").as_numpy() + seq_lens = pb_utils.get_output_tensor_by_name( + llm_response, "sequence_length").as_numpy() + + # Get actual output IDs up to the sequence length + actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + + yield actual_output_ids + + def forward_audio_tokenizer(self, wav, wav_len): + """Forward pass through the audio tokenizer component. + + Args: + wav: Input waveform tensor + wav_len: Waveform length tensor + + Returns: + Tuple of global and semantic tokens + """ + inference_request = pb_utils.InferenceRequest( + model_name='audio_tokenizer', + requested_output_names=['prompt_speech_tokens'], + inputs=[wav, wav_len] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') + prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() + + return prompt_speech_tokens + + def forward_token2wav(self, prompt_speech_tokens: torch.Tensor, prompt_speech_feat: torch.Tensor, prompt_spk_embedding: torch.Tensor, target_speech_tokens: torch.Tensor) -> torch.Tensor: + """Forward pass through the vocoder component. + + Args: + prompt_speech_tokens: Prompt speech tokens tensor + prompt_speech_feat: Prompt speech feat tensor + prompt_spk_embedding: Prompt spk embedding tensor + target_speech_tokens: Target speech tokens tensor + + Returns: + Generated waveform tensor + """ + print(prompt_speech_tokens.shape, prompt_speech_feat.shape, prompt_spk_embedding.shape, target_speech_tokens.shape) + # Convert tensors to Triton format + prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) + prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) + prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) + target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) + + # Create and execute inference request + inference_request = pb_utils.InferenceRequest( + model_name='token2wav', + requested_output_names=['waveform'], + inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output waveform + waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') + waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() + + return waveform + + def parse_input(self, text, prompt_text, prompt_speech_tokens): + total_text = f"{prompt_text}{text}" + prompt = self.prompt_template.format(input_text=total_text) + input_ids = self.tokenizer.encode(prompt) + input_ids = torch.tensor([input_ids], dtype=torch.int32) + print(input_ids.shape, "before cat") + input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1) + print(input_ids.shape, "after cat", prompt_speech_tokens.shape) + return input_ids + + def _extract_spk_embedding(self, speech): + feat = kaldi.fbank(speech, + num_mel_bins=80, + dither=0, + sample_frequency=16000) + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = self.campplus_session.run(None, + {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() + embedding = torch.tensor([embedding]).to(self.device).half() + return embedding + + + def _extract_speech_feat(self, speech): + speech_feat = mel_spectrogram(speech, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000).squeeze(dim=0).transpose(0, 1).to(self.device) + speech_feat = speech_feat.unsqueeze(dim=0) + return speech_feat + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated audio + """ + responses = [] + + for request in requests: + # Extract input tensors + wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + + # Process reference audio through audio tokenizer + + prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) + prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) + + # TODO: FIX ME + wav_tensor = wav.as_numpy() + print(wav_tensor.shape, "wav_tensor") + wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + print(wav_tensor.shape, "wav_tensor after") + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) + speech_feat = self._extract_speech_feat(prompt_speech_resample) + print(speech_feat.shape, "speech_feat") + print(prompt_speech_tokens.shape, "prompt_speech_tokens here") + token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) + prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() + prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() + print(prompt_speech_tokens.shape, "prompt_speech_tokens after") + print(speech_feat.shape, "speech_feat after") + print(token_len, "token_len") + + # Extract text inputs + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') + + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = target_text[0][0].decode('utf-8') + + # Prepare prompt for LLM + input_ids = self.parse_input( + text=target_text, + prompt_text=reference_text, + prompt_speech_tokens=prompt_speech_tokens, + ) + + # Generate semantic tokens with LLM + generated_ids_iter = self.forward_llm(input_ids) + + if self.decoupled: + response_sender = request.get_response_sender() + request_id = request.request_id() + for generated_ids in generated_ids_iter: + raise NotImplementedError("Decoupled mode is not implemented") + else: + generated_ids = next(generated_ids_iter) + generated_ids = torch.tensor([generated_ids]).to(self.device) + if generated_ids is None or len(generated_ids) == 0: + raise pb_utils.TritonModelException("Generated IDs is None or empty") + + prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) + audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) + + # Prepare response + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + responses.append(inference_response) + + if self.decoupled: + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + self.logger.log_info(f"send tritonserver_response_complete_final to end") + + if not self.decoupled: + return responses \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt b/runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt new file mode 100644 index 0000000..c370336 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt @@ -0,0 +1,70 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "cosyvoice2" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +model_transaction_policy { + decoupled: ${decoupled_mode} +} +parameters [ + { + key: "llm_tokenizer_dir", + value: {string_value:"${llm_tokenizer_dir}"} + }, + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "reference_wav" + data_type: TYPE_FP32 + dims: [-1] + }, + { + name: "reference_wav_len" + data_type: TYPE_INT32 + dims: [1] + }, + { + name: "reference_text" + data_type: TYPE_STRING + dims: [1] + }, + { + name: "target_text" + data_type: TYPE_STRING + dims: [1] + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: ${bls_instance_num} + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep b/runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt b/runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt new file mode 100644 index 0000000..3da7894 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt @@ -0,0 +1,857 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm" +backend: "${triton_backend}" +max_batch_size: ${triton_max_batch_size} + +model_transaction_policy { + decoupled: ${decoupled_mode} +} + +dynamic_batching { + preferred_batch_size: [ ${triton_max_batch_size} ] + max_queue_delay_microseconds: ${max_queue_delay_microseconds} + default_queue_policy: { max_queue_size: ${max_queue_size} } +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + allow_ragged_batch: true + optional: true + }, + { + name: "encoder_input_features" + data_type: ${encoder_input_features_data_type} + dims: [ -1, -1 ] + allow_ragged_batch: true + optional: true + }, + { + name: "encoder_output_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "num_return_sequences" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "draft_input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "decoder_input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "decoder_input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + reshape: { shape: [ ] } + }, + { + name: "draft_logits" + data_type: ${logits_datatype} + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "draft_acceptance_threshold" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "end_id" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "pad_id" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "bad_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "embedding_bias" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "beam_width" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_min" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_decay" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_reset_ids" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "early_stopping" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_search_diversity_rate" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "frequency_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_context_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_generation_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_perf_metrics" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "exclude_input_in_output" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "prompt_embedding_table" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "prompt_table_extra_ids" + data_type: TYPE_UINT64 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "prompt_vocab_size" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + # cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]` + { + name: "cross_attention_mask" + data_type: TYPE_BOOL + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + # Mrope param when mrope is used + { + name: "mrope_rotary_cos_sin" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + }, + { + name: "mrope_position_deltas" + data_type: TYPE_INT64 + dims: [ 1 ] + optional: true + }, + # the unique task ID for the given LoRA. + # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given. + # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`. + # If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached. + { + name: "lora_task_id" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ] + # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer + # each of the in / out tensors are first flattened and then concatenated together in the format above. + # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out. + { + name: "lora_weights" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + # module identifier (same size a first dimension of lora_weights) + # See LoraModule::ModuleType for model id mapping + # + # "attn_qkv": 0 # compbined qkv adapter + # "attn_q": 1 # q adapter + # "attn_k": 2 # k adapter + # "attn_v": 3 # v adapter + # "attn_dense": 4 # adapter for the dense layer in attention + # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection + # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection + # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate + # + # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ] + { + name: "lora_config" + data_type: TYPE_INT32 + dims: [ -1, 3 ] + optional: true + allow_ragged_batch: true + }, + { + name: "context_phase_params" + data_type: TYPE_UINT8 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + # skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama + { + name: "skip_cross_attn_blocks" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_starts" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_ends" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_priorities" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_durations_ms" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_decode_priority" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_decode_duration_ms" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "guided_decoding_guide_type" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "guided_decoding_guide" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_window_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_ngram_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_verification_set_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + }, + { + name: "sequence_length" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "cum_log_probs" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "output_log_probs" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "context_logits" + data_type: ${logits_datatype} + dims: [ -1, -1 ] + }, + { + name: "generation_logits" + data_type: ${logits_datatype} + dims: [ -1, -1, -1 ] + }, + { + name: "batch_index" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "sequence_index" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "context_phase_params" + data_type: TYPE_UINT8 + dims: [ -1 ] + }, + { + name: "kv_cache_alloc_new_blocks" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "kv_cache_reused_blocks" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "kv_cache_alloc_total_blocks" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "arrival_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_scheduled_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "last_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "acceptance_rate" + data_type: TYPE_FP32 + dims: [ 1 ] + }, + { + name: "total_accepted_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "total_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters: { + key: "max_beam_width" + value: { + string_value: "${max_beam_width}" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} +parameters: { + key: "gpt_model_type" + value: { + string_value: "${batching_strategy}" + } +} +parameters: { + key: "gpt_model_path" + value: { + string_value: "${engine_dir}" + } +} +parameters: { + key: "encoder_model_path" + value: { + string_value: "${encoder_engine_dir}" + } +} +parameters: { + key: "max_tokens_in_paged_kv_cache" + value: { + string_value: "${max_tokens_in_paged_kv_cache}" + } +} +parameters: { + key: "max_attention_window_size" + value: { + string_value: "${max_attention_window_size}" + } +} +parameters: { + key: "sink_token_length" + value: { + string_value: "${sink_token_length}" + } +} +parameters: { + key: "batch_scheduler_policy" + value: { + string_value: "${batch_scheduler_policy}" + } +} +parameters: { + key: "kv_cache_free_gpu_mem_fraction" + value: { + string_value: "${kv_cache_free_gpu_mem_fraction}" + } +} +parameters: { + key: "cross_kv_cache_fraction" + value: { + string_value: "${cross_kv_cache_fraction}" + } +} +parameters: { + key: "kv_cache_host_memory_bytes" + value: { + string_value: "${kv_cache_host_memory_bytes}" + } +} +# kv_cache_onboard_blocks is for internal implementation. +parameters: { + key: "kv_cache_onboard_blocks" + value: { + string_value: "${kv_cache_onboard_blocks}" + } +} +# enable_trt_overlap is deprecated and doesn't have any effect on the runtime +# parameters: { +# key: "enable_trt_overlap" +# value: { +# string_value: "${enable_trt_overlap}" +# } +# } +parameters: { + key: "exclude_input_in_output" + value: { + string_value: "${exclude_input_in_output}" + } +} +parameters: { + key: "cancellation_check_period_ms" + value: { + string_value: "${cancellation_check_period_ms}" + } +} +parameters: { + key: "stats_check_period_ms" + value: { + string_value: "${stats_check_period_ms}" + } +} +parameters: { + key: "iter_stats_max_iterations" + value: { + string_value: "${iter_stats_max_iterations}" + } +} +parameters: { + key: "request_stats_max_iterations" + value: { + string_value: "${request_stats_max_iterations}" + } +} +parameters: { + key: "enable_kv_cache_reuse" + value: { + string_value: "${enable_kv_cache_reuse}" + } +} +parameters: { + key: "normalize_log_probs" + value: { + string_value: "${normalize_log_probs}" + } +} +parameters: { + key: "enable_chunked_context" + value: { + string_value: "${enable_chunked_context}" + } +} +parameters: { + key: "gpu_device_ids" + value: { + string_value: "${gpu_device_ids}" + } +} +parameters: { + key: "participant_ids" + value: { + string_value: "${participant_ids}" + } +} +parameters: { + key: "lora_cache_optimal_adapter_size" + value: { + string_value: "${lora_cache_optimal_adapter_size}" + } +} +parameters: { + key: "lora_cache_max_adapter_size" + value: { + string_value: "${lora_cache_max_adapter_size}" + } +} +parameters: { + key: "lora_cache_gpu_memory_fraction" + value: { + string_value: "${lora_cache_gpu_memory_fraction}" + } +} +parameters: { + key: "lora_cache_host_memory_bytes" + value: { + string_value: "${lora_cache_host_memory_bytes}" + } +} +parameters: { + key: "lora_prefetch_dir" + value: { + string_value: "${lora_prefetch_dir}" + } +} +parameters: { + key: "decoding_mode" + value: { + string_value: "${decoding_mode}" + } +} +parameters: { + key: "executor_worker_path" + value: { + string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker" + } +} +parameters: { + key: "lookahead_window_size" + value: { + string_value: "${lookahead_window_size}" + } +} +parameters: { + key: "lookahead_ngram_size" + value: { + string_value: "${lookahead_ngram_size}" + } +} +parameters: { + key: "lookahead_verification_set_size" + value: { + string_value: "${lookahead_verification_set_size}" + } +} +parameters: { + key: "medusa_choices" + value: { + string_value: "${medusa_choices}" + } +} +parameters: { + key: "eagle_choices" + value: { + string_value: "${eagle_choices}" + } +} +parameters: { + key: "gpu_weights_percent" + value: { + string_value: "${gpu_weights_percent}" + } +} +parameters: { + key: "enable_context_fmha_fp32_acc" + value: { + string_value: "${enable_context_fmha_fp32_acc}" + } +} +parameters: { + key: "multi_block_mode" + value: { + string_value: "${multi_block_mode}" + } +} +parameters: { + key: "cuda_graph_mode" + value: { + string_value: "${cuda_graph_mode}" + } +} +parameters: { + key: "cuda_graph_cache_size" + value: { + string_value: "${cuda_graph_cache_size}" + } +} +parameters: { + key: "speculative_decoding_fast_logits" + value: { + string_value: "${speculative_decoding_fast_logits}" + } +} +parameters: { + key: "tokenizer_dir" + value: { + string_value: "${tokenizer_dir}" + } +} +parameters: { + key: "guided_decoding_backend" + value: { + string_value: "${guided_decoding_backend}" + } +} +parameters: { + key: "xgrammar_tokenizer_info_path" + value: { + string_value: "${xgrammar_tokenizer_info_path}" + } +} diff --git a/runtime/triton_trtllm/model_repo/token2wav/1/model.py b/runtime/triton_trtllm/model_repo/token2wav/1/model.py new file mode 100644 index 0000000..1232ba9 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/token2wav/1/model.py @@ -0,0 +1,198 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import os + +import logging +from typing import List, Dict + +import torch +from torch.utils.dlpack import to_dlpack + +import triton_python_backend_utils as pb_utils + +from hyperpyyaml import load_hyperpyyaml +from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm +from cosyvoice.utils.common import TrtContextWrapper +#import sys +#sys.path.append("/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS") + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +class CosyVoice2: + + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): + + self.model_dir = model_dir + self.fp16 = fp16 + + hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir) + if not os.path.exists(hyper_yaml_path): + raise ValueError('{} not found!'.format(hyper_yaml_path)) + with open(hyper_yaml_path, 'r') as f: + configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) + self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16) + self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) + if load_jit: + self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) + if load_trt: + self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), + '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + trt_concurrent, + self.fp16) + +class CosyVoice2Model: + + def __init__(self, + flow: torch.nn.Module, + hift: torch.nn.Module, + fp16: bool = False): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.flow = flow + self.hift = hift + self.fp16 = fp16 + if self.fp16 is True: + self.flow.half() + + def load_jit(self, flow_encoder_model): + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + self.flow.encoder = flow_encoder + + def load(self, flow_model, hift_model): + self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) + self.flow.to(self.device).eval() + # in case hift_model is a hifigan model + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() + + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) + del self.flow.decoder.estimator + import tensorrt as trt + with open(flow_decoder_estimator_model, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_trt_kwargs(self): + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] + opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)] + max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] + input_names = ["x", "mask", "mu", "cond"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + +class TritonPythonModel: + """Triton Python model for vocoder. + + This model takes global and semantic tokens as input and generates audio waveforms + using the BiCodec vocoder. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + # Parse model parameters + parameters = json.loads(args['model_config'])['parameters'] + model_params = {key: value["string_value"] for key, value in parameters.items()} + model_dir = model_params["model_dir"] + + # Initialize device and vocoder + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Initializing vocoder from {model_dir} on {self.device}") + + self.token2wav_model = CosyVoice2( + model_dir, load_jit=True, load_trt=True, fp16=True + ) + + logger.info("Token2Wav initialized successfully") + + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated waveforms + """ + responses = [] + # Process each request in batch + for request in requests: + target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() + prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy() + prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy() + prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy() + + target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device) + prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device) + prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device) + prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device) + + prompt_speech_tokens = prompt_speech_tokens - 151663 + target_speech_tokens = target_speech_tokens - 151663 + + tts_mel, _ = self.token2wav_model.model.flow.inference( + token=target_speech_tokens, + token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( + self.device + ), + prompt_token=prompt_speech_tokens, + prompt_token_len=torch.tensor( + [prompt_speech_tokens.shape[1]], dtype=torch.int32 + ).to(self.device), + prompt_feat=prompt_speech_feat, + prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=prompt_spk_embedding, + streaming=False, + finalize=True, + ) + + audio_hat, _ = self.token2wav_model.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) + + generated_wave = audio_hat.squeeze(0).cpu().numpy() + + wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) + inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) + responses.append(inference_response) + + return responses + + + + diff --git a/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt new file mode 100644 index 0000000..36489ff --- /dev/null +++ b/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt @@ -0,0 +1,63 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "token2wav" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +parameters [ + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "target_speech_tokens" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "prompt_speech_tokens" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "prompt_speech_feat" + data_type: TYPE_FP16 + dims: [-1, 80] + }, + { + name: "prompt_spk_embedding" + data_type: TYPE_FP16 + dims: [-1] + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/runtime/triton_trtllm/requirements.txt b/runtime/triton_trtllm/requirements.txt new file mode 100644 index 0000000..9f675b6 --- /dev/null +++ b/runtime/triton_trtllm/requirements.txt @@ -0,0 +1,13 @@ +hyperpyyaml +s3tokenizer +onnxruntime-gpu +omegaconf +conformer +hydra-core +lightning +gdown +wget +librosa +pyworld +openai-whisper +tritonclient \ No newline at end of file diff --git a/runtime/triton_trtllm/run.sh b/runtime/triton_trtllm/run.sh new file mode 100644 index 0000000..3e4a1a8 --- /dev/null +++ b/runtime/triton_trtllm/run.sh @@ -0,0 +1,92 @@ +# huggingface-cli download --local-dir cosyvoice2_llm yuekai/cosyvoice2_llm +# modelscope download --model iic/CosyVoice2-0.5B --local_dir ./CosyVoice2-0.5B/ +# git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git +# cd CosyVoice +# git submodule update --init --recursive +export CUDA_VISIBLE_DEVICES=0 +export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice:$PYTHONPATH +export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS:$PYTHONPATH +stage=$1 +stop_stage=$2 + +huggingface_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/cosyvoice2_llm +model_scope_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice2-0.5B +trt_dtype=bfloat16 +trt_dtype=float16 +trt_weights_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_weights_${trt_dtype} +trt_engines_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_engines_${trt_dtype} + +model_repo=./model_repo_cosyvoice2 +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + echo "Converting checkpoint to TensorRT weights" + python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \ + --output_dir $trt_weights_dir \ + --dtype $trt_dtype || exit 1 + + echo "Building TensorRT engines" + trtllm-build --checkpoint_dir $trt_weights_dir \ + --output_dir $trt_engines_dir \ + --max_batch_size 16 \ + --max_num_tokens 32768 \ + --gemm_plugin $trt_dtype || exit 1 +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + echo "Testing TensorRT engines" + python3 ./test_llm.py --input_text "你好,请问你叫什么?" \ + --tokenizer_dir $huggingface_model_local_dir \ + --top_k 50 --top_p 0.95 --temperature 0.8 \ + --engine_dir=$trt_engines_dir || exit 1 +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + echo "Creating model repository" + rm -rf $model_repo + mkdir -p $model_repo + cosyvoice2_dir="cosyvoice2" + + cp -r ./model_repo/${cosyvoice2_dir} $model_repo + cp -r ./model_repo/audio_tokenizer $model_repo + cp -r ./model_repo/tensorrt_llm $model_repo + cp -r ./model_repo/token2wav $model_repo + + ENGINE_PATH=$trt_engines_dir + MAX_QUEUE_DELAY_MICROSECONDS=0 + MODEL_DIR=$model_scope_model_local_dir + LLM_TOKENIZER_DIR=$huggingface_model_local_dir + BLS_INSTANCE_NUM=4 + TRITON_MAX_BATCH_SIZE=16 + DECOUPLED_MODE=False + + python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 + +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + + tritonserver --model-repository $model_repo +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + echo "Testing TensorRT engines" + python3 client_http.py \ + --reference-audio ./prompt_audio.wav \ + --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ + --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \ + --model-name cosyvoice2 +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + echo "Running benchmark client" + num_task=4 + python3 client_grpc.py \ + --server-addr localhost \ + --model-name cosyvoice2 \ + --num-tasks $num_task \ + --mode offline \ + --huggingface-dataset yuekai/seed_tts_cosy2 \ + --log-dir ./log_concurrent_tasks_${num_task}_offline_bls_4_${trt_dtype} +fi \ No newline at end of file diff --git a/runtime/triton_trtllm/scripts/convert_checkpoint.py b/runtime/triton_trtllm/scripts/convert_checkpoint.py new file mode 100644 index 0000000..932cdf8 --- /dev/null +++ b/runtime/triton_trtllm/scripts/convert_checkpoint.py @@ -0,0 +1,342 @@ +import argparse +import os +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed + +from transformers import AutoConfig + +import tensorrt_llm +from tensorrt_llm._utils import release_gc +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models import QWenForCausalLM +from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization import QuantAlgo + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_dir', type=str, default=None, required=True) + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + parser.add_argument('--cp_size', + type=int, + default=1, + help='N-way context parallelism size') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'float16', 'bfloat16', 'float32'], + help= + "The data type for the model weights and activations if not quantized. " + "If 'auto', the data type is automatically inferred from the source model; " + "however, if the source dtype is float32, it is converted to float16.") + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--disable_weight_only_quant_plugin', + default=False, + action="store_true", + help= + 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8', 'int4', 'int4_gptq'], + help= + 'Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument( + '--calib_dataset', + type=str, + default='ccdv/cnn_dailymail', + help= + "The huggingface dataset name or the local directory of the dataset for calibration." + ) + parser.add_argument( + "--smoothquant", + "-sq", + type=float, + default=None, + help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" + " to Smoothquant the model, and output int8 weights." + " A good first try is 0.5. Must be in [0, 1]") + parser.add_argument( + '--per_channel', + action="store_true", + default=False, + help= + 'By default, we use a single static scaling factor for the GEMM\'s result. ' + 'per_channel instead uses a different static scaling factor for each channel. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + '--per_token', + action="store_true", + default=False, + help= + 'By default, we use a single static scaling factor to scale activations in the int8 range. ' + 'per_token chooses at run time, and for each token, a custom scaling factor. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + '--int8_kv_cache', + default=False, + action="store_true", + help= + 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' + ) + parser.add_argument( + '--per_group', + default=False, + action="store_true", + help= + 'By default, we use a single static scaling factor to scale weights in the int4 range. ' + 'per_group chooses at run time, and for each group, a custom scaling factor. ' + 'The flag is built for GPTQ/AWQ quantization.') + + parser.add_argument('--group_size', + type=int, + default=128, + help='Group size used in GPTQ quantization.') + + parser.add_argument("--load_model_on_cpu", action="store_true") + parser.add_argument( + '--use_parallel_embedding', + action="store_true", + default=False, + help= + 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help= + 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument( + '--workers', + type=int, + default=1, + help='The number of workers for converting checkpoint in parallel') + parser.add_argument( + '--moe_tp_size', + type=int, + default=-1, + help= + 'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE' + ) + parser.add_argument( + '--moe_ep_size', + type=int, + default=-1, + help= + 'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE' + ) + args = parser.parse_args() + return args + + +def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: + '''return config dict with quantization info based on the command line args + ''' + quant_config = QuantConfig() + if args.use_weight_only: + if args.weight_only_precision == 'int8': + quant_config.quant_algo = QuantAlgo.W8A16 + elif args.weight_only_precision == 'int4': + quant_config.quant_algo = QuantAlgo.W4A16 + elif args.smoothquant: + quant_config.smoothquant_val = args.smoothquant + if args.per_channel: + if args.per_token: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN + else: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN + else: + if args.per_token: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN + else: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN + + if args.int8_kv_cache: + quant_config.kv_cache_quant_algo = QuantAlgo.INT8 + + if args.weight_only_precision == 'int4_gptq': + quant_config.group_size = args.group_size + quant_config.has_zero_point = True + quant_config.pre_quant_scale = False + quant_config.quant_algo = QuantAlgo.W4A16_GPTQ + + return quant_config + + +def update_quant_config_from_hf(quant_config, hf_config, + override_fields) -> tuple[QuantConfig, dict]: + hf_config_dict = hf_config.to_dict() + if hf_config_dict.get('quantization_config'): + # update the quant_algo, and clamp_val. + if hf_config_dict['quantization_config'].get('quant_method') == 'awq': + logger.info( + "Load quantization configs from huggingface model_config.") + quant_config.quant_algo = QuantAlgo.W4A16_GPTQ + quant_config.group_size = hf_config_dict['quantization_config'].get( + 'group_size', 128) + quant_config.has_zero_point = hf_config_dict[ + 'quantization_config'].get('zero_point', False) + override_fields.update({"use_autoawq": True}) + elif hf_config_dict['quantization_config'].get( + 'quant_method') == 'gptq': + logger.info( + "Load quantization configs from huggingface model_config.") + desc_act = hf_config_dict['quantization_config'].get( + 'desc_act', False) + if desc_act: + raise ValueError("GPTQ with desc_act=True is not implemented!") + quant_config.quant_algo = QuantAlgo.W4A16_GPTQ + quant_config.group_size = hf_config_dict['quantization_config'].get( + 'group_size', 128) + quant_config.has_zero_point = hf_config_dict[ + 'quantization_config'].get('sym', False) + return quant_config, override_fields + + +def args_to_build_options(args): + return { + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'disable_weight_only_quant_plugin': + args.disable_weight_only_quant_plugin + } + + +def convert_and_save_hf(args): + model_dir = args.model_dir + world_size = args.tp_size * args.pp_size + # Need to convert the cli args to the kay-value pairs and override them in the generate config dict. + # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, + # before the refactor is done. + override_fields = {} + override_fields.update(args_to_build_options(args)) + quant_config = args_to_quant_config(args) + + try: + hf_config = AutoConfig.from_pretrained(model_dir, + trust_remote_code=True) + quant_config, override_fields = update_quant_config_from_hf( + quant_config, hf_config, override_fields) + except: + logger.warning("AutoConfig cannot load the huggingface config.") + + if args.smoothquant is not None or args.int8_kv_cache: + mapping = Mapping(world_size=world_size, + tp_size=args.tp_size, + pp_size=args.pp_size, + moe_tp_size=args.moe_tp_size, + moe_ep_size=args.moe_ep_size, + cp_size=args.cp_size) + QWenForCausalLM.quantize(args.model_dir, + args.output_dir, + dtype=args.dtype, + mapping=mapping, + quant_config=quant_config, + calib_dataset=args.calib_dataset, + **override_fields) + else: + + def convert_and_save_rank(args, rank): + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size, + moe_tp_size=args.moe_tp_size, + moe_ep_size=args.moe_ep_size) + qwen = QWenForCausalLM.from_hugging_face(model_dir, + args.dtype, + mapping=mapping, + quant_config=quant_config, + **override_fields) + qwen.config.mapping.cp_size = args.cp_size + qwen.config.mapping.attn_tp_size = -1 + qwen.config.mapping.attn_cp_size = -1 + qwen.config.mapping.world_size *= args.cp_size + qwen.save_checkpoint(args.output_dir, save_config=(rank == 0)) + del qwen + + execute(args.workers, [convert_and_save_rank] * world_size, args) + release_gc() + + +def execute(workers, func, args): + if workers == 1: + for rank, f in enumerate(func): + f(args, rank) + else: + with ThreadPoolExecutor(max_workers=workers) as p: + futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." + + +def main(): + print(tensorrt_llm.__version__) + args = parse_arguments() + + if (args.moe_tp_size == -1 and args.moe_ep_size == -1): + # moe default to tp-only + args.moe_tp_size = args.tp_size + args.moe_ep_size = 1 + elif (args.moe_tp_size == -1): + args.moe_tp_size = args.tp_size // args.moe_ep_size + elif (args.moe_ep_size == -1): + args.moe_ep_size = args.tp_size // args.moe_tp_size + assert (args.moe_tp_size * args.moe_ep_size == args.tp_size + ), "moe_tp_size * moe_ep_size must equal to tp_size" + + tik = time.time() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + assert args.model_dir is not None + convert_and_save_hf(args) + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Total time of converting checkpoints: {t}') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/runtime/triton_trtllm/scripts/fill_template.py b/runtime/triton_trtllm/scripts/fill_template.py new file mode 100644 index 0000000..5c629f7 --- /dev/null +++ b/runtime/triton_trtllm/scripts/fill_template.py @@ -0,0 +1,70 @@ +#! /usr/bin/env python3 +from argparse import ArgumentParser +from string import Template + + +def split(string, delimiter): + """Split a string using delimiter. Supports escaping. + + Args: + string (str): The string to split. + delimiter (str): The delimiter to split the string with. + + Returns: + list: A list of strings. + """ + result = [] + current = "" + escape = False + for char in string: + if escape: + current += char + escape = False + elif char == delimiter: + result.append(current) + current = "" + elif char == "\\": + escape = True + else: + current += char + result.append(current) + return result + + +def main(file_path, substitutions, in_place): + with open(file_path) as f: + pbtxt = Template(f.read()) + + sub_dict = { + "max_queue_size": 0, + 'max_queue_delay_microseconds': 0, + } + for sub in split(substitutions, ","): + key, value = split(sub, ":") + sub_dict[key] = value + + assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}." + + pbtxt = pbtxt.safe_substitute(sub_dict) + + if in_place: + with open(file_path, "w") as f: + f.write(pbtxt) + else: + print(pbtxt) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("file_path", help="path of the .pbtxt to modify") + parser.add_argument( + "substitutions", + help= + "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." + ) + parser.add_argument("--in_place", + "-i", + action="store_true", + help="do the operation in-place") + args = parser.parse_args() + main(**vars(args)) diff --git a/runtime/triton_trtllm/scripts/test_llm.py b/runtime/triton_trtllm/scripts/test_llm.py new file mode 100644 index 0000000..9ffe9cf --- /dev/null +++ b/runtime/triton_trtllm/scripts/test_llm.py @@ -0,0 +1,144 @@ + +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import ast +import csv +import os +from pathlib import Path +from typing import List, Optional + +import numpy as np +import torch + +import tensorrt_llm +from tensorrt_llm.logger import logger + +from tensorrt_llm.runtime import ModelRunnerCpp +from transformers import AutoTokenizer + + +def parse_arguments(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + '--input_text', + type=str, + nargs='+', + default=["Born in north-east France, Soyer trained as a"]) + parser.add_argument('--tokenizer_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") + parser.add_argument('--engine_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") + parser.add_argument('--log_level', type=str, default="debug") + parser.add_argument('--kv_cache_free_gpu_memory_fraction', type=float, default=0.6) + parser.add_argument('--temperature', type=float, default=0.8) + parser.add_argument('--top_k', type=int, default=50) + parser.add_argument('--top_p', type=float, default=0.95) + + + return parser.parse_args(args=args) + + +def parse_input(tokenizer, + input_text=None, + prompt_template=None): + batch_input_ids = [] + for curr_text in input_text: + if prompt_template is not None: + curr_text = prompt_template.format(input_text=curr_text) + input_ids = tokenizer.encode( + curr_text) + batch_input_ids.append(input_ids) + + batch_input_ids = [ + torch.tensor(x, dtype=torch.int32) for x in batch_input_ids + ] + + logger.debug(f"Input token ids (batch_size = {len(batch_input_ids)}):") + for i, input_ids in enumerate(batch_input_ids): + logger.debug(f"Request {i}: {input_ids.tolist()}") + + return batch_input_ids + + +def main(args): + runtime_rank = tensorrt_llm.mpi_rank() + logger.set_level(args.log_level) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) + prompt_template = "<|sos|>{input_text}<|task_id|>" + end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") + + batch_input_ids = parse_input(tokenizer=tokenizer, + input_text=args.input_text, + prompt_template=prompt_template) + + input_lengths = [x.size(0) for x in batch_input_ids] + + runner_kwargs = dict( + engine_dir=args.engine_dir, + rank=runtime_rank, + max_output_len=1024, + enable_context_fmha_fp32_acc=False, + max_batch_size=len(batch_input_ids), + max_input_len=max(input_lengths), + kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction, + cuda_graph_mode=False, + gather_generation_logits=False, + ) + + runner = ModelRunnerCpp.from_dir(**runner_kwargs) + + with torch.no_grad(): + outputs = runner.generate( + batch_input_ids=batch_input_ids, + max_new_tokens=1024, + end_id=end_id, + pad_id=end_id, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + num_return_sequences=1, + repetition_penalty=1.1, + random_seed=42, + streaming=False, + output_sequence_lengths=True, + output_generation_logits=False, + return_dict=True, + return_all_generated_tokens=False) + torch.cuda.synchronize() + output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"] + num_output_sents, num_beams, _ = output_ids.size() + assert num_beams == 1 + beam = 0 + batch_size = len(input_lengths) + num_return_sequences = num_output_sents // batch_size + assert num_return_sequences == 1 + for i in range(batch_size * num_return_sequences): + batch_idx = i // num_return_sequences + seq_idx = i % num_return_sequences + inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist() + input_text = tokenizer.decode(inputs) + print(f'Input [Text {batch_idx}]: \"{input_text}\"') + output_begin = input_lengths[batch_idx] + output_end = sequence_lengths[i][beam] + outputs = output_ids[i][beam][output_begin:output_end].tolist() + output_text = tokenizer.decode(outputs) + print(f'Output [Text {batch_idx}]: \"{output_text}\"') + logger.debug(str(outputs)) + + +if __name__ == '__main__': + args = parse_arguments() + main(args) From 178da09993e9b794bc1f0e726e35b7546059bd14 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Sun, 27 Jul 2025 23:33:10 -0700 Subject: [PATCH 2/6] clean code --- runtime/triton_trtllm/client_grpc.py | 2 +- runtime/triton_trtllm/client_http.py | 6 +++--- .../model_repo/audio_tokenizer/1/model.py | 3 ++- .../model_repo/cosyvoice2/1/model.py | 14 +------------- .../model_repo/token2wav/1/model.py | 9 ++++----- runtime/triton_trtllm/run.sh | 18 ++++++++++++------ 6 files changed, 23 insertions(+), 29 deletions(-) diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 19f13e2..7dba493 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -692,7 +692,7 @@ async def main(): model_name=args.model_name, audio_save_dir=args.log_dir, padding_duration=10, - save_sample_rate=24000 if args.model_name == "f5_tts" else 16000, + save_sample_rate=16000 if args.model_name == "spark_tts" else 24000, chunk_overlap_duration=args.chunk_overlap_duration, ) ) diff --git a/runtime/triton_trtllm/client_http.py b/runtime/triton_trtllm/client_http.py index 970d417..e22f4eb 100644 --- a/runtime/triton_trtllm/client_http.py +++ b/runtime/triton_trtllm/client_http.py @@ -162,8 +162,8 @@ if __name__ == "__main__": result = rsp.json() audio = result["outputs"][0]["data"] audio = np.array(audio, dtype=np.float32) - if args.model_name == "cosyvoice2": - sample_rate = 24000 - else: + if args.model_name == "spark_tts": sample_rate = 16000 + else: + sample_rate = 24000 sf.write(args.output_audio, audio, sample_rate, "PCM_16") \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py index a197ec9..105ffa1 100644 --- a/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py +++ b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py @@ -33,6 +33,7 @@ import os import numpy as np import s3tokenizer +ORIGINAL_VOCAB_SIZE = 151663 class TritonPythonModel: """Triton Python model for audio tokenization. @@ -81,7 +82,7 @@ class TritonPythonModel: mels, mels_lens = s3tokenizer.padding(mels) codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device)) - codes = codes.clone() + 151663 + codes = codes.clone() + ORIGINAL_VOCAB_SIZE responses = [] for i in range(len(requests)): diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py index 97b9fee..cd61931 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py @@ -199,8 +199,6 @@ class TritonPythonModel: Returns: Generated waveform tensor """ - print(prompt_speech_tokens.shape, prompt_speech_feat.shape, prompt_spk_embedding.shape, target_speech_tokens.shape) - # Convert tensors to Triton format prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) @@ -228,9 +226,7 @@ class TritonPythonModel: prompt = self.prompt_template.format(input_text=total_text) input_ids = self.tokenizer.encode(prompt) input_ids = torch.tensor([input_ids], dtype=torch.int32) - print(input_ids.shape, "before cat") input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1) - print(input_ids.shape, "after cat", prompt_speech_tokens.shape) return input_ids def _extract_spk_embedding(self, speech): @@ -271,23 +267,15 @@ class TritonPythonModel: prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - # TODO: FIX ME + wav_tensor = wav.as_numpy() - print(wav_tensor.shape, "wav_tensor") wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] - print(wav_tensor.shape, "wav_tensor after") prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) speech_feat = self._extract_speech_feat(prompt_speech_resample) - print(speech_feat.shape, "speech_feat") - print(prompt_speech_tokens.shape, "prompt_speech_tokens here") token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() - print(prompt_speech_tokens.shape, "prompt_speech_tokens after") - print(speech_feat.shape, "speech_feat after") - print(token_len, "token_len") - # Extract text inputs reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') diff --git a/runtime/triton_trtllm/model_repo/token2wav/1/model.py b/runtime/triton_trtllm/model_repo/token2wav/1/model.py index 1232ba9..d6735a1 100644 --- a/runtime/triton_trtllm/model_repo/token2wav/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav/1/model.py @@ -38,13 +38,11 @@ import triton_python_backend_utils as pb_utils from hyperpyyaml import load_hyperpyyaml from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm from cosyvoice.utils.common import TrtContextWrapper -#import sys -#sys.path.append("/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS") -# Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) +ORIGINAL_VOCAB_SIZE = 151663 class CosyVoice2: @@ -162,8 +160,9 @@ class TritonPythonModel: prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device) prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device) - prompt_speech_tokens = prompt_speech_tokens - 151663 - target_speech_tokens = target_speech_tokens - 151663 + # shift the speech tokens according to the original vocab size + prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE + target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE tts_mel, _ = self.token2wav_model.model.flow.inference( token=target_speech_tokens, diff --git a/runtime/triton_trtllm/run.sh b/runtime/triton_trtllm/run.sh index 3e4a1a8..10e6a67 100644 --- a/runtime/triton_trtllm/run.sh +++ b/runtime/triton_trtllm/run.sh @@ -1,8 +1,4 @@ -# huggingface-cli download --local-dir cosyvoice2_llm yuekai/cosyvoice2_llm -# modelscope download --model iic/CosyVoice2-0.5B --local_dir ./CosyVoice2-0.5B/ -# git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git -# cd CosyVoice -# git submodule update --init --recursive + export CUDA_VISIBLE_DEVICES=0 export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice:$PYTHONPATH export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS:$PYTHONPATH @@ -12,11 +8,21 @@ stop_stage=$2 huggingface_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/cosyvoice2_llm model_scope_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice2-0.5B trt_dtype=bfloat16 -trt_dtype=float16 trt_weights_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_weights_${trt_dtype} trt_engines_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_engines_${trt_dtype} model_repo=./model_repo_cosyvoice2 + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + echo " " + huggingface-cli download --local-dir cosyvoice2_llm yuekai/cosyvoice2_llm + modelscope download --model iic/CosyVoice2-0.5B --local_dir ./CosyVoice2-0.5B/ + git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git + cd CosyVoice + git submodule update --init --recursive +fi + + if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then echo "Converting checkpoint to TensorRT weights" python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \ From dc196df9402507bf44e2ac81e6caae2c140a4393 Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Tue, 29 Jul 2025 11:13:07 +0800 Subject: [PATCH 3/6] fix decoupled mode --- .../model_repo/cosyvoice2/1/model.py | 25 +++++--- runtime/triton_trtllm/requirements.txt | 3 +- runtime/triton_trtllm/run.sh | 57 +++++++++++-------- 3 files changed, 52 insertions(+), 33 deletions(-) diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py index cd61931..cb91677 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py @@ -295,11 +295,26 @@ class TritonPythonModel: if self.decoupled: response_sender = request.get_response_sender() request_id = request.request_id() - for generated_ids in generated_ids_iter: - raise NotImplementedError("Decoupled mode is not implemented") + generated_ids = [] + for generated_id in generated_ids_iter: + # convert the numpy array into a int32 tensor + generated_id = generated_id.tolist() + if len(generated_id) > 0: + assert len(generated_id) == 1, "Generated ID is not a single integer" + generated_ids.append(generated_id[0]) + generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device) + prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) + audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) + + # Prepare response + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + self.logger.log_info(f"send tritonserver_response_complete_final to end") else: generated_ids = next(generated_ids_iter) - generated_ids = torch.tensor([generated_ids]).to(self.device) + generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device) if generated_ids is None or len(generated_ids) == 0: raise pb_utils.TritonModelException("Generated IDs is None or empty") @@ -311,9 +326,5 @@ class TritonPythonModel: inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) responses.append(inference_response) - if self.decoupled: - response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) - self.logger.log_info(f"send tritonserver_response_complete_final to end") - if not self.decoupled: return responses \ No newline at end of file diff --git a/runtime/triton_trtllm/requirements.txt b/runtime/triton_trtllm/requirements.txt index 9f675b6..ba0623b 100644 --- a/runtime/triton_trtllm/requirements.txt +++ b/runtime/triton_trtllm/requirements.txt @@ -10,4 +10,5 @@ wget librosa pyworld openai-whisper -tritonclient \ No newline at end of file +tritonclient +modelscope diff --git a/runtime/triton_trtllm/run.sh b/runtime/triton_trtllm/run.sh index 10e6a67..922105d 100644 --- a/runtime/triton_trtllm/run.sh +++ b/runtime/triton_trtllm/run.sh @@ -1,25 +1,31 @@ export CUDA_VISIBLE_DEVICES=0 -export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice:$PYTHONPATH -export PYTHONPATH=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS:$PYTHONPATH +cosyvoice_path=/workspace/CosyVoice +export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH +export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH stage=$1 stop_stage=$2 -huggingface_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/cosyvoice2_llm -model_scope_model_local_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice2-0.5B +huggingface_model_local_dir=./cosyvoice2_llm +model_scope_model_local_dir=./CosyVoice2-0.5B trt_dtype=bfloat16 -trt_weights_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_weights_${trt_dtype} -trt_engines_dir=/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/trt_engines_${trt_dtype} +trt_weights_dir=./trt_weights_${trt_dtype} +trt_engines_dir=./trt_engines_${trt_dtype} model_repo=./model_repo_cosyvoice2 -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - echo " " - huggingface-cli download --local-dir cosyvoice2_llm yuekai/cosyvoice2_llm - modelscope download --model iic/CosyVoice2-0.5B --local_dir ./CosyVoice2-0.5B/ - git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git - cd CosyVoice +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + echo "Cloning CosyVoice" + git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path + cd $cosyvoice_path git submodule update --init --recursive + cd runtime/triton_trtllm +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + echo "Downloading CosyVoice2-0.5B" + huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm + modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir fi @@ -35,17 +41,15 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then --max_batch_size 16 \ --max_num_tokens 32768 \ --gemm_plugin $trt_dtype || exit 1 -fi -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then echo "Testing TensorRT engines" - python3 ./test_llm.py --input_text "你好,请问你叫什么?" \ + python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \ --tokenizer_dir $huggingface_model_local_dir \ --top_k 50 --top_p 0.95 --temperature 0.8 \ --engine_dir=$trt_engines_dir || exit 1 fi -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then echo "Creating model repository" rm -rf $model_repo mkdir -p $model_repo @@ -71,28 +75,31 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + echo "Starting Triton server" tritonserver --model-repository $model_repo fi -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - echo "Testing TensorRT engines" +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + echo "Single request test http" python3 client_http.py \ - --reference-audio ./prompt_audio.wav \ + --reference-audio ./assets/prompt_audio.wav \ --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \ --model-name cosyvoice2 fi -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - echo "Running benchmark client" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + echo "Running benchmark client grpc" num_task=4 + # set mode=streaming, when decoupled=True + # set mode=offline, when decoupled=False + mode=offline python3 client_grpc.py \ --server-addr localhost \ --model-name cosyvoice2 \ --num-tasks $num_task \ - --mode offline \ + --mode $mode \ --huggingface-dataset yuekai/seed_tts_cosy2 \ - --log-dir ./log_concurrent_tasks_${num_task}_offline_bls_4_${trt_dtype} + --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_4_${trt_dtype} fi \ No newline at end of file From b44f12110224cb11c03aee4084b1597e7b9331cb Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Tue, 29 Jul 2025 11:58:23 +0800 Subject: [PATCH 4/6] update readme --- runtime/triton_trtllm/README.md | 115 +++++++++++------------ runtime/triton_trtllm/docker-compose.yml | 4 +- 2 files changed, 57 insertions(+), 62 deletions(-) diff --git a/runtime/triton_trtllm/README.md b/runtime/triton_trtllm/README.md index 999d698..0fc7c48 100644 --- a/runtime/triton_trtllm/README.md +++ b/runtime/triton_trtllm/README.md @@ -1,94 +1,89 @@ -## Nvidia Triton Inference Serving Best Practice for Spark TTS +## Best Practices for Serving CosyVoice with NVIDIA Triton Inference Server ### Quick Start -Directly launch the service using docker compose. +Launch the service directly with Docker Compose: ```sh docker compose up ``` -### Build Image -Build the docker image from scratch. +### Build the Docker Image +Build the image from scratch: ```sh -docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02 +docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06 ``` -### Create Docker Container +### Run a Docker Container ```sh your_mount_dir=/mnt:/mnt -docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02 +docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06 ``` ### Understanding `run.sh` +The `run.sh` script orchestrates the entire workflow through numbered stages. -The `run.sh` script automates various steps using stages. You can run specific stages using: +Run a subset of stages with: ```sh bash run.sh [service_type] ``` -- ``: The stage to begin execution from (0-5). -- ``: The stage to end execution at (0-5). -- `[service_type]`: Optional, specifies the service type ('streaming' or 'offline', defaults may apply based on script logic). Required for stages 4 and 5. +- `` – stage to start from (0-5). +- `` – stage to stop after (0-5). Stages: -- **Stage 0**: Download Spark-TTS-0.5B model from HuggingFace. -- **Stage 1**: Convert HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines. -- **Stage 2**: Create the Triton model repository structure and configure model files (adjusts for streaming/offline). -- **Stage 3**: Launch the Triton Inference Server. -- **Stage 4**: Run the gRPC benchmark client. -- **Stage 5**: Run the single utterance client (gRPC for streaming, HTTP for offline). +- **Stage 0** – Download the cosyvoice-2 0.5B model from HuggingFace. +- **Stage 1** – Convert the HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines. +- **Stage 2** – Create the Triton model repository and configure the model files (adjusts depending on whether `Decoupled=True/False` will be used later). +- **Stage 3** – Launch the Triton Inference Server. +- **Stage 4** – Run the single-utterance HTTP client. +- **Stage 5** – Run the gRPC benchmark client. -### Export Models to TensorRT-LLM and Launch Server -Inside the docker container, you can prepare the models and launch the Triton server by running stages 0 through 3. This involves downloading the original model, converting it to TensorRT-LLM format, building the optimized TensorRT engines, creating the necessary model repository structure for Triton, and finally starting the server. +### Export Models to TensorRT-LLM and Launch the Server +Inside the Docker container, prepare the models and start the Triton server by running stages 0-3: ```sh -# This runs stages 0, 1, 2, and 3 +# Runs stages 0, 1, 2, and 3 bash run.sh 0 3 ``` -*Note: Stage 2 prepares the model repository differently based on whether you intend to run streaming or offline inference later. You might need to re-run stage 2 if switching service types.* +*Note: Stage 2 prepares the model repository differently depending on whether you intend to run with `Decoupled=False` or `Decoupled=True`. Rerun stage 2 if you switch the service type.* - -### Single Utterance Client -Run a single inference request. Specify `streaming` or `offline` as the third argument. - -**Streaming Mode (gRPC):** +### Single-Utterance HTTP Client +Send a single HTTP inference request: ```sh -bash run.sh 5 5 streaming -``` -This executes the `client_grpc.py` script with predefined example text and prompt audio in streaming mode. - -**Offline Mode (HTTP):** -```sh -bash run.sh 5 5 offline +bash run.sh 4 4 ``` -### Benchmark using Dataset -Run the benchmark client against the running Triton server. Specify `streaming` or `offline` as the third argument. +### Benchmark with a Dataset +Benchmark the running Triton server. Pass either `streaming` or `offline` as the third argument. ```sh -# Run benchmark in streaming mode -bash run.sh 4 4 streaming +bash run.sh 5 5 -# Run benchmark in offline mode -bash run.sh 4 4 offline - -# You can also customize parameters like num_task directly in client_grpc.py or via args if supported -# Example from run.sh (streaming): -# python3 client_grpc.py \ -# --server-addr localhost \ -# --model-name spark_tts \ -# --num-tasks 2 \ -# --mode streaming \ -# --log-dir ./log_concurrent_tasks_2_streaming_new - -# Example customizing dataset (requires modifying client_grpc.py or adding args): -# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --mode [streaming|offline] +# You can also customise parameters such as num_task and dataset split directly: +# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline] ``` +> [!TIP] +> Only offline CosyVoice TTS is currently supported. Setting the client to `streaming` simply enables NVIDIA Triton’s decoupled mode so that responses are returned as soon as they are ready. ### Benchmark Results -Decoding on a single L20 GPU, using 26 different prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts), total audio duration 169 secs. +Decoding on a single L20 GPU with 26 prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts) (≈221 s of audio): + +| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF | +|------|------|-------------|------------------|------------------|-----| +| Decoupled=False | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 758.04 | 615.79 | 0.0891 | +| Decoupled=False | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 1025.93 | 901.68 | 0.0657 | +| Decoupled=False | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1914.13 | 1783.58 | 0.0610 | +| Decoupled=True | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 659.87 | 655.63 | 0.0891 | +| Decoupled=True | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 1103.16 | 992.96 | 0.0693 | +| Decoupled=True | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1790.91 | 1668.63 | 0.0604 | + +### OpenAI-Compatible Server +To launch an OpenAI-compatible service, run: +```sh +git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git +pip install -r requirements.txt +# After the Triton service is up, start the FastAPI bridge: +python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000 +# Test with curl +bash test/test_cosyvoice.sh +``` + +### Acknowledgements +This section originates from the NVIDIA CISI project. We also provide other multimodal resources—see [mair-hub](https://github.com/nvidia-china-sae/mair-hub) for details. -| Mode | Note | Concurrency | Avg Latency | First Chunk Latency (P50) | RTF | -|-------|-----------|-----------------------|---------|----------------|-| -| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms |-| 0.1362| -| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms |-|0.0737| -| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms |-| 0.0704| -| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 1 | 913.28 ms |210.42 ms| 0.1501 | -| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 2 | 1009.23 ms |226.08 ms |0.0862 | -| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 4 | 1793.86 ms |1017.70 ms| 0.0824 | \ No newline at end of file diff --git a/runtime/triton_trtllm/docker-compose.yml b/runtime/triton_trtllm/docker-compose.yml index eca94bc..e221e56 100644 --- a/runtime/triton_trtllm/docker-compose.yml +++ b/runtime/triton_trtllm/docker-compose.yml @@ -1,6 +1,6 @@ services: tts: - image: soar97/triton-spark-tts:25.02 + image: soar97/triton-cosyvoice:25.06 shm_size: '1gb' ports: - "8000:8000" @@ -17,4 +17,4 @@ services: device_ids: ['0'] capabilities: [gpu] command: > - /bin/bash -c "rm -rf Spark-TTS && git clone https://github.com/SparkAudio/Spark-TTS.git && cd Spark-TTS/runtime/triton_trtllm && bash run.sh 0 3" + /bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/FunAudioLLM/CosyVoice.git && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3" \ No newline at end of file From 1b8d194b673a83c528f3d64794f0c68780baea88 Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Tue, 29 Jul 2025 12:01:55 +0800 Subject: [PATCH 5/6] fix commit --- runtime/triton_trtllm/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/runtime/triton_trtllm/README.md b/runtime/triton_trtllm/README.md index 0fc7c48..017d8fc 100644 --- a/runtime/triton_trtllm/README.md +++ b/runtime/triton_trtllm/README.md @@ -66,12 +66,12 @@ Decoding on a single L20 GPU with 26 prompt_audio/target_text [pairs](https://hu | Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF | |------|------|-------------|------------------|------------------|-----| -| Decoupled=False | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 758.04 | 615.79 | 0.0891 | -| Decoupled=False | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 1025.93 | 901.68 | 0.0657 | -| Decoupled=False | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1914.13 | 1783.58 | 0.0610 | -| Decoupled=True | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 659.87 | 655.63 | 0.0891 | -| Decoupled=True | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 1103.16 | 992.96 | 0.0693 | -| Decoupled=True | [Commit](https://github.com/SparkAudio/cosyvoice/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1790.91 | 1668.63 | 0.0604 | +| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 | +| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 | +| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 | +| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 659.87 | 655.63 | 0.0891 | +| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1103.16 | 992.96 | 0.0693 | +| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1790.91 | 1668.63 | 0.0604 | ### OpenAI-Compatible Server To launch an OpenAI-compatible service, run: From 07cbc51cd1d95abc354c03095b31adfd3d8737f4 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 29 Jul 2025 08:39:41 +0000 Subject: [PATCH 6/6] fix lint --- runtime/triton_trtllm/client_grpc.py | 101 +++++++++-------- runtime/triton_trtllm/client_http.py | 22 ++-- .../model_repo/audio_tokenizer/1/model.py | 21 ++-- .../model_repo/cosyvoice2/1/model.py | 106 ++++++++++-------- .../model_repo/token2wav/1/model.py | 24 ++-- .../scripts/convert_checkpoint.py | 40 +++---- .../triton_trtllm/scripts/fill_template.py | 5 +- runtime/triton_trtllm/scripts/test_llm.py | 3 +- 8 files changed, 165 insertions(+), 157 deletions(-) diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 7dba493..881b519 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) # 2023 Nvidia (authors: Yuekai Zhang) # 2023 Recurrent.ai (authors: Songtao Shi) @@ -46,7 +45,7 @@ import asyncio import json import queue # Added import uuid # Added -import functools # Added +import functools # Added import os import time @@ -56,9 +55,9 @@ from pathlib import Path import numpy as np import soundfile as sf import tritonclient -import tritonclient.grpc.aio as grpcclient_aio # Renamed original import -import tritonclient.grpc as grpcclient_sync # Added sync client import -from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException +import tritonclient.grpc.aio as grpcclient_aio # Renamed original import +import tritonclient.grpc as grpcclient_sync # Added sync client import +from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException # --- Added UserData and callback --- @@ -76,9 +75,10 @@ class UserData: return self._first_chunk_time - self._start_time return None + def callback(user_data, result, error): if user_data._first_chunk_time is None and not error: - user_data._first_chunk_time = time.time() # Record time of first successful chunk + user_data._first_chunk_time = time.time() # Record time of first successful chunk if error: user_data._completed_requests.put(error) else: @@ -206,8 +206,11 @@ def get_args(): "--model-name", type=str, default="f5_tts", - choices=["f5_tts", "spark_tts", "cosyvoice2"], - help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", + choices=[ + "f5_tts", + "spark_tts", + "cosyvoice2"], + help="triton model_repo module name to request", ) parser.add_argument( @@ -273,13 +276,14 @@ def load_audio(wav_path, target_sample_rate=16000): waveform = resample(waveform, num_samples) return waveform, target_sample_rate + def prepare_request_input_output( - protocol_client, # Can be grpcclient_aio or grpcclient_sync + protocol_client, # Can be grpcclient_aio or grpcclient_sync waveform, reference_text, target_text, sample_rate=16000, - padding_duration: int = None # Optional padding for offline mode + padding_duration: int = None # Optional padding for offline mode ): """Prepares inputs for Triton inference (offline or streaming).""" assert len(waveform.shape) == 1, "waveform should be 1D" @@ -291,9 +295,9 @@ def prepare_request_input_output( # Estimate target duration based on text length ratio (crude estimation) # Avoid division by zero if reference_text is empty if reference_text: - estimated_target_duration = duration / len(reference_text) * len(target_text) + estimated_target_duration = duration / len(reference_text) * len(target_text) else: - estimated_target_duration = duration # Assume target duration similar to reference if no text + estimated_target_duration = duration # Assume target duration similar to reference if no text # Calculate required samples based on estimated total duration required_total_samples = padding_duration * sample_rate * ( @@ -329,6 +333,7 @@ def prepare_request_input_output( return inputs, outputs + def run_sync_streaming_inference( sync_triton_client: tritonclient.grpc.InferenceServerClient, model_name: str, @@ -342,7 +347,7 @@ def run_sync_streaming_inference( ): """Helper function to run the blocking sync streaming call.""" start_time_total = time.time() - user_data.record_start_time() # Record start time for first chunk latency calculation + user_data.record_start_time() # Record start time for first chunk latency calculation # Establish stream sync_triton_client.start_stream(callback=functools.partial(callback, user_data)) @@ -360,11 +365,11 @@ def run_sync_streaming_inference( audios = [] while True: try: - result = user_data._completed_requests.get() # Add timeout + result = user_data._completed_requests.get() # Add timeout if isinstance(result, InferenceServerException): print(f"Received InferenceServerException: {result}") sync_triton_client.stop_stream() - return None, None, None # Indicate error + return None, None, None # Indicate error # Get response metadata response = result.get_response() final = response.parameters["triton_final_response"].bool_param @@ -372,15 +377,15 @@ def run_sync_streaming_inference( break audio_chunk = result.as_numpy("waveform").reshape(-1) - if audio_chunk.size > 0: # Only append non-empty chunks - audios.append(audio_chunk) + if audio_chunk.size > 0: # Only append non-empty chunks + audios.append(audio_chunk) else: print("Warning: received empty audio chunk.") except queue.Empty: print(f"Timeout waiting for response for request id {request_id}") sync_triton_client.stop_stream() - return None, None, None # Indicate error + return None, None, None # Indicate error sync_triton_client.stop_stream() end_time_total = time.time() @@ -398,19 +403,19 @@ def run_sync_streaming_inference( # Simplified reconstruction based on client_grpc_streaming.py if not audios: print("Warning: No audio chunks received.") - reconstructed_audio = np.array([], dtype=np.float32) # Empty array + reconstructed_audio = np.array([], dtype=np.float32) # Empty array elif len(audios) == 1: reconstructed_audio = audios[0] else: - reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap + reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap for i in range(1, len(audios)): - # Cross-fade section - cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + - audios[i - 1][-cross_fade_samples:] * fade_out) - # Middle section of the current chunk - middle_part = audios[i][cross_fade_samples:-cross_fade_samples] - # Concatenate - reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) + # Cross-fade section + cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + + audios[i - 1][-cross_fade_samples:] * fade_out) + # Middle section of the current chunk + middle_part = audios[i][cross_fade_samples:-cross_fade_samples] + # Concatenate + reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) # Add the last part of the final chunk reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) @@ -421,11 +426,11 @@ def run_sync_streaming_inference( sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") else: print("Warning: No audio chunks received or reconstructed.") - actual_duration = 0 # Set duration to 0 if no audio + actual_duration = 0 # Set duration to 0 if no audio else: - print("Warning: No audio chunks received.") - actual_duration = 0 + print("Warning: No audio chunks received.") + actual_duration = 0 return total_request_latency, first_chunk_latency, actual_duration @@ -433,7 +438,7 @@ def run_sync_streaming_inference( async def send_streaming( manifest_item_list: list, name: str, - server_url: str, # Changed from sync_triton_client + server_url: str, # Changed from sync_triton_client protocol_client: types.ModuleType, log_interval: int, model_name: str, @@ -445,11 +450,11 @@ async def send_streaming( total_duration = 0.0 latency_data = [] task_id = int(name[5:]) - sync_triton_client = None # Initialize client variable + sync_triton_client = None # Initialize client variable - try: # Wrap in try...finally to ensure client closing + try: # Wrap in try...finally to ensure client closing print(f"{name}: Initializing sync client for streaming...") - sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here + sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.") for i, item in enumerate(manifest_item_list): @@ -491,8 +496,7 @@ async def send_streaming( latency_data.append((total_request_latency, first_chunk_latency, actual_duration)) total_duration += actual_duration else: - print(f"{name}: Item {i} failed.") - + print(f"{name}: Item {i} failed.") except FileNotFoundError: print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}") @@ -501,8 +505,7 @@ async def send_streaming( import traceback traceback.print_exc() - - finally: # Ensure client is closed + finally: # Ensure client is closed if sync_triton_client: try: print(f"{name}: Closing sync client...") @@ -510,10 +513,10 @@ async def send_streaming( except Exception as e: print(f"{name}: Error closing sync client: {e}") - print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s") return total_duration, latency_data + async def send( manifest_item_list: list, name: str, @@ -605,6 +608,7 @@ def split_data(data, k): return result + async def main(): args = get_args() url = f"{args.server_addr}:{args.server_port}" @@ -622,7 +626,7 @@ async def main(): # Use the sync client for streaming tasks, handled via asyncio.to_thread # We will create one sync client instance PER TASK inside send_streaming. # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now - protocol_client = grpcclient_sync # protocol client for input prep + protocol_client = grpcclient_sync # protocol client for input prep else: raise ValueError(f"Invalid mode: {args.mode}") # --- End Client Initialization --- @@ -682,11 +686,11 @@ async def main(): ) ) elif args.mode == "streaming": - task = asyncio.create_task( + task = asyncio.create_task( send_streaming( manifest_item_list[i], name=f"task-{i}", - server_url=url, # Pass URL instead of client + server_url=url, # Pass URL instead of client protocol_client=protocol_client, log_interval=args.log_interval, model_name=args.model_name, @@ -709,16 +713,15 @@ async def main(): for ans in ans_list: if ans: total_duration += ans[0] - latency_data.extend(ans[1]) # Use extend for list of lists + latency_data.extend(ans[1]) # Use extend for list of lists else: - print("Warning: A task returned None, possibly due to an error.") - + print("Warning: A task returned None, possibly due to an error.") if total_duration == 0: print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.") rtf = float('inf') else: - rtf = elapsed / total_duration + rtf = elapsed / total_duration s = f"Mode: {args.mode}\n" s += f"RTF: {rtf:.4f}\n" @@ -759,7 +762,7 @@ async def main(): s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n" s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n" else: - s += "No total request latency data collected.\n" + s += "No total request latency data collected.\n" s += "\n--- First Chunk Latency ---\n" if first_chunk_latency_list: @@ -772,7 +775,7 @@ async def main(): s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n" s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n" else: - s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n" + s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n" else: s += "No latency data collected.\n" # --- End Statistics Reporting --- @@ -785,7 +788,7 @@ async def main(): elif args.reference_audio: name = Path(args.reference_audio).stem else: - name = "results" # Default name if no manifest/split/audio provided + name = "results" # Default name if no manifest/split/audio provided with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: f.write(s) diff --git a/runtime/triton_trtllm/client_http.py b/runtime/triton_trtllm/client_http.py index e22f4eb..4d73e0b 100644 --- a/runtime/triton_trtllm/client_http.py +++ b/runtime/triton_trtllm/client_http.py @@ -29,6 +29,7 @@ import json import numpy as np import argparse + def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -67,9 +68,10 @@ def get_args(): type=str, default="spark_tts", choices=[ - "f5_tts", "spark_tts", "cosyvoice2" - ], - help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", + "f5_tts", + "spark_tts", + "cosyvoice2"], + help="triton model_repo module name to request", ) parser.add_argument( @@ -80,6 +82,7 @@ def get_args(): ) return parser.parse_args() + def prepare_request( waveform, reference_text, @@ -97,7 +100,7 @@ def prepare_request( 1, padding_duration * sample_rate - * ((int(duration) // padding_duration) + 1), + * ((int(len(waveform) / sample_rate) // padding_duration) + 1), ), dtype=np.float32, ) @@ -105,11 +108,11 @@ def prepare_request( samples[0, : len(waveform)] = waveform else: samples = waveform - + samples = samples.reshape(1, -1).astype(np.float32) data = { - "inputs":[ + "inputs": [ { "name": "reference_wav", "shape": samples.shape, @@ -139,16 +142,17 @@ def prepare_request( return data + if __name__ == "__main__": args = get_args() server_url = args.server_url if not server_url.startswith(("http://", "https://")): server_url = f"http://{server_url}" - + url = f"{server_url}/v2/models/{args.model_name}/infer" waveform, sr = sf.read(args.reference_audio) assert sr == 16000, "sample rate hardcoded in server" - + samples = np.array(waveform, dtype=np.float32) data = prepare_request(samples, args.reference_text, args.target_text) @@ -166,4 +170,4 @@ if __name__ == "__main__": sample_rate = 16000 else: sample_rate = 24000 - sf.write(args.output_audio, audio, sample_rate, "PCM_16") \ No newline at end of file + sf.write(args.output_audio, audio, sample_rate, "PCM_16") diff --git a/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py index 105ffa1..47383e2 100644 --- a/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py +++ b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py @@ -35,33 +35,34 @@ import s3tokenizer ORIGINAL_VOCAB_SIZE = 151663 + class TritonPythonModel: """Triton Python model for audio tokenization. - + This model takes reference audio input and extracts semantic tokens using s3tokenizer. """ def initialize(self, args): """Initialize the model. - + Args: args: Dictionary containing model configuration """ # Parse model parameters parameters = json.loads(args['model_config'])['parameters'] model_params = {k: v["string_value"] for k, v in parameters.items()} - + self.device = torch.device("cuda") model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx") self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device) def execute(self, requests): """Execute inference on the batched requests. - + Args: requests: List of inference requests - + Returns: List of inference responses containing tokenized outputs """ @@ -79,18 +80,18 @@ class TritonPythonModel: # Prepare inputs wav = wav_array[:, :wav_len].squeeze(0) mels.append(s3tokenizer.log_mel_spectrogram(wav)) - + mels, mels_lens = s3tokenizer.padding(mels) codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device)) codes = codes.clone() + ORIGINAL_VOCAB_SIZE - + responses = [] for i in range(len(requests)): - prompt_speech_tokens = codes[i, :codes_lens[i].item()] + prompt_speech_tokens = codes[i, :codes_lens[i].item()] prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack( "prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) inference_response = pb_utils.InferenceResponse( output_tensors=[prompt_speech_tokens_tensor]) responses.append(inference_response) - - return responses \ No newline at end of file + + return responses diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py index cb91677..77a440b 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py @@ -42,16 +42,17 @@ import onnxruntime from matcha.utils.audio import mel_spectrogram + class TritonPythonModel: """Triton Python model for Spark TTS. - + This model orchestrates the end-to-end TTS pipeline by coordinating between audio tokenizer, LLM, and vocoder components. """ - + def initialize(self, args): """Initialize the model. - + Args: args: Dictionary containing model configuration """ @@ -116,58 +117,58 @@ class TritonPythonModel: "input_ids": input_ids, "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), } - + # Convert inputs to Triton tensors input_tensor_list = [ pb_utils.Tensor(k, v) for k, v in input_dict.items() ] - + # Create and execute inference request llm_request = pb_utils.InferenceRequest( model_name="tensorrt_llm", requested_output_names=["output_ids", "sequence_length"], inputs=input_tensor_list, ) - + llm_responses = llm_request.exec(decoupled=self.decoupled) if self.decoupled: for llm_response in llm_responses: if llm_response.has_error(): raise pb_utils.TritonModelException(llm_response.error().message()) - + # Extract and process output output_ids = pb_utils.get_output_tensor_by_name( llm_response, "output_ids").as_numpy() seq_lens = pb_utils.get_output_tensor_by_name( llm_response, "sequence_length").as_numpy() - + # Get actual output IDs up to the sequence length actual_output_ids = output_ids[0][0][:seq_lens[0][0]] - + yield actual_output_ids else: llm_response = llm_responses if llm_response.has_error(): raise pb_utils.TritonModelException(llm_response.error().message()) - + # Extract and process output output_ids = pb_utils.get_output_tensor_by_name( llm_response, "output_ids").as_numpy() seq_lens = pb_utils.get_output_tensor_by_name( llm_response, "sequence_length").as_numpy() - + # Get actual output IDs up to the sequence length actual_output_ids = output_ids[0][0][:seq_lens[0][0]] - - yield actual_output_ids - + + yield actual_output_ids + def forward_audio_tokenizer(self, wav, wav_len): """Forward pass through the audio tokenizer component. - + Args: wav: Input waveform tensor wav_len: Waveform length tensor - + Returns: Tuple of global and semantic tokens """ @@ -176,26 +177,31 @@ class TritonPythonModel: requested_output_names=['prompt_speech_tokens'], inputs=[wav, wav_len] ) - + inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) - + # Extract and convert output tensors prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() return prompt_speech_tokens - def forward_token2wav(self, prompt_speech_tokens: torch.Tensor, prompt_speech_feat: torch.Tensor, prompt_spk_embedding: torch.Tensor, target_speech_tokens: torch.Tensor) -> torch.Tensor: + def forward_token2wav( + self, + prompt_speech_tokens: torch.Tensor, + prompt_speech_feat: torch.Tensor, + prompt_spk_embedding: torch.Tensor, + target_speech_tokens: torch.Tensor) -> torch.Tensor: """Forward pass through the vocoder component. - + Args: prompt_speech_tokens: Prompt speech tokens tensor prompt_speech_feat: Prompt speech feat tensor prompt_spk_embedding: Prompt spk embedding tensor target_speech_tokens: Target speech tokens tensor - + Returns: Generated waveform tensor """ @@ -203,22 +209,22 @@ class TritonPythonModel: prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) - + # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav', requested_output_names=['waveform'], inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor] ) - + inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) - + # Extract and convert output waveform waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() - + return waveform def parse_input(self, text, prompt_text, prompt_speech_tokens): @@ -231,43 +237,53 @@ class TritonPythonModel: def _extract_spk_embedding(self, speech): feat = kaldi.fbank(speech, - num_mel_bins=80, - dither=0, - sample_frequency=16000) + num_mel_bins=80, + dither=0, + sample_frequency=16000) feat = feat - feat.mean(dim=0, keepdim=True) embedding = self.campplus_session.run(None, - {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() + {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() embedding = torch.tensor([embedding]).to(self.device).half() return embedding - def _extract_speech_feat(self, speech): - speech_feat = mel_spectrogram(speech, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000).squeeze(dim=0).transpose(0, 1).to(self.device) + speech_feat = mel_spectrogram( + speech, + n_fft=1920, + num_mels=80, + sampling_rate=24000, + hop_size=480, + win_size=1920, + fmin=0, + fmax=8000).squeeze( + dim=0).transpose( + 0, + 1).to( + self.device) speech_feat = speech_feat.unsqueeze(dim=0) return speech_feat def execute(self, requests): """Execute inference on the batched requests. - + Args: requests: List of inference requests - + Returns: List of inference responses containing generated audio """ responses = [] - + for request in requests: # Extract input tensors wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") - + # Process reference audio through audio tokenizer prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - wav_tensor = wav.as_numpy() wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) @@ -275,20 +291,20 @@ class TritonPythonModel: token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() - + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') - + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode('utf-8') - + # Prepare prompt for LLM input_ids = self.parse_input( text=target_text, prompt_text=reference_text, prompt_speech_tokens=prompt_speech_tokens, ) - + # Generate semantic tokens with LLM generated_ids_iter = self.forward_llm(input_ids) @@ -305,13 +321,13 @@ class TritonPythonModel: generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device) prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) - + # Prepare response audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) - self.logger.log_info(f"send tritonserver_response_complete_final to end") + self.logger.log_info("send tritonserver_response_complete_final to end") else: generated_ids = next(generated_ids_iter) generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device) @@ -320,11 +336,11 @@ class TritonPythonModel: prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) - + # Prepare response audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) responses.append(inference_response) - + if not self.decoupled: - return responses \ No newline at end of file + return responses diff --git a/runtime/triton_trtllm/model_repo/token2wav/1/model.py b/runtime/triton_trtllm/model_repo/token2wav/1/model.py index d6735a1..d38f8a4 100644 --- a/runtime/triton_trtllm/model_repo/token2wav/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav/1/model.py @@ -44,6 +44,7 @@ logger = logging.getLogger(__name__) ORIGINAL_VOCAB_SIZE = 151663 + class CosyVoice2: def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): @@ -66,6 +67,7 @@ class CosyVoice2: trt_concurrent, self.fp16) + class CosyVoice2Model: def __init__(self, @@ -109,16 +111,17 @@ class CosyVoice2Model: input_names = ["x", "mask", "mu", "cond"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + class TritonPythonModel: """Triton Python model for vocoder. - + This model takes global and semantic tokens as input and generates audio waveforms using the BiCodec vocoder. """ def initialize(self, args): """Initialize the model. - + Args: args: Dictionary containing model configuration """ @@ -126,24 +129,23 @@ class TritonPythonModel: parameters = json.loads(args['model_config'])['parameters'] model_params = {key: value["string_value"] for key, value in parameters.items()} model_dir = model_params["model_dir"] - + # Initialize device and vocoder self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Initializing vocoder from {model_dir} on {self.device}") - + self.token2wav_model = CosyVoice2( model_dir, load_jit=True, load_trt=True, fp16=True ) logger.info("Token2Wav initialized successfully") - def execute(self, requests): """Execute inference on the batched requests. - + Args: requests: List of inference requests - + Returns: List of inference responses containing generated waveforms """ @@ -163,7 +165,7 @@ class TritonPythonModel: # shift the speech tokens according to the original vocab size prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE - + tts_mel, _ = self.token2wav_model.model.flow.inference( token=target_speech_tokens, token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( @@ -189,9 +191,5 @@ class TritonPythonModel: wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) responses.append(inference_response) - + return responses - - - - diff --git a/runtime/triton_trtllm/scripts/convert_checkpoint.py b/runtime/triton_trtllm/scripts/convert_checkpoint.py index 932cdf8..7cd166f 100644 --- a/runtime/triton_trtllm/scripts/convert_checkpoint.py +++ b/runtime/triton_trtllm/scripts/convert_checkpoint.py @@ -35,8 +35,7 @@ def parse_arguments(): type=str, default='auto', choices=['auto', 'float16', 'bfloat16', 'float32'], - help= - "The data type for the model weights and activations if not quantized. " + help="The data type for the model weights and activations if not quantized. " "If 'auto', the data type is automatically inferred from the source model; " "however, if the source dtype is float32, it is converted to float16.") parser.add_argument( @@ -49,8 +48,7 @@ def parse_arguments(): '--disable_weight_only_quant_plugin', default=False, action="store_true", - help= - 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' + help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument( @@ -60,16 +58,14 @@ def parse_arguments(): nargs='?', default='int8', choices=['int8', 'int4', 'int4_gptq'], - help= - 'Define the precision for the weights when using weight-only quantization.' + help='Define the precision for the weights when using weight-only quantization.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument( '--calib_dataset', type=str, default='ccdv/cnn_dailymail', - help= - "The huggingface dataset name or the local directory of the dataset for calibration." + help="The huggingface dataset name or the local directory of the dataset for calibration." ) parser.add_argument( "--smoothquant", @@ -83,31 +79,27 @@ def parse_arguments(): '--per_channel', action="store_true", default=False, - help= - 'By default, we use a single static scaling factor for the GEMM\'s result. ' + help='By default, we use a single static scaling factor for the GEMM\'s result. ' 'per_channel instead uses a different static scaling factor for each channel. ' 'The latter is usually more accurate, but a little slower.') parser.add_argument( '--per_token', action="store_true", default=False, - help= - 'By default, we use a single static scaling factor to scale activations in the int8 range. ' + help='By default, we use a single static scaling factor to scale activations in the int8 range. ' 'per_token chooses at run time, and for each token, a custom scaling factor. ' 'The latter is usually more accurate, but a little slower.') parser.add_argument( '--int8_kv_cache', default=False, action="store_true", - help= - 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' + help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' ) parser.add_argument( '--per_group', default=False, action="store_true", - help= - 'By default, we use a single static scaling factor to scale weights in the int4 range. ' + help='By default, we use a single static scaling factor to scale weights in the int4 range. ' 'per_group chooses at run time, and for each group, a custom scaling factor. ' 'The flag is built for GPTQ/AWQ quantization.') @@ -121,16 +113,14 @@ def parse_arguments(): '--use_parallel_embedding', action="store_true", default=False, - help= - 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' ) parser.add_argument( '--embedding_sharding_dim', type=int, default=0, choices=[0, 1], - help= - 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' 'To shard it along hidden dimension, set embedding_sharding_dim=1' 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' ) @@ -147,15 +137,13 @@ def parse_arguments(): '--moe_tp_size', type=int, default=-1, - help= - 'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE' + help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE' ) parser.add_argument( '--moe_ep_size', type=int, default=-1, - help= - 'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE' + help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE' ) args = parser.parse_args() return args @@ -249,7 +237,7 @@ def convert_and_save_hf(args): trust_remote_code=True) quant_config, override_fields = update_quant_config_from_hf( quant_config, hf_config, override_fields) - except: + except BaseException: logger.warning("AutoConfig cannot load the huggingface config.") if args.smoothquant is not None or args.int8_kv_cache: @@ -339,4 +327,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/runtime/triton_trtllm/scripts/fill_template.py b/runtime/triton_trtllm/scripts/fill_template.py index 5c629f7..6e6a2bc 100644 --- a/runtime/triton_trtllm/scripts/fill_template.py +++ b/runtime/triton_trtllm/scripts/fill_template.py @@ -1,4 +1,4 @@ -#! /usr/bin/env python3 +# /usr/bin/env python3 from argparse import ArgumentParser from string import Template @@ -59,8 +59,7 @@ if __name__ == "__main__": parser.add_argument("file_path", help="path of the .pbtxt to modify") parser.add_argument( "substitutions", - help= - "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." + help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." ) parser.add_argument("--in_place", "-i", diff --git a/runtime/triton_trtllm/scripts/test_llm.py b/runtime/triton_trtllm/scripts/test_llm.py index 9ffe9cf..d52d724 100644 --- a/runtime/triton_trtllm/scripts/test_llm.py +++ b/runtime/triton_trtllm/scripts/test_llm.py @@ -46,7 +46,6 @@ def parse_arguments(args=None): parser.add_argument('--top_k', type=int, default=50) parser.add_argument('--top_p', type=float, default=0.95) - return parser.parse_args(args=args) @@ -60,7 +59,7 @@ def parse_input(tokenizer, input_ids = tokenizer.encode( curr_text) batch_input_ids.append(input_ids) - + batch_input_ids = [ torch.tensor(x, dtype=torch.int32) for x in batch_input_ids ]