mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
Merge pull request #1507 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
This commit is contained in:
@@ -29,6 +29,10 @@
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [x] 2025/08
|
||||
|
||||
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support
|
||||
|
||||
- [x] 2025/07
|
||||
|
||||
- [x] release cosyvoice 3.0 eval set
|
||||
|
||||
8
runtime/triton_trtllm/Dockerfile.server
Normal file
8
runtime/triton_trtllm/Dockerfile.server
Normal file
@@ -0,0 +1,8 @@
|
||||
FROM nvcr.io/nvidia/tritonserver:25.06-trtllm-python-py3
|
||||
LABEL maintainer="zhangyuekai@foxmail.com"
|
||||
|
||||
RUN apt-get update && apt-get install -y cmake
|
||||
RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
|
||||
COPY ./requirements.txt /workspace/requirements.txt
|
||||
RUN pip install -r /workspace/requirements.txt
|
||||
WORKDIR /workspace
|
||||
91
runtime/triton_trtllm/README.md
Normal file
91
runtime/triton_trtllm/README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
## Best Practices for Serving CosyVoice with NVIDIA Triton Inference Server
|
||||
|
||||
Thanks to the contribution from NVIDIA Yuekai Zhang.
|
||||
|
||||
### Quick Start
|
||||
Launch the service directly with Docker Compose:
|
||||
```sh
|
||||
docker compose up
|
||||
```
|
||||
|
||||
### Build the Docker Image
|
||||
Build the image from scratch:
|
||||
```sh
|
||||
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
|
||||
```
|
||||
|
||||
### Run a Docker Container
|
||||
```sh
|
||||
your_mount_dir=/mnt:/mnt
|
||||
docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
|
||||
```
|
||||
|
||||
### Understanding `run.sh`
|
||||
The `run.sh` script orchestrates the entire workflow through numbered stages.
|
||||
|
||||
Run a subset of stages with:
|
||||
```sh
|
||||
bash run.sh <start_stage> <stop_stage> [service_type]
|
||||
```
|
||||
- `<start_stage>` – stage to start from (0-5).
|
||||
- `<stop_stage>` – stage to stop after (0-5).
|
||||
|
||||
Stages:
|
||||
- **Stage 0** – Download the cosyvoice-2 0.5B model from HuggingFace.
|
||||
- **Stage 1** – Convert the HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines.
|
||||
- **Stage 2** – Create the Triton model repository and configure the model files (adjusts depending on whether `Decoupled=True/False` will be used later).
|
||||
- **Stage 3** – Launch the Triton Inference Server.
|
||||
- **Stage 4** – Run the single-utterance HTTP client.
|
||||
- **Stage 5** – Run the gRPC benchmark client.
|
||||
|
||||
### Export Models to TensorRT-LLM and Launch the Server
|
||||
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
|
||||
```sh
|
||||
# Runs stages 0, 1, 2, and 3
|
||||
bash run.sh 0 3
|
||||
```
|
||||
*Note: Stage 2 prepares the model repository differently depending on whether you intend to run with `Decoupled=False` or `Decoupled=True`. Rerun stage 2 if you switch the service type.*
|
||||
|
||||
### Single-Utterance HTTP Client
|
||||
Send a single HTTP inference request:
|
||||
```sh
|
||||
bash run.sh 4 4
|
||||
```
|
||||
|
||||
### Benchmark with a Dataset
|
||||
Benchmark the running Triton server. Pass either `streaming` or `offline` as the third argument.
|
||||
```sh
|
||||
bash run.sh 5 5
|
||||
|
||||
# You can also customise parameters such as num_task and dataset split directly:
|
||||
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline]
|
||||
```
|
||||
> [!TIP]
|
||||
> Only offline CosyVoice TTS is currently supported. Setting the client to `streaming` simply enables NVIDIA Triton’s decoupled mode so that responses are returned as soon as they are ready.
|
||||
|
||||
### Benchmark Results
|
||||
Decoding on a single L20 GPU with 26 prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts) (≈221 s of audio):
|
||||
|
||||
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
||||
|------|------|-------------|------------------|------------------|-----|
|
||||
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
|
||||
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
|
||||
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
|
||||
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 659.87 | 655.63 | 0.0891 |
|
||||
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1103.16 | 992.96 | 0.0693 |
|
||||
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1790.91 | 1668.63 | 0.0604 |
|
||||
|
||||
### OpenAI-Compatible Server
|
||||
To launch an OpenAI-compatible service, run:
|
||||
```sh
|
||||
git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
|
||||
pip install -r requirements.txt
|
||||
# After the Triton service is up, start the FastAPI bridge:
|
||||
python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
|
||||
# Test with curl
|
||||
bash test/test_cosyvoice.sh
|
||||
```
|
||||
|
||||
### Acknowledgements
|
||||
This section originates from the NVIDIA CISI project. We also provide other multimodal resources—see [mair-hub](https://github.com/nvidia-china-sae/mair-hub) for details.
|
||||
|
||||
834
runtime/triton_trtllm/client_grpc.py
Normal file
834
runtime/triton_trtllm/client_grpc.py
Normal file
@@ -0,0 +1,834 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# 2023 Nvidia (authors: Yuekai Zhang)
|
||||
# 2023 Recurrent.ai (authors: Songtao Shi)
|
||||
# See LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script supports to load dataset from huggingface and sends it to the server
|
||||
for decoding, in parallel.
|
||||
|
||||
Usage:
|
||||
num_task=2
|
||||
|
||||
# For offline F5-TTS
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--model-name f5_tts \
|
||||
--num-tasks $num_task \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name test_zh \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
|
||||
# For offline Spark-TTS-0.5B
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--model-name spark_tts \
|
||||
--num-tasks $num_task \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name wenetspeech4tts \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import queue # Added
|
||||
import uuid # Added
|
||||
import functools # Added
|
||||
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import tritonclient
|
||||
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
|
||||
import tritonclient.grpc as grpcclient_sync # Added sync client import
|
||||
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
||||
|
||||
|
||||
# --- Added UserData and callback ---
|
||||
class UserData:
|
||||
def __init__(self):
|
||||
self._completed_requests = queue.Queue()
|
||||
self._first_chunk_time = None
|
||||
self._start_time = None
|
||||
|
||||
def record_start_time(self):
|
||||
self._start_time = time.time()
|
||||
|
||||
def get_first_chunk_latency(self):
|
||||
if self._first_chunk_time and self._start_time:
|
||||
return self._first_chunk_time - self._start_time
|
||||
return None
|
||||
|
||||
|
||||
def callback(user_data, result, error):
|
||||
if user_data._first_chunk_time is None and not error:
|
||||
user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
||||
if error:
|
||||
user_data._completed_requests.put(error)
|
||||
else:
|
||||
user_data._completed_requests.put(result)
|
||||
# --- End Added UserData and callback ---
|
||||
|
||||
|
||||
def write_triton_stats(stats, summary_file):
|
||||
with open(summary_file, "w") as summary_f:
|
||||
model_stats = stats["model_stats"]
|
||||
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
|
||||
summary_f.write(
|
||||
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
|
||||
)
|
||||
summary_f.write("To learn more about the log, please refer to: \n")
|
||||
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
|
||||
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
|
||||
summary_f.write(
|
||||
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
|
||||
)
|
||||
summary_f.write(
|
||||
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
|
||||
)
|
||||
summary_f.write(
|
||||
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
|
||||
)
|
||||
summary_f.write(
|
||||
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
|
||||
)
|
||||
for model_state in model_stats:
|
||||
if "last_inference" not in model_state:
|
||||
continue
|
||||
summary_f.write(f"model name is {model_state['name']} \n")
|
||||
model_inference_stats = model_state["inference_stats"]
|
||||
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
|
||||
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
|
||||
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
||||
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
||||
summary_f.write(
|
||||
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
|
||||
)
|
||||
model_batch_stats = model_state["batch_stats"]
|
||||
for batch in model_batch_stats:
|
||||
batch_size = int(batch["batch_size"])
|
||||
compute_input = batch["compute_input"]
|
||||
compute_output = batch["compute_output"]
|
||||
compute_infer = batch["compute_infer"]
|
||||
batch_count = int(compute_infer["count"])
|
||||
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
|
||||
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
||||
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
||||
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
||||
summary_f.write(
|
||||
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
|
||||
)
|
||||
summary_f.write(
|
||||
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
|
||||
)
|
||||
summary_f.write(
|
||||
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-addr",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Address of the server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-port",
|
||||
type=int,
|
||||
default=8001,
|
||||
help="Grpc port of the triton server, default is 8001",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-audio",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-text",
|
||||
type=str,
|
||||
default="",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-text",
|
||||
type=str,
|
||||
default="",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--huggingface-dataset",
|
||||
type=str,
|
||||
default="yuekai/seed_tts",
|
||||
help="dataset name in huggingface dataset hub",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||
help="dataset split name, default is 'test'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the manifest dir which includes wav.scp trans.txt files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="f5_tts",
|
||||
choices=[
|
||||
"f5_tts",
|
||||
"spark_tts",
|
||||
"cosyvoice2"],
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-tasks",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of concurrent tasks for sending",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Controls how frequently we print the log.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--compute-wer",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="""True to compute WER.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-dir",
|
||||
type=str,
|
||||
required=False,
|
||||
default="./tmp",
|
||||
help="log directory",
|
||||
)
|
||||
|
||||
# --- Added arguments ---
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="offline",
|
||||
choices=["offline", "streaming"],
|
||||
help="Select offline or streaming benchmark mode."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-overlap-duration",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Chunk overlap duration for streaming reconstruction (in seconds)."
|
||||
)
|
||||
# --- End Added arguments ---
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_audio(wav_path, target_sample_rate=16000):
|
||||
assert target_sample_rate == 16000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
waveform = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
else:
|
||||
waveform, sample_rate = sf.read(wav_path)
|
||||
if sample_rate != target_sample_rate:
|
||||
from scipy.signal import resample
|
||||
|
||||
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
|
||||
waveform = resample(waveform, num_samples)
|
||||
return waveform, target_sample_rate
|
||||
|
||||
|
||||
def prepare_request_input_output(
|
||||
protocol_client, # Can be grpcclient_aio or grpcclient_sync
|
||||
waveform,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=16000,
|
||||
padding_duration: int = None # Optional padding for offline mode
|
||||
):
|
||||
"""Prepares inputs for Triton inference (offline or streaming)."""
|
||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
|
||||
# Apply padding only if padding_duration is provided (for offline)
|
||||
if padding_duration:
|
||||
duration = len(waveform) / sample_rate
|
||||
# Estimate target duration based on text length ratio (crude estimation)
|
||||
# Avoid division by zero if reference_text is empty
|
||||
if reference_text:
|
||||
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
||||
else:
|
||||
estimated_target_duration = duration # Assume target duration similar to reference if no text
|
||||
|
||||
# Calculate required samples based on estimated total duration
|
||||
required_total_samples = padding_duration * sample_rate * (
|
||||
(int(estimated_target_duration + duration) // padding_duration) + 1
|
||||
)
|
||||
samples = np.zeros((1, required_total_samples), dtype=np.float32)
|
||||
samples[0, : len(waveform)] = waveform
|
||||
else:
|
||||
# No padding for streaming or if padding_duration is None
|
||||
samples = waveform.reshape(1, -1).astype(np.float32)
|
||||
|
||||
# Common input creation logic
|
||||
inputs = [
|
||||
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
|
||||
protocol_client.InferInput(
|
||||
"reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
|
||||
),
|
||||
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
|
||||
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
|
||||
]
|
||||
inputs[0].set_data_from_numpy(samples)
|
||||
inputs[1].set_data_from_numpy(lengths)
|
||||
|
||||
input_data_numpy = np.array([reference_text], dtype=object)
|
||||
input_data_numpy = input_data_numpy.reshape((1, 1))
|
||||
inputs[2].set_data_from_numpy(input_data_numpy)
|
||||
|
||||
input_data_numpy = np.array([target_text], dtype=object)
|
||||
input_data_numpy = input_data_numpy.reshape((1, 1))
|
||||
inputs[3].set_data_from_numpy(input_data_numpy)
|
||||
|
||||
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
||||
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def run_sync_streaming_inference(
|
||||
sync_triton_client: tritonclient.grpc.InferenceServerClient,
|
||||
model_name: str,
|
||||
inputs: list,
|
||||
outputs: list,
|
||||
request_id: str,
|
||||
user_data: UserData,
|
||||
chunk_overlap_duration: float,
|
||||
save_sample_rate: int,
|
||||
audio_save_path: str,
|
||||
):
|
||||
"""Helper function to run the blocking sync streaming call."""
|
||||
start_time_total = time.time()
|
||||
user_data.record_start_time() # Record start time for first chunk latency calculation
|
||||
|
||||
# Establish stream
|
||||
sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
|
||||
|
||||
# Send request
|
||||
sync_triton_client.async_stream_infer(
|
||||
model_name,
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
outputs=outputs,
|
||||
enable_empty_final_response=True,
|
||||
)
|
||||
|
||||
# Process results
|
||||
audios = []
|
||||
while True:
|
||||
try:
|
||||
result = user_data._completed_requests.get() # Add timeout
|
||||
if isinstance(result, InferenceServerException):
|
||||
print(f"Received InferenceServerException: {result}")
|
||||
sync_triton_client.stop_stream()
|
||||
return None, None, None # Indicate error
|
||||
# Get response metadata
|
||||
response = result.get_response()
|
||||
final = response.parameters["triton_final_response"].bool_param
|
||||
if final is True:
|
||||
break
|
||||
|
||||
audio_chunk = result.as_numpy("waveform").reshape(-1)
|
||||
if audio_chunk.size > 0: # Only append non-empty chunks
|
||||
audios.append(audio_chunk)
|
||||
else:
|
||||
print("Warning: received empty audio chunk.")
|
||||
|
||||
except queue.Empty:
|
||||
print(f"Timeout waiting for response for request id {request_id}")
|
||||
sync_triton_client.stop_stream()
|
||||
return None, None, None # Indicate error
|
||||
|
||||
sync_triton_client.stop_stream()
|
||||
end_time_total = time.time()
|
||||
total_request_latency = end_time_total - start_time_total
|
||||
first_chunk_latency = user_data.get_first_chunk_latency()
|
||||
|
||||
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
|
||||
actual_duration = 0
|
||||
if audios:
|
||||
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
|
||||
fade_out = np.linspace(1, 0, cross_fade_samples)
|
||||
fade_in = np.linspace(0, 1, cross_fade_samples)
|
||||
reconstructed_audio = None
|
||||
|
||||
# Simplified reconstruction based on client_grpc_streaming.py
|
||||
if not audios:
|
||||
print("Warning: No audio chunks received.")
|
||||
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
|
||||
elif len(audios) == 1:
|
||||
reconstructed_audio = audios[0]
|
||||
else:
|
||||
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
|
||||
for i in range(1, len(audios)):
|
||||
# Cross-fade section
|
||||
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
||||
audios[i - 1][-cross_fade_samples:] * fade_out)
|
||||
# Middle section of the current chunk
|
||||
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
||||
# Concatenate
|
||||
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
||||
# Add the last part of the final chunk
|
||||
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
||||
|
||||
if reconstructed_audio is not None and reconstructed_audio.size > 0:
|
||||
actual_duration = len(reconstructed_audio) / save_sample_rate
|
||||
# Save reconstructed audio
|
||||
os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
|
||||
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
||||
else:
|
||||
print("Warning: No audio chunks received or reconstructed.")
|
||||
actual_duration = 0 # Set duration to 0 if no audio
|
||||
|
||||
else:
|
||||
print("Warning: No audio chunks received.")
|
||||
actual_duration = 0
|
||||
|
||||
return total_request_latency, first_chunk_latency, actual_duration
|
||||
|
||||
|
||||
async def send_streaming(
|
||||
manifest_item_list: list,
|
||||
name: str,
|
||||
server_url: str, # Changed from sync_triton_client
|
||||
protocol_client: types.ModuleType,
|
||||
log_interval: int,
|
||||
model_name: str,
|
||||
audio_save_dir: str = "./",
|
||||
save_sample_rate: int = 16000,
|
||||
chunk_overlap_duration: float = 0.1,
|
||||
padding_duration: int = None,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
task_id = int(name[5:])
|
||||
sync_triton_client = None # Initialize client variable
|
||||
|
||||
try: # Wrap in try...finally to ensure client closing
|
||||
print(f"{name}: Initializing sync client for streaming...")
|
||||
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
||||
|
||||
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
||||
for i, item in enumerate(manifest_item_list):
|
||||
if i % log_interval == 0:
|
||||
print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
|
||||
|
||||
try:
|
||||
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
||||
reference_text, target_text = item["reference_text"], item["target_text"]
|
||||
|
||||
inputs, outputs = prepare_request_input_output(
|
||||
protocol_client,
|
||||
waveform,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate,
|
||||
padding_duration=padding_duration
|
||||
)
|
||||
request_id = str(uuid.uuid4())
|
||||
user_data = UserData()
|
||||
|
||||
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||
|
||||
total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
|
||||
run_sync_streaming_inference,
|
||||
sync_triton_client,
|
||||
model_name,
|
||||
inputs,
|
||||
outputs,
|
||||
request_id,
|
||||
user_data,
|
||||
chunk_overlap_duration,
|
||||
save_sample_rate,
|
||||
audio_save_path
|
||||
)
|
||||
|
||||
if total_request_latency is not None:
|
||||
print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
|
||||
latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
|
||||
total_duration += actual_duration
|
||||
else:
|
||||
print(f"{name}: Item {i} failed.")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
||||
except Exception as e:
|
||||
print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally: # Ensure client is closed
|
||||
if sync_triton_client:
|
||||
try:
|
||||
print(f"{name}: Closing sync client...")
|
||||
sync_triton_client.close()
|
||||
except Exception as e:
|
||||
print(f"{name}: Error closing sync client: {e}")
|
||||
|
||||
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
|
||||
return total_duration, latency_data
|
||||
|
||||
|
||||
async def send(
|
||||
manifest_item_list: list,
|
||||
name: str,
|
||||
triton_client: tritonclient.grpc.aio.InferenceServerClient,
|
||||
protocol_client: types.ModuleType,
|
||||
log_interval: int,
|
||||
model_name: str,
|
||||
padding_duration: int = None,
|
||||
audio_save_dir: str = "./",
|
||||
save_sample_rate: int = 16000,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
task_id = int(name[5:])
|
||||
|
||||
print(f"manifest_item_list: {manifest_item_list}")
|
||||
for i, item in enumerate(manifest_item_list):
|
||||
if i % log_interval == 0:
|
||||
print(f"{name}: {i}/{len(manifest_item_list)}")
|
||||
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
||||
reference_text, target_text = item["reference_text"], item["target_text"]
|
||||
|
||||
inputs, outputs = prepare_request_input_output(
|
||||
protocol_client,
|
||||
waveform,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate,
|
||||
padding_duration=padding_duration
|
||||
)
|
||||
sequence_id = 100000000 + i + task_id * 10
|
||||
start = time.time()
|
||||
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
|
||||
|
||||
audio = response.as_numpy("waveform").reshape(-1)
|
||||
actual_duration = len(audio) / save_sample_rate
|
||||
|
||||
end = time.time() - start
|
||||
|
||||
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
||||
|
||||
latency_data.append((end, actual_duration))
|
||||
total_duration += actual_duration
|
||||
|
||||
return total_duration, latency_data
|
||||
|
||||
|
||||
def load_manifests(manifest_path):
|
||||
with open(manifest_path, "r") as f:
|
||||
manifest_list = []
|
||||
for line in f:
|
||||
assert len(line.strip().split("|")) == 4
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||
utt = Path(utt).stem
|
||||
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
|
||||
if not os.path.isabs(prompt_wav):
|
||||
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
|
||||
manifest_list.append(
|
||||
{
|
||||
"audio_filepath": prompt_wav,
|
||||
"reference_text": prompt_text,
|
||||
"target_text": gt_text,
|
||||
"target_audio_path": utt,
|
||||
}
|
||||
)
|
||||
return manifest_list
|
||||
|
||||
|
||||
def split_data(data, k):
|
||||
n = len(data)
|
||||
if n < k:
|
||||
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
|
||||
k = n
|
||||
|
||||
quotient = n // k
|
||||
remainder = n % k
|
||||
|
||||
result = []
|
||||
start = 0
|
||||
for i in range(k):
|
||||
if i < remainder:
|
||||
end = start + quotient + 1
|
||||
else:
|
||||
end = start + quotient
|
||||
|
||||
result.append(data[start:end])
|
||||
start = end
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def main():
|
||||
args = get_args()
|
||||
url = f"{args.server_addr}:{args.server_port}"
|
||||
|
||||
# --- Client Initialization based on mode ---
|
||||
triton_client = None
|
||||
protocol_client = None
|
||||
if args.mode == "offline":
|
||||
print("Initializing gRPC client for offline mode...")
|
||||
# Use the async client for offline tasks
|
||||
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
||||
protocol_client = grpcclient_aio
|
||||
elif args.mode == "streaming":
|
||||
print("Initializing gRPC client for streaming mode...")
|
||||
# Use the sync client for streaming tasks, handled via asyncio.to_thread
|
||||
# We will create one sync client instance PER TASK inside send_streaming.
|
||||
# triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
|
||||
protocol_client = grpcclient_sync # protocol client for input prep
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {args.mode}")
|
||||
# --- End Client Initialization ---
|
||||
|
||||
if args.reference_audio:
|
||||
args.num_tasks = 1
|
||||
args.log_interval = 1
|
||||
manifest_item_list = [
|
||||
{
|
||||
"reference_text": args.reference_text,
|
||||
"target_text": args.target_text,
|
||||
"audio_filepath": args.reference_audio,
|
||||
"target_audio_path": "test",
|
||||
}
|
||||
]
|
||||
elif args.huggingface_dataset:
|
||||
import datasets
|
||||
|
||||
dataset = datasets.load_dataset(
|
||||
args.huggingface_dataset,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
manifest_item_list = []
|
||||
for i in range(len(dataset)):
|
||||
manifest_item_list.append(
|
||||
{
|
||||
"audio_filepath": dataset[i]["prompt_audio"],
|
||||
"reference_text": dataset[i]["prompt_text"],
|
||||
"target_audio_path": dataset[i]["id"],
|
||||
"target_text": dataset[i]["target_text"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
manifest_item_list = load_manifests(args.manifest_path)
|
||||
|
||||
num_tasks = min(args.num_tasks, len(manifest_item_list))
|
||||
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
||||
|
||||
os.makedirs(args.log_dir, exist_ok=True)
|
||||
tasks = []
|
||||
start_time = time.time()
|
||||
for i in range(num_tasks):
|
||||
# --- Task Creation based on mode ---
|
||||
if args.mode == "offline":
|
||||
task = asyncio.create_task(
|
||||
send(
|
||||
manifest_item_list[i],
|
||||
name=f"task-{i}",
|
||||
triton_client=triton_client,
|
||||
protocol_client=protocol_client,
|
||||
log_interval=args.log_interval,
|
||||
model_name=args.model_name,
|
||||
audio_save_dir=args.log_dir,
|
||||
padding_duration=1,
|
||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||
)
|
||||
)
|
||||
elif args.mode == "streaming":
|
||||
task = asyncio.create_task(
|
||||
send_streaming(
|
||||
manifest_item_list[i],
|
||||
name=f"task-{i}",
|
||||
server_url=url, # Pass URL instead of client
|
||||
protocol_client=protocol_client,
|
||||
log_interval=args.log_interval,
|
||||
model_name=args.model_name,
|
||||
audio_save_dir=args.log_dir,
|
||||
padding_duration=10,
|
||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||
chunk_overlap_duration=args.chunk_overlap_duration,
|
||||
)
|
||||
)
|
||||
# --- End Task Creation ---
|
||||
tasks.append(task)
|
||||
|
||||
ans_list = await asyncio.gather(*tasks)
|
||||
|
||||
end_time = time.time()
|
||||
elapsed = end_time - start_time
|
||||
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
for ans in ans_list:
|
||||
if ans:
|
||||
total_duration += ans[0]
|
||||
latency_data.extend(ans[1]) # Use extend for list of lists
|
||||
else:
|
||||
print("Warning: A task returned None, possibly due to an error.")
|
||||
|
||||
if total_duration == 0:
|
||||
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
|
||||
rtf = float('inf')
|
||||
else:
|
||||
rtf = elapsed / total_duration
|
||||
|
||||
s = f"Mode: {args.mode}\n"
|
||||
s += f"RTF: {rtf:.4f}\n"
|
||||
s += f"total_duration: {total_duration:.3f} seconds\n"
|
||||
s += f"({total_duration / 3600:.2f} hours)\n"
|
||||
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
|
||||
|
||||
# --- Statistics Reporting based on mode ---
|
||||
if latency_data:
|
||||
if args.mode == "offline":
|
||||
# Original offline latency calculation
|
||||
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
|
||||
if latency_list:
|
||||
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
||||
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
|
||||
s += f"latency_variance: {latency_variance:.2f}\n"
|
||||
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
|
||||
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
|
||||
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
|
||||
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
|
||||
s += f"average_latency_ms: {latency_ms:.2f}\n"
|
||||
else:
|
||||
s += "No latency data collected for offline mode.\n"
|
||||
|
||||
elif args.mode == "streaming":
|
||||
# Calculate stats for total request latency and first chunk latency
|
||||
total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
|
||||
first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
|
||||
|
||||
s += "\n--- Total Request Latency ---\n"
|
||||
if total_latency_list:
|
||||
avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
|
||||
variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
|
||||
s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
|
||||
s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
|
||||
s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
|
||||
s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
|
||||
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
|
||||
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
|
||||
else:
|
||||
s += "No total request latency data collected.\n"
|
||||
|
||||
s += "\n--- First Chunk Latency ---\n"
|
||||
if first_chunk_latency_list:
|
||||
avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
|
||||
variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
|
||||
s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
|
||||
s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
|
||||
s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
|
||||
s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
|
||||
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
||||
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
||||
else:
|
||||
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
||||
else:
|
||||
s += "No latency data collected.\n"
|
||||
# --- End Statistics Reporting ---
|
||||
|
||||
print(s)
|
||||
if args.manifest_path:
|
||||
name = Path(args.manifest_path).stem
|
||||
elif args.split_name:
|
||||
name = args.split_name
|
||||
elif args.reference_audio:
|
||||
name = Path(args.reference_audio).stem
|
||||
else:
|
||||
name = "results" # Default name if no manifest/split/audio provided
|
||||
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
||||
f.write(s)
|
||||
|
||||
# --- Statistics Fetching using temporary Async Client ---
|
||||
# Use a separate async client for fetching stats regardless of mode
|
||||
stats_client = None
|
||||
try:
|
||||
print("Initializing temporary async client for fetching stats...")
|
||||
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
||||
print("Fetching inference statistics...")
|
||||
# Fetching for all models, filtering might be needed depending on server setup
|
||||
stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
||||
print("Fetching model config...")
|
||||
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
|
||||
|
||||
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
||||
|
||||
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
||||
json.dump(metadata, f, indent=4)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Could not retrieve statistics or config: {e}")
|
||||
finally:
|
||||
if stats_client:
|
||||
try:
|
||||
print("Closing temporary async stats client...")
|
||||
await stats_client.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing async stats client: {e}")
|
||||
# --- End Statistics Fetching ---
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# asyncio.run(main()) # Use TaskGroup for better exception handling if needed
|
||||
async def run_main():
|
||||
try:
|
||||
await main()
|
||||
except Exception as e:
|
||||
print(f"An error occurred in main: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
asyncio.run(run_main())
|
||||
173
runtime/triton_trtllm/client_http.py
Normal file
173
runtime/triton_trtllm/client_http.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import requests
|
||||
import soundfile as sf
|
||||
import json
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
type=str,
|
||||
default="localhost:8000",
|
||||
help="Address of the server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-audio",
|
||||
type=str,
|
||||
default="../../example/prompt_audio.wav",
|
||||
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-text",
|
||||
type=str,
|
||||
default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-text",
|
||||
type=str,
|
||||
default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="spark_tts",
|
||||
choices=[
|
||||
"f5_tts",
|
||||
"spark_tts",
|
||||
"cosyvoice2"],
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-audio",
|
||||
type=str,
|
||||
default="output.wav",
|
||||
help="Path to save the output audio",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_request(
|
||||
waveform,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=16000,
|
||||
padding_duration: int = None,
|
||||
audio_save_dir: str = "./",
|
||||
):
|
||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
if padding_duration:
|
||||
# padding to nearset 10 seconds
|
||||
samples = np.zeros(
|
||||
(
|
||||
1,
|
||||
padding_duration
|
||||
* sample_rate
|
||||
* ((int(len(waveform) / sample_rate) // padding_duration) + 1),
|
||||
),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
samples[0, : len(waveform)] = waveform
|
||||
else:
|
||||
samples = waveform
|
||||
|
||||
samples = samples.reshape(1, -1).astype(np.float32)
|
||||
|
||||
data = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "reference_wav",
|
||||
"shape": samples.shape,
|
||||
"datatype": "FP32",
|
||||
"data": samples.tolist()
|
||||
},
|
||||
{
|
||||
"name": "reference_wav_len",
|
||||
"shape": lengths.shape,
|
||||
"datatype": "INT32",
|
||||
"data": lengths.tolist(),
|
||||
},
|
||||
{
|
||||
"name": "reference_text",
|
||||
"shape": [1, 1],
|
||||
"datatype": "BYTES",
|
||||
"data": [reference_text]
|
||||
},
|
||||
{
|
||||
"name": "target_text",
|
||||
"shape": [1, 1],
|
||||
"datatype": "BYTES",
|
||||
"data": [target_text]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
server_url = args.server_url
|
||||
if not server_url.startswith(("http://", "https://")):
|
||||
server_url = f"http://{server_url}"
|
||||
|
||||
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
||||
waveform, sr = sf.read(args.reference_audio)
|
||||
assert sr == 16000, "sample rate hardcoded in server"
|
||||
|
||||
samples = np.array(waveform, dtype=np.float32)
|
||||
data = prepare_request(samples, args.reference_text, args.target_text)
|
||||
|
||||
rsp = requests.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=data,
|
||||
verify=False,
|
||||
params={"request_id": '0'}
|
||||
)
|
||||
result = rsp.json()
|
||||
audio = result["outputs"][0]["data"]
|
||||
audio = np.array(audio, dtype=np.float32)
|
||||
if args.model_name == "spark_tts":
|
||||
sample_rate = 16000
|
||||
else:
|
||||
sample_rate = 24000
|
||||
sf.write(args.output_audio, audio, sample_rate, "PCM_16")
|
||||
20
runtime/triton_trtllm/docker-compose.yml
Normal file
20
runtime/triton_trtllm/docker-compose.yml
Normal file
@@ -0,0 +1,20 @@
|
||||
services:
|
||||
tts:
|
||||
image: soar97/triton-cosyvoice:25.06
|
||||
shm_size: '1gb'
|
||||
ports:
|
||||
- "8000:8000"
|
||||
- "8001:8001"
|
||||
- "8002:8002"
|
||||
environment:
|
||||
- PYTHONIOENCODING=utf-8
|
||||
- MODEL_ID=${MODEL_ID}
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
device_ids: ['0']
|
||||
capabilities: [gpu]
|
||||
command: >
|
||||
/bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/FunAudioLLM/CosyVoice.git && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3"
|
||||
97
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
Normal file
97
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import json
|
||||
import torch
|
||||
from torch.utils.dlpack import to_dlpack
|
||||
|
||||
import triton_python_backend_utils as pb_utils
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import s3tokenizer
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for audio tokenization.
|
||||
|
||||
This model takes reference audio input and extracts semantic tokens
|
||||
using s3tokenizer.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing model configuration
|
||||
"""
|
||||
# Parse model parameters
|
||||
parameters = json.loads(args['model_config'])['parameters']
|
||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||
|
||||
self.device = torch.device("cuda")
|
||||
model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
|
||||
self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
|
||||
|
||||
def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
Args:
|
||||
requests: List of inference requests
|
||||
|
||||
Returns:
|
||||
List of inference responses containing tokenized outputs
|
||||
"""
|
||||
mels = []
|
||||
|
||||
# Process each request in batch
|
||||
for request in requests:
|
||||
# Extract input tensors
|
||||
wav_array = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav").as_numpy()
|
||||
wav_len = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav_len").as_numpy().item()
|
||||
|
||||
wav_array = torch.from_numpy(wav_array).to(self.device)
|
||||
# Prepare inputs
|
||||
wav = wav_array[:, :wav_len].squeeze(0)
|
||||
mels.append(s3tokenizer.log_mel_spectrogram(wav))
|
||||
|
||||
mels, mels_lens = s3tokenizer.padding(mels)
|
||||
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
|
||||
codes = codes.clone() + ORIGINAL_VOCAB_SIZE
|
||||
|
||||
responses = []
|
||||
for i in range(len(requests)):
|
||||
prompt_speech_tokens = codes[i, :codes_lens[i].item()]
|
||||
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
||||
"prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
||||
inference_response = pb_utils.InferenceResponse(
|
||||
output_tensors=[prompt_speech_tokens_tensor])
|
||||
responses.append(inference_response)
|
||||
|
||||
return responses
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "audio_tokenizer"
|
||||
backend: "python"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
key: "model_dir",
|
||||
value: {string_value:"${model_dir}"}
|
||||
}
|
||||
]
|
||||
|
||||
input [
|
||||
{
|
||||
name: "reference_wav"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
},
|
||||
{
|
||||
name: "reference_wav_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "prompt_speech_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [-1]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_CPU
|
||||
}
|
||||
]
|
||||
346
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py
Normal file
346
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Tuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from transformers import AutoTokenizer
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
import torchaudio
|
||||
import onnxruntime
|
||||
|
||||
|
||||
from matcha.utils.audio import mel_spectrogram
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for Spark TTS.
|
||||
|
||||
This model orchestrates the end-to-end TTS pipeline by coordinating
|
||||
between audio tokenizer, LLM, and vocoder components.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing model configuration
|
||||
"""
|
||||
self.logger = pb_utils.Logger
|
||||
# Parse model parameters
|
||||
self.model_config = json.loads(args['model_config'])
|
||||
parameters = self.model_config['parameters']
|
||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||
self.logger.log_info(f"model_params:{model_params}")
|
||||
|
||||
# Initialize tokenizer
|
||||
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
|
||||
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
|
||||
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
|
||||
|
||||
self.device = torch.device("cuda")
|
||||
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
||||
|
||||
campplus_model = f'{model_params["model_dir"]}/campplus.onnx'
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
||||
|
||||
def forward_llm(self, input_ids):
|
||||
"""
|
||||
Prepares the response from the language model based on the provided
|
||||
inputs. Creates a `pb_utils.InferenceRequest` object with passed
|
||||
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
|
||||
For each response from the language model:
|
||||
- Checks for errors and raise an exception if any are found.
|
||||
- Extracts the "output_ids" tensor from the response.
|
||||
- Determines the finish reason based on the presence of the
|
||||
end-of-sequence token or reaching the maximum length.
|
||||
- Appends the generated token IDs to `output_ids`.
|
||||
- If the finish reason is determined, decodes the output IDs to text
|
||||
and prepares the final response.
|
||||
|
||||
The final response includes the generated text, finish reason,
|
||||
completion tokens, prompt tokens, and total tokens.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
|
||||
"""
|
||||
# convert input_ids to numpy, with shape [1, sequence_length]
|
||||
input_ids = input_ids.cpu().numpy()
|
||||
max_tokens = 1024
|
||||
input_dict = {
|
||||
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
||||
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
||||
"pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
||||
"streaming": np.array([[self.decoupled]], dtype=np.bool_),
|
||||
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
||||
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
||||
"temperature": np.array([[0.8]], dtype=np.float32),
|
||||
"input_ids": input_ids,
|
||||
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
||||
}
|
||||
|
||||
# Convert inputs to Triton tensors
|
||||
input_tensor_list = [
|
||||
pb_utils.Tensor(k, v) for k, v in input_dict.items()
|
||||
]
|
||||
|
||||
# Create and execute inference request
|
||||
llm_request = pb_utils.InferenceRequest(
|
||||
model_name="tensorrt_llm",
|
||||
requested_output_names=["output_ids", "sequence_length"],
|
||||
inputs=input_tensor_list,
|
||||
)
|
||||
|
||||
llm_responses = llm_request.exec(decoupled=self.decoupled)
|
||||
if self.decoupled:
|
||||
for llm_response in llm_responses:
|
||||
if llm_response.has_error():
|
||||
raise pb_utils.TritonModelException(llm_response.error().message())
|
||||
|
||||
# Extract and process output
|
||||
output_ids = pb_utils.get_output_tensor_by_name(
|
||||
llm_response, "output_ids").as_numpy()
|
||||
seq_lens = pb_utils.get_output_tensor_by_name(
|
||||
llm_response, "sequence_length").as_numpy()
|
||||
|
||||
# Get actual output IDs up to the sequence length
|
||||
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
||||
|
||||
yield actual_output_ids
|
||||
else:
|
||||
llm_response = llm_responses
|
||||
if llm_response.has_error():
|
||||
raise pb_utils.TritonModelException(llm_response.error().message())
|
||||
|
||||
# Extract and process output
|
||||
output_ids = pb_utils.get_output_tensor_by_name(
|
||||
llm_response, "output_ids").as_numpy()
|
||||
seq_lens = pb_utils.get_output_tensor_by_name(
|
||||
llm_response, "sequence_length").as_numpy()
|
||||
|
||||
# Get actual output IDs up to the sequence length
|
||||
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
||||
|
||||
yield actual_output_ids
|
||||
|
||||
def forward_audio_tokenizer(self, wav, wav_len):
|
||||
"""Forward pass through the audio tokenizer component.
|
||||
|
||||
Args:
|
||||
wav: Input waveform tensor
|
||||
wav_len: Waveform length tensor
|
||||
|
||||
Returns:
|
||||
Tuple of global and semantic tokens
|
||||
"""
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='audio_tokenizer',
|
||||
requested_output_names=['prompt_speech_tokens'],
|
||||
inputs=[wav, wav_len]
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
# Extract and convert output tensors
|
||||
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
|
||||
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
|
||||
|
||||
return prompt_speech_tokens
|
||||
|
||||
def forward_token2wav(
|
||||
self,
|
||||
prompt_speech_tokens: torch.Tensor,
|
||||
prompt_speech_feat: torch.Tensor,
|
||||
prompt_spk_embedding: torch.Tensor,
|
||||
target_speech_tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass through the vocoder component.
|
||||
|
||||
Args:
|
||||
prompt_speech_tokens: Prompt speech tokens tensor
|
||||
prompt_speech_feat: Prompt speech feat tensor
|
||||
prompt_spk_embedding: Prompt spk embedding tensor
|
||||
target_speech_tokens: Target speech tokens tensor
|
||||
|
||||
Returns:
|
||||
Generated waveform tensor
|
||||
"""
|
||||
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
||||
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||
|
||||
# Create and execute inference request
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='token2wav',
|
||||
requested_output_names=['waveform'],
|
||||
inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
# Extract and convert output waveform
|
||||
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
||||
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
||||
|
||||
return waveform
|
||||
|
||||
def parse_input(self, text, prompt_text, prompt_speech_tokens):
|
||||
total_text = f"{prompt_text}{text}"
|
||||
prompt = self.prompt_template.format(input_text=total_text)
|
||||
input_ids = self.tokenizer.encode(prompt)
|
||||
input_ids = torch.tensor([input_ids], dtype=torch.int32)
|
||||
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
||||
return input_ids
|
||||
|
||||
def _extract_spk_embedding(self, speech):
|
||||
feat = kaldi.fbank(speech,
|
||||
num_mel_bins=80,
|
||||
dither=0,
|
||||
sample_frequency=16000)
|
||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||
embedding = self.campplus_session.run(None,
|
||||
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
||||
embedding = torch.tensor([embedding]).to(self.device).half()
|
||||
return embedding
|
||||
|
||||
def _extract_speech_feat(self, speech):
|
||||
speech_feat = mel_spectrogram(
|
||||
speech,
|
||||
n_fft=1920,
|
||||
num_mels=80,
|
||||
sampling_rate=24000,
|
||||
hop_size=480,
|
||||
win_size=1920,
|
||||
fmin=0,
|
||||
fmax=8000).squeeze(
|
||||
dim=0).transpose(
|
||||
0,
|
||||
1).to(
|
||||
self.device)
|
||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||
return speech_feat
|
||||
|
||||
def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
Args:
|
||||
requests: List of inference requests
|
||||
|
||||
Returns:
|
||||
List of inference responses containing generated audio
|
||||
"""
|
||||
responses = []
|
||||
|
||||
for request in requests:
|
||||
# Extract input tensors
|
||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||
|
||||
# Process reference audio through audio tokenizer
|
||||
|
||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||
|
||||
wav_tensor = wav.as_numpy()
|
||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
|
||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||
target_text = target_text[0][0].decode('utf-8')
|
||||
|
||||
# Prepare prompt for LLM
|
||||
input_ids = self.parse_input(
|
||||
text=target_text,
|
||||
prompt_text=reference_text,
|
||||
prompt_speech_tokens=prompt_speech_tokens,
|
||||
)
|
||||
|
||||
# Generate semantic tokens with LLM
|
||||
generated_ids_iter = self.forward_llm(input_ids)
|
||||
|
||||
if self.decoupled:
|
||||
response_sender = request.get_response_sender()
|
||||
request_id = request.request_id()
|
||||
generated_ids = []
|
||||
for generated_id in generated_ids_iter:
|
||||
# convert the numpy array into a int32 tensor
|
||||
generated_id = generated_id.tolist()
|
||||
if len(generated_id) > 0:
|
||||
assert len(generated_id) == 1, "Generated ID is not a single integer"
|
||||
generated_ids.append(generated_id[0])
|
||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
|
||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
||||
|
||||
# Prepare response
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||
else:
|
||||
generated_ids = next(generated_ids_iter)
|
||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
|
||||
if generated_ids is None or len(generated_ids) == 0:
|
||||
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
||||
|
||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
||||
|
||||
# Prepare response
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
responses.append(inference_response)
|
||||
|
||||
if not self.decoupled:
|
||||
return responses
|
||||
70
runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt
Normal file
70
runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "cosyvoice2"
|
||||
backend: "python"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
}
|
||||
model_transaction_policy {
|
||||
decoupled: ${decoupled_mode}
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
key: "llm_tokenizer_dir",
|
||||
value: {string_value:"${llm_tokenizer_dir}"}
|
||||
},
|
||||
{
|
||||
key: "model_dir",
|
||||
value: {string_value:"${model_dir}"}
|
||||
}
|
||||
]
|
||||
|
||||
input [
|
||||
{
|
||||
name: "reference_wav"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
},
|
||||
{
|
||||
name: "reference_wav_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
},
|
||||
{
|
||||
name: "reference_text"
|
||||
data_type: TYPE_STRING
|
||||
dims: [1]
|
||||
},
|
||||
{
|
||||
name: "target_text"
|
||||
data_type: TYPE_STRING
|
||||
dims: [1]
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "waveform"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: ${bls_instance_num}
|
||||
kind: KIND_CPU
|
||||
}
|
||||
]
|
||||
857
runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt
Normal file
857
runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt
Normal file
@@ -0,0 +1,857 @@
|
||||
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
name: "tensorrt_llm"
|
||||
backend: "${triton_backend}"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
|
||||
model_transaction_policy {
|
||||
decoupled: ${decoupled_mode}
|
||||
}
|
||||
|
||||
dynamic_batching {
|
||||
preferred_batch_size: [ ${triton_max_batch_size} ]
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
default_queue_policy: { max_queue_size: ${max_queue_size} }
|
||||
}
|
||||
|
||||
input [
|
||||
{
|
||||
name: "input_ids"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
allow_ragged_batch: true
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "encoder_input_features"
|
||||
data_type: ${encoder_input_features_data_type}
|
||||
dims: [ -1, -1 ]
|
||||
allow_ragged_batch: true
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "encoder_output_lengths"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "input_lengths"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
},
|
||||
{
|
||||
name: "request_output_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
},
|
||||
{
|
||||
name: "num_return_sequences"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "draft_input_ids"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "decoder_input_ids"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "decoder_input_lengths"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
reshape: { shape: [ ] }
|
||||
},
|
||||
{
|
||||
name: "draft_logits"
|
||||
data_type: ${logits_datatype}
|
||||
dims: [ -1, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "draft_acceptance_threshold"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "end_id"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "pad_id"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "stop_words_list"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 2, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "bad_words_list"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 2, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "embedding_bias"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "beam_width"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "temperature"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "runtime_top_k"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "runtime_top_p"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "runtime_top_p_min"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "runtime_top_p_decay"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "runtime_top_p_reset_ids"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "len_penalty"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "early_stopping"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "repetition_penalty"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "min_length"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "beam_search_diversity_rate"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "presence_penalty"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "frequency_penalty"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "random_seed"
|
||||
data_type: TYPE_UINT64
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "return_log_probs"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "return_context_logits"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "return_generation_logits"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "return_perf_metrics"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "exclude_input_in_output"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "stop"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "streaming"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "prompt_embedding_table"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ -1, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "prompt_table_extra_ids"
|
||||
data_type: TYPE_UINT64
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "prompt_vocab_size"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
# cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
|
||||
{
|
||||
name: "cross_attention_mask"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ -1, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
# Mrope param when mrope is used
|
||||
{
|
||||
name: "mrope_rotary_cos_sin"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "mrope_position_deltas"
|
||||
data_type: TYPE_INT64
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
},
|
||||
# the unique task ID for the given LoRA.
|
||||
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
|
||||
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
|
||||
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
|
||||
{
|
||||
name: "lora_task_id"
|
||||
data_type: TYPE_UINT64
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
|
||||
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
|
||||
# each of the in / out tensors are first flattened and then concatenated together in the format above.
|
||||
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
|
||||
{
|
||||
name: "lora_weights"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ -1, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
# module identifier (same size a first dimension of lora_weights)
|
||||
# See LoraModule::ModuleType for model id mapping
|
||||
#
|
||||
# "attn_qkv": 0 # compbined qkv adapter
|
||||
# "attn_q": 1 # q adapter
|
||||
# "attn_k": 2 # k adapter
|
||||
# "attn_v": 3 # v adapter
|
||||
# "attn_dense": 4 # adapter for the dense layer in attention
|
||||
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
|
||||
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
|
||||
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
|
||||
#
|
||||
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
|
||||
{
|
||||
name: "lora_config"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1, 3 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "context_phase_params"
|
||||
data_type: TYPE_UINT8
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
# skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
|
||||
{
|
||||
name: "skip_cross_attn_blocks"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_token_range_starts"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_token_range_ends"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_token_range_priorities"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_token_range_durations_ms"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_decode_priority"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_decode_duration_ms"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "guided_decoding_guide_type"
|
||||
data_type: TYPE_STRING
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "guided_decoding_guide"
|
||||
data_type: TYPE_STRING
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "lookahead_window_size"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "lookahead_ngram_size"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "lookahead_verification_set_size"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "output_ids"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1, -1 ]
|
||||
},
|
||||
{
|
||||
name: "sequence_length"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
},
|
||||
{
|
||||
name: "cum_log_probs"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
},
|
||||
{
|
||||
name: "output_log_probs"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1, -1 ]
|
||||
},
|
||||
{
|
||||
name: "context_logits"
|
||||
data_type: ${logits_datatype}
|
||||
dims: [ -1, -1 ]
|
||||
},
|
||||
{
|
||||
name: "generation_logits"
|
||||
data_type: ${logits_datatype}
|
||||
dims: [ -1, -1, -1 ]
|
||||
},
|
||||
{
|
||||
name: "batch_index"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "sequence_index"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "context_phase_params"
|
||||
data_type: TYPE_UINT8
|
||||
dims: [ -1 ]
|
||||
},
|
||||
{
|
||||
name: "kv_cache_alloc_new_blocks"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "kv_cache_reused_blocks"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "kv_cache_alloc_total_blocks"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "arrival_time_ns"
|
||||
data_type: TYPE_INT64
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "first_scheduled_time_ns"
|
||||
data_type: TYPE_INT64
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "first_token_time_ns"
|
||||
data_type: TYPE_INT64
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "last_token_time_ns"
|
||||
data_type: TYPE_INT64
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "acceptance_rate"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "total_accepted_draft_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "total_draft_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
}
|
||||
]
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind : KIND_CPU
|
||||
}
|
||||
]
|
||||
parameters: {
|
||||
key: "max_beam_width"
|
||||
value: {
|
||||
string_value: "${max_beam_width}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "FORCE_CPU_ONLY_INPUT_TENSORS"
|
||||
value: {
|
||||
string_value: "no"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "gpt_model_type"
|
||||
value: {
|
||||
string_value: "${batching_strategy}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "gpt_model_path"
|
||||
value: {
|
||||
string_value: "${engine_dir}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "encoder_model_path"
|
||||
value: {
|
||||
string_value: "${encoder_engine_dir}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "max_tokens_in_paged_kv_cache"
|
||||
value: {
|
||||
string_value: "${max_tokens_in_paged_kv_cache}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "max_attention_window_size"
|
||||
value: {
|
||||
string_value: "${max_attention_window_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "sink_token_length"
|
||||
value: {
|
||||
string_value: "${sink_token_length}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "batch_scheduler_policy"
|
||||
value: {
|
||||
string_value: "${batch_scheduler_policy}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "kv_cache_free_gpu_mem_fraction"
|
||||
value: {
|
||||
string_value: "${kv_cache_free_gpu_mem_fraction}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "cross_kv_cache_fraction"
|
||||
value: {
|
||||
string_value: "${cross_kv_cache_fraction}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "kv_cache_host_memory_bytes"
|
||||
value: {
|
||||
string_value: "${kv_cache_host_memory_bytes}"
|
||||
}
|
||||
}
|
||||
# kv_cache_onboard_blocks is for internal implementation.
|
||||
parameters: {
|
||||
key: "kv_cache_onboard_blocks"
|
||||
value: {
|
||||
string_value: "${kv_cache_onboard_blocks}"
|
||||
}
|
||||
}
|
||||
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
|
||||
# parameters: {
|
||||
# key: "enable_trt_overlap"
|
||||
# value: {
|
||||
# string_value: "${enable_trt_overlap}"
|
||||
# }
|
||||
# }
|
||||
parameters: {
|
||||
key: "exclude_input_in_output"
|
||||
value: {
|
||||
string_value: "${exclude_input_in_output}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "cancellation_check_period_ms"
|
||||
value: {
|
||||
string_value: "${cancellation_check_period_ms}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "stats_check_period_ms"
|
||||
value: {
|
||||
string_value: "${stats_check_period_ms}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "iter_stats_max_iterations"
|
||||
value: {
|
||||
string_value: "${iter_stats_max_iterations}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "request_stats_max_iterations"
|
||||
value: {
|
||||
string_value: "${request_stats_max_iterations}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "enable_kv_cache_reuse"
|
||||
value: {
|
||||
string_value: "${enable_kv_cache_reuse}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "normalize_log_probs"
|
||||
value: {
|
||||
string_value: "${normalize_log_probs}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "enable_chunked_context"
|
||||
value: {
|
||||
string_value: "${enable_chunked_context}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "gpu_device_ids"
|
||||
value: {
|
||||
string_value: "${gpu_device_ids}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "participant_ids"
|
||||
value: {
|
||||
string_value: "${participant_ids}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lora_cache_optimal_adapter_size"
|
||||
value: {
|
||||
string_value: "${lora_cache_optimal_adapter_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lora_cache_max_adapter_size"
|
||||
value: {
|
||||
string_value: "${lora_cache_max_adapter_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lora_cache_gpu_memory_fraction"
|
||||
value: {
|
||||
string_value: "${lora_cache_gpu_memory_fraction}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lora_cache_host_memory_bytes"
|
||||
value: {
|
||||
string_value: "${lora_cache_host_memory_bytes}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lora_prefetch_dir"
|
||||
value: {
|
||||
string_value: "${lora_prefetch_dir}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "decoding_mode"
|
||||
value: {
|
||||
string_value: "${decoding_mode}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "executor_worker_path"
|
||||
value: {
|
||||
string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lookahead_window_size"
|
||||
value: {
|
||||
string_value: "${lookahead_window_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lookahead_ngram_size"
|
||||
value: {
|
||||
string_value: "${lookahead_ngram_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lookahead_verification_set_size"
|
||||
value: {
|
||||
string_value: "${lookahead_verification_set_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "medusa_choices"
|
||||
value: {
|
||||
string_value: "${medusa_choices}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "eagle_choices"
|
||||
value: {
|
||||
string_value: "${eagle_choices}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "gpu_weights_percent"
|
||||
value: {
|
||||
string_value: "${gpu_weights_percent}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "enable_context_fmha_fp32_acc"
|
||||
value: {
|
||||
string_value: "${enable_context_fmha_fp32_acc}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "multi_block_mode"
|
||||
value: {
|
||||
string_value: "${multi_block_mode}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "cuda_graph_mode"
|
||||
value: {
|
||||
string_value: "${cuda_graph_mode}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "cuda_graph_cache_size"
|
||||
value: {
|
||||
string_value: "${cuda_graph_cache_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "speculative_decoding_fast_logits"
|
||||
value: {
|
||||
string_value: "${speculative_decoding_fast_logits}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "tokenizer_dir"
|
||||
value: {
|
||||
string_value: "${tokenizer_dir}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "guided_decoding_backend"
|
||||
value: {
|
||||
string_value: "${guided_decoding_backend}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "xgrammar_tokenizer_info_path"
|
||||
value: {
|
||||
string_value: "${xgrammar_tokenizer_info_path}"
|
||||
}
|
||||
}
|
||||
195
runtime/triton_trtllm/model_repo/token2wav/1/model.py
Normal file
195
runtime/triton_trtllm/model_repo/token2wav/1/model.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
|
||||
import torch
|
||||
from torch.utils.dlpack import to_dlpack
|
||||
|
||||
import triton_python_backend_utils as pb_utils
|
||||
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
||||
from cosyvoice.utils.common import TrtContextWrapper
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
|
||||
|
||||
class CosyVoice2:
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
|
||||
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
|
||||
if not os.path.exists(hyper_yaml_path):
|
||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
|
||||
|
||||
class CosyVoice2Model:
|
||||
|
||||
def __init__(self,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
if self.fp16 is True:
|
||||
self.flow.half()
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load(self, flow_model, hift_model):
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
# in case hift_model is a hifigan model
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_trt_kwargs(self):
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
||||
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
|
||||
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
|
||||
input_names = ["x", "mask", "mu", "cond"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for vocoder.
|
||||
|
||||
This model takes global and semantic tokens as input and generates audio waveforms
|
||||
using the BiCodec vocoder.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing model configuration
|
||||
"""
|
||||
# Parse model parameters
|
||||
parameters = json.loads(args['model_config'])['parameters']
|
||||
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
||||
model_dir = model_params["model_dir"]
|
||||
|
||||
# Initialize device and vocoder
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
||||
|
||||
self.token2wav_model = CosyVoice2(
|
||||
model_dir, load_jit=True, load_trt=True, fp16=True
|
||||
)
|
||||
|
||||
logger.info("Token2Wav initialized successfully")
|
||||
|
||||
def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
Args:
|
||||
requests: List of inference requests
|
||||
|
||||
Returns:
|
||||
List of inference responses containing generated waveforms
|
||||
"""
|
||||
responses = []
|
||||
# Process each request in batch
|
||||
for request in requests:
|
||||
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
||||
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
|
||||
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
|
||||
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
|
||||
|
||||
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
|
||||
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
|
||||
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
|
||||
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
|
||||
|
||||
# shift the speech tokens according to the original vocab size
|
||||
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
|
||||
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
||||
token=target_speech_tokens,
|
||||
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
||||
self.device
|
||||
),
|
||||
prompt_token=prompt_speech_tokens,
|
||||
prompt_token_len=torch.tensor(
|
||||
[prompt_speech_tokens.shape[1]], dtype=torch.int32
|
||||
).to(self.device),
|
||||
prompt_feat=prompt_speech_feat,
|
||||
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=prompt_spk_embedding,
|
||||
streaming=False,
|
||||
finalize=True,
|
||||
)
|
||||
|
||||
audio_hat, _ = self.token2wav_model.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
|
||||
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||
|
||||
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
|
||||
responses.append(inference_response)
|
||||
|
||||
return responses
|
||||
63
runtime/triton_trtllm/model_repo/token2wav/config.pbtxt
Normal file
63
runtime/triton_trtllm/model_repo/token2wav/config.pbtxt
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "token2wav"
|
||||
backend: "python"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
key: "model_dir",
|
||||
value: {string_value:"${model_dir}"}
|
||||
}
|
||||
]
|
||||
|
||||
input [
|
||||
{
|
||||
name: "target_speech_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [-1]
|
||||
},
|
||||
{
|
||||
name: "prompt_speech_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [-1]
|
||||
},
|
||||
{
|
||||
name: "prompt_speech_feat"
|
||||
data_type: TYPE_FP16
|
||||
dims: [-1, 80]
|
||||
},
|
||||
{
|
||||
name: "prompt_spk_embedding"
|
||||
data_type: TYPE_FP16
|
||||
dims: [-1]
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "waveform"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_CPU
|
||||
}
|
||||
]
|
||||
14
runtime/triton_trtllm/requirements.txt
Normal file
14
runtime/triton_trtllm/requirements.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
hyperpyyaml
|
||||
s3tokenizer
|
||||
onnxruntime-gpu
|
||||
omegaconf
|
||||
conformer
|
||||
hydra-core
|
||||
lightning
|
||||
gdown
|
||||
wget
|
||||
librosa
|
||||
pyworld
|
||||
openai-whisper
|
||||
tritonclient
|
||||
modelscope
|
||||
106
runtime/triton_trtllm/run.sh
Normal file
106
runtime/triton_trtllm/run.sh
Normal file
@@ -0,0 +1,106 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
cosyvoice_path=/workspace/CosyVoice
|
||||
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
|
||||
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
|
||||
huggingface_model_local_dir=./cosyvoice2_llm
|
||||
model_scope_model_local_dir=./CosyVoice2-0.5B
|
||||
trt_dtype=bfloat16
|
||||
trt_weights_dir=./trt_weights_${trt_dtype}
|
||||
trt_engines_dir=./trt_engines_${trt_dtype}
|
||||
|
||||
model_repo=./model_repo_cosyvoice2
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
echo "Cloning CosyVoice"
|
||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
|
||||
cd $cosyvoice_path
|
||||
git submodule update --init --recursive
|
||||
cd runtime/triton_trtllm
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
echo "Downloading CosyVoice2-0.5B"
|
||||
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
|
||||
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
echo "Converting checkpoint to TensorRT weights"
|
||||
python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
|
||||
--output_dir $trt_weights_dir \
|
||||
--dtype $trt_dtype || exit 1
|
||||
|
||||
echo "Building TensorRT engines"
|
||||
trtllm-build --checkpoint_dir $trt_weights_dir \
|
||||
--output_dir $trt_engines_dir \
|
||||
--max_batch_size 16 \
|
||||
--max_num_tokens 32768 \
|
||||
--gemm_plugin $trt_dtype || exit 1
|
||||
|
||||
echo "Testing TensorRT engines"
|
||||
python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
|
||||
--tokenizer_dir $huggingface_model_local_dir \
|
||||
--top_k 50 --top_p 0.95 --temperature 0.8 \
|
||||
--engine_dir=$trt_engines_dir || exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
echo "Creating model repository"
|
||||
rm -rf $model_repo
|
||||
mkdir -p $model_repo
|
||||
cosyvoice2_dir="cosyvoice2"
|
||||
|
||||
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
||||
cp -r ./model_repo/audio_tokenizer $model_repo
|
||||
cp -r ./model_repo/tensorrt_llm $model_repo
|
||||
cp -r ./model_repo/token2wav $model_repo
|
||||
|
||||
ENGINE_PATH=$trt_engines_dir
|
||||
MAX_QUEUE_DELAY_MICROSECONDS=0
|
||||
MODEL_DIR=$model_scope_model_local_dir
|
||||
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
|
||||
BLS_INSTANCE_NUM=4
|
||||
TRITON_MAX_BATCH_SIZE=16
|
||||
DECOUPLED_MODE=False
|
||||
|
||||
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
echo "Starting Triton server"
|
||||
tritonserver --model-repository $model_repo
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
echo "Single request test http"
|
||||
python3 client_http.py \
|
||||
--reference-audio ./assets/prompt_audio.wav \
|
||||
--reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
||||
--target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
|
||||
--model-name cosyvoice2
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
echo "Running benchmark client grpc"
|
||||
num_task=4
|
||||
# set mode=streaming, when decoupled=True
|
||||
# set mode=offline, when decoupled=False
|
||||
mode=offline
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--model-name cosyvoice2 \
|
||||
--num-tasks $num_task \
|
||||
--mode $mode \
|
||||
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_4_${trt_dtype}
|
||||
fi
|
||||
330
runtime/triton_trtllm/scripts/convert_checkpoint.py
Normal file
330
runtime/triton_trtllm/scripts/convert_checkpoint.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import release_gc
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models import QWenForCausalLM
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_dir', type=str, default=None, required=True)
|
||||
parser.add_argument('--tp_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='N-way tensor parallelism size')
|
||||
parser.add_argument('--pp_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='N-way pipeline parallelism size')
|
||||
parser.add_argument('--cp_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='N-way context parallelism size')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
||||
help="The data type for the model weights and activations if not quantized. "
|
||||
"If 'auto', the data type is automatically inferred from the source model; "
|
||||
"however, if the source dtype is float32, it is converted to float16.")
|
||||
parser.add_argument(
|
||||
'--use_weight_only',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
||||
'See --weight_only_precision to set the precision')
|
||||
parser.add_argument(
|
||||
'--disable_weight_only_quant_plugin',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--weight_only_precision',
|
||||
const='int8',
|
||||
type=str,
|
||||
nargs='?',
|
||||
default='int8',
|
||||
choices=['int8', 'int4', 'int4_gptq'],
|
||||
help='Define the precision for the weights when using weight-only quantization.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--calib_dataset',
|
||||
type=str,
|
||||
default='ccdv/cnn_dailymail',
|
||||
help="The huggingface dataset name or the local directory of the dataset for calibration."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smoothquant",
|
||||
"-sq",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
|
||||
" to Smoothquant the model, and output int8 weights."
|
||||
" A good first try is 0.5. Must be in [0, 1]")
|
||||
parser.add_argument(
|
||||
'--per_channel',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help='By default, we use a single static scaling factor for the GEMM\'s result. '
|
||||
'per_channel instead uses a different static scaling factor for each channel. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
parser.add_argument(
|
||||
'--per_token',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help='By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
parser.add_argument(
|
||||
'--int8_kv_cache',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_group',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='By default, we use a single static scaling factor to scale weights in the int4 range. '
|
||||
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
||||
'The flag is built for GPTQ/AWQ quantization.')
|
||||
|
||||
parser.add_argument('--group_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Group size used in GPTQ quantization.')
|
||||
|
||||
parser.add_argument("--load_model_on_cpu", action="store_true")
|
||||
parser.add_argument(
|
||||
'--use_parallel_embedding',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--embedding_sharding_dim',
|
||||
type=int,
|
||||
default=0,
|
||||
choices=[0, 1],
|
||||
help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
||||
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
||||
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
||||
)
|
||||
parser.add_argument('--output_dir',
|
||||
type=str,
|
||||
default='tllm_checkpoint',
|
||||
help='The path to save the TensorRT-LLM checkpoint')
|
||||
parser.add_argument(
|
||||
'--workers',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of workers for converting checkpoint in parallel')
|
||||
parser.add_argument(
|
||||
'--moe_tp_size',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--moe_ep_size',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
|
||||
'''return config dict with quantization info based on the command line args
|
||||
'''
|
||||
quant_config = QuantConfig()
|
||||
if args.use_weight_only:
|
||||
if args.weight_only_precision == 'int8':
|
||||
quant_config.quant_algo = QuantAlgo.W8A16
|
||||
elif args.weight_only_precision == 'int4':
|
||||
quant_config.quant_algo = QuantAlgo.W4A16
|
||||
elif args.smoothquant:
|
||||
quant_config.smoothquant_val = args.smoothquant
|
||||
if args.per_channel:
|
||||
if args.per_token:
|
||||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
|
||||
else:
|
||||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
|
||||
else:
|
||||
if args.per_token:
|
||||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
|
||||
else:
|
||||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
|
||||
|
||||
if args.int8_kv_cache:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.INT8
|
||||
|
||||
if args.weight_only_precision == 'int4_gptq':
|
||||
quant_config.group_size = args.group_size
|
||||
quant_config.has_zero_point = True
|
||||
quant_config.pre_quant_scale = False
|
||||
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
||||
|
||||
return quant_config
|
||||
|
||||
|
||||
def update_quant_config_from_hf(quant_config, hf_config,
|
||||
override_fields) -> tuple[QuantConfig, dict]:
|
||||
hf_config_dict = hf_config.to_dict()
|
||||
if hf_config_dict.get('quantization_config'):
|
||||
# update the quant_algo, and clamp_val.
|
||||
if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
|
||||
logger.info(
|
||||
"Load quantization configs from huggingface model_config.")
|
||||
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
||||
quant_config.group_size = hf_config_dict['quantization_config'].get(
|
||||
'group_size', 128)
|
||||
quant_config.has_zero_point = hf_config_dict[
|
||||
'quantization_config'].get('zero_point', False)
|
||||
override_fields.update({"use_autoawq": True})
|
||||
elif hf_config_dict['quantization_config'].get(
|
||||
'quant_method') == 'gptq':
|
||||
logger.info(
|
||||
"Load quantization configs from huggingface model_config.")
|
||||
desc_act = hf_config_dict['quantization_config'].get(
|
||||
'desc_act', False)
|
||||
if desc_act:
|
||||
raise ValueError("GPTQ with desc_act=True is not implemented!")
|
||||
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
||||
quant_config.group_size = hf_config_dict['quantization_config'].get(
|
||||
'group_size', 128)
|
||||
quant_config.has_zero_point = hf_config_dict[
|
||||
'quantization_config'].get('sym', False)
|
||||
return quant_config, override_fields
|
||||
|
||||
|
||||
def args_to_build_options(args):
|
||||
return {
|
||||
'use_parallel_embedding': args.use_parallel_embedding,
|
||||
'embedding_sharding_dim': args.embedding_sharding_dim,
|
||||
'disable_weight_only_quant_plugin':
|
||||
args.disable_weight_only_quant_plugin
|
||||
}
|
||||
|
||||
|
||||
def convert_and_save_hf(args):
|
||||
model_dir = args.model_dir
|
||||
world_size = args.tp_size * args.pp_size
|
||||
# Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
|
||||
# Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
|
||||
# before the refactor is done.
|
||||
override_fields = {}
|
||||
override_fields.update(args_to_build_options(args))
|
||||
quant_config = args_to_quant_config(args)
|
||||
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_dir,
|
||||
trust_remote_code=True)
|
||||
quant_config, override_fields = update_quant_config_from_hf(
|
||||
quant_config, hf_config, override_fields)
|
||||
except BaseException:
|
||||
logger.warning("AutoConfig cannot load the huggingface config.")
|
||||
|
||||
if args.smoothquant is not None or args.int8_kv_cache:
|
||||
mapping = Mapping(world_size=world_size,
|
||||
tp_size=args.tp_size,
|
||||
pp_size=args.pp_size,
|
||||
moe_tp_size=args.moe_tp_size,
|
||||
moe_ep_size=args.moe_ep_size,
|
||||
cp_size=args.cp_size)
|
||||
QWenForCausalLM.quantize(args.model_dir,
|
||||
args.output_dir,
|
||||
dtype=args.dtype,
|
||||
mapping=mapping,
|
||||
quant_config=quant_config,
|
||||
calib_dataset=args.calib_dataset,
|
||||
**override_fields)
|
||||
else:
|
||||
|
||||
def convert_and_save_rank(args, rank):
|
||||
mapping = Mapping(world_size=world_size,
|
||||
rank=rank,
|
||||
tp_size=args.tp_size,
|
||||
pp_size=args.pp_size,
|
||||
moe_tp_size=args.moe_tp_size,
|
||||
moe_ep_size=args.moe_ep_size)
|
||||
qwen = QWenForCausalLM.from_hugging_face(model_dir,
|
||||
args.dtype,
|
||||
mapping=mapping,
|
||||
quant_config=quant_config,
|
||||
**override_fields)
|
||||
qwen.config.mapping.cp_size = args.cp_size
|
||||
qwen.config.mapping.attn_tp_size = -1
|
||||
qwen.config.mapping.attn_cp_size = -1
|
||||
qwen.config.mapping.world_size *= args.cp_size
|
||||
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||||
del qwen
|
||||
|
||||
execute(args.workers, [convert_and_save_rank] * world_size, args)
|
||||
release_gc()
|
||||
|
||||
|
||||
def execute(workers, func, args):
|
||||
if workers == 1:
|
||||
for rank, f in enumerate(func):
|
||||
f(args, rank)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as p:
|
||||
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
|
||||
exceptions = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
exceptions.append(e)
|
||||
assert len(
|
||||
exceptions
|
||||
) == 0, "Checkpoint conversion failed, please check error log."
|
||||
|
||||
|
||||
def main():
|
||||
print(tensorrt_llm.__version__)
|
||||
args = parse_arguments()
|
||||
|
||||
if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
|
||||
# moe default to tp-only
|
||||
args.moe_tp_size = args.tp_size
|
||||
args.moe_ep_size = 1
|
||||
elif (args.moe_tp_size == -1):
|
||||
args.moe_tp_size = args.tp_size // args.moe_ep_size
|
||||
elif (args.moe_ep_size == -1):
|
||||
args.moe_ep_size = args.tp_size // args.moe_tp_size
|
||||
assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
|
||||
), "moe_tp_size * moe_ep_size must equal to tp_size"
|
||||
|
||||
tik = time.time()
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
assert args.model_dir is not None
|
||||
convert_and_save_hf(args)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
print(f'Total time of converting checkpoints: {t}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
69
runtime/triton_trtllm/scripts/fill_template.py
Normal file
69
runtime/triton_trtllm/scripts/fill_template.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# /usr/bin/env python3
|
||||
from argparse import ArgumentParser
|
||||
from string import Template
|
||||
|
||||
|
||||
def split(string, delimiter):
|
||||
"""Split a string using delimiter. Supports escaping.
|
||||
|
||||
Args:
|
||||
string (str): The string to split.
|
||||
delimiter (str): The delimiter to split the string with.
|
||||
|
||||
Returns:
|
||||
list: A list of strings.
|
||||
"""
|
||||
result = []
|
||||
current = ""
|
||||
escape = False
|
||||
for char in string:
|
||||
if escape:
|
||||
current += char
|
||||
escape = False
|
||||
elif char == delimiter:
|
||||
result.append(current)
|
||||
current = ""
|
||||
elif char == "\\":
|
||||
escape = True
|
||||
else:
|
||||
current += char
|
||||
result.append(current)
|
||||
return result
|
||||
|
||||
|
||||
def main(file_path, substitutions, in_place):
|
||||
with open(file_path) as f:
|
||||
pbtxt = Template(f.read())
|
||||
|
||||
sub_dict = {
|
||||
"max_queue_size": 0,
|
||||
'max_queue_delay_microseconds': 0,
|
||||
}
|
||||
for sub in split(substitutions, ","):
|
||||
key, value = split(sub, ":")
|
||||
sub_dict[key] = value
|
||||
|
||||
assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}."
|
||||
|
||||
pbtxt = pbtxt.safe_substitute(sub_dict)
|
||||
|
||||
if in_place:
|
||||
with open(file_path, "w") as f:
|
||||
f.write(pbtxt)
|
||||
else:
|
||||
print(pbtxt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("file_path", help="path of the .pbtxt to modify")
|
||||
parser.add_argument(
|
||||
"substitutions",
|
||||
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
|
||||
)
|
||||
parser.add_argument("--in_place",
|
||||
"-i",
|
||||
action="store_true",
|
||||
help="do the operation in-place")
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
143
runtime/triton_trtllm/scripts/test_llm.py
Normal file
143
runtime/triton_trtllm/scripts/test_llm.py
Normal file
@@ -0,0 +1,143 @@
|
||||
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import csv
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from tensorrt_llm.runtime import ModelRunnerCpp
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def parse_arguments(args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--input_text',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=["Born in north-east France, Soyer trained as a"])
|
||||
parser.add_argument('--tokenizer_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
parser.add_argument('--engine_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
parser.add_argument('--log_level', type=str, default="debug")
|
||||
parser.add_argument('--kv_cache_free_gpu_memory_fraction', type=float, default=0.6)
|
||||
parser.add_argument('--temperature', type=float, default=0.8)
|
||||
parser.add_argument('--top_k', type=int, default=50)
|
||||
parser.add_argument('--top_p', type=float, default=0.95)
|
||||
|
||||
return parser.parse_args(args=args)
|
||||
|
||||
|
||||
def parse_input(tokenizer,
|
||||
input_text=None,
|
||||
prompt_template=None):
|
||||
batch_input_ids = []
|
||||
for curr_text in input_text:
|
||||
if prompt_template is not None:
|
||||
curr_text = prompt_template.format(input_text=curr_text)
|
||||
input_ids = tokenizer.encode(
|
||||
curr_text)
|
||||
batch_input_ids.append(input_ids)
|
||||
|
||||
batch_input_ids = [
|
||||
torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
|
||||
]
|
||||
|
||||
logger.debug(f"Input token ids (batch_size = {len(batch_input_ids)}):")
|
||||
for i, input_ids in enumerate(batch_input_ids):
|
||||
logger.debug(f"Request {i}: {input_ids.tolist()}")
|
||||
|
||||
return batch_input_ids
|
||||
|
||||
|
||||
def main(args):
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
prompt_template = "<|sos|>{input_text}<|task_id|>"
|
||||
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>")
|
||||
|
||||
batch_input_ids = parse_input(tokenizer=tokenizer,
|
||||
input_text=args.input_text,
|
||||
prompt_template=prompt_template)
|
||||
|
||||
input_lengths = [x.size(0) for x in batch_input_ids]
|
||||
|
||||
runner_kwargs = dict(
|
||||
engine_dir=args.engine_dir,
|
||||
rank=runtime_rank,
|
||||
max_output_len=1024,
|
||||
enable_context_fmha_fp32_acc=False,
|
||||
max_batch_size=len(batch_input_ids),
|
||||
max_input_len=max(input_lengths),
|
||||
kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
|
||||
cuda_graph_mode=False,
|
||||
gather_generation_logits=False,
|
||||
)
|
||||
|
||||
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = runner.generate(
|
||||
batch_input_ids=batch_input_ids,
|
||||
max_new_tokens=1024,
|
||||
end_id=end_id,
|
||||
pad_id=end_id,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
num_return_sequences=1,
|
||||
repetition_penalty=1.1,
|
||||
random_seed=42,
|
||||
streaming=False,
|
||||
output_sequence_lengths=True,
|
||||
output_generation_logits=False,
|
||||
return_dict=True,
|
||||
return_all_generated_tokens=False)
|
||||
torch.cuda.synchronize()
|
||||
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
|
||||
num_output_sents, num_beams, _ = output_ids.size()
|
||||
assert num_beams == 1
|
||||
beam = 0
|
||||
batch_size = len(input_lengths)
|
||||
num_return_sequences = num_output_sents // batch_size
|
||||
assert num_return_sequences == 1
|
||||
for i in range(batch_size * num_return_sequences):
|
||||
batch_idx = i // num_return_sequences
|
||||
seq_idx = i % num_return_sequences
|
||||
inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist()
|
||||
input_text = tokenizer.decode(inputs)
|
||||
print(f'Input [Text {batch_idx}]: \"{input_text}\"')
|
||||
output_begin = input_lengths[batch_idx]
|
||||
output_end = sequence_lengths[i][beam]
|
||||
outputs = output_ids[i][beam][output_begin:output_end].tolist()
|
||||
output_text = tokenizer.decode(outputs)
|
||||
print(f'Output [Text {batch_idx}]: \"{output_text}\"')
|
||||
logger.debug(str(outputs))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user