Merge pull request #1561 from yuekaizhang/streaming

[Runtime] Support Streaming TTS for Triton + TensorRT-LLM runtime
This commit is contained in:
Xiang Lyu
2025-09-04 09:48:43 +08:00
committed by GitHub
11 changed files with 544 additions and 134 deletions

View File

@@ -1,15 +1,17 @@
## Best Practices for Serving CosyVoice with NVIDIA Triton Inference Server
## Serving CosyVoice with NVIDIA Triton Inference Server
Thanks to the contribution from NVIDIA Yuekai Zhang.
Contributed by Yuekai Zhang (NVIDIA).
### Quick Start
Launch the service directly with Docker Compose:
```sh
docker compose up
```
### Build the Docker Image
Build the image from scratch:
To build the image from scratch:
```sh
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
```
@@ -21,71 +23,89 @@ docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_di
```
### Understanding `run.sh`
The `run.sh` script orchestrates the entire workflow through numbered stages.
Run a subset of stages with:
You can run a subset of stages with:
```sh
bash run.sh <start_stage> <stop_stage> [service_type]
```
- `<start_stage>` stage to start from (0-5).
- `<stop_stage>` stage to stop after (0-5).
- `<start_stage>`: The stage to start from (0-5).
- `<stop_stage>`: The stage to stop after (0-5).
Stages:
- **Stage 0** Download the cosyvoice-2 0.5B model from HuggingFace.
- **Stage 1** Convert the HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines.
- **Stage 2** Create the Triton model repository and configure the model files (adjusts depending on whether `Decoupled=True/False` will be used later).
- **Stage 3** Launch the Triton Inference Server.
- **Stage 4** Run the single-utterance HTTP client.
- **Stage 5** Run the gRPC benchmark client.
**Stages:**
- **Stage 0**: Downloads the `cosyvoice-2 0.5B` model from HuggingFace.
- **Stage 1**: Converts the HuggingFace checkpoint to the TensorRT-LLM format and builds the TensorRT engines.
- **Stage 2**: Creates the Triton model repository and configures the model files. The configuration is adjusted based on whether `Decoupled=True` (streaming) or `Decoupled=False` (offline) will be used.
- **Stage 3**: Launches the Triton Inference Server.
- **Stage 4**: Runs the single-utterance HTTP client for testing.
- **Stage 5**: Runs the gRPC benchmark client.
### Export Models and Launch Server
### Export Models to TensorRT-LLM and Launch the Server
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
```sh
# Runs stages 0, 1, 2, and 3
# This command runs stages 0, 1, 2, and 3
bash run.sh 0 3
```
*Note: Stage 2 prepares the model repository differently depending on whether you intend to run with `Decoupled=False` or `Decoupled=True`. Rerun stage 2 if you switch the service type.*
> [!TIP]
> Both streaming and offline (non-streaming) TTS modes are supported. For streaming TTS, set `Decoupled=True`. For offline TTS, set `Decoupled=False`. You need to rerun stage 2 if you switch between modes.
### Single-Utterance HTTP Client
Send a single HTTP inference request:
Sends a single HTTP inference request. This is intended for testing the offline TTS mode (`Decoupled=False`):
```sh
bash run.sh 4 4
```
### Benchmark with a Dataset
Benchmark the running Triton server. Pass either `streaming` or `offline` as the third argument.
```sh
bash run.sh 5 5
# You can also customise parameters such as num_task and dataset split directly:
To benchmark the running Triton server, pass `streaming` or `offline` as the third argument:
```sh
bash run.sh 5 5 # [streaming|offline]
# You can also customize parameters such as the number of tasks and the dataset split:
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline]
```
> [!TIP]
> Only offline CosyVoice TTS is currently supported. Setting the client to `streaming` simply enables NVIDIA Tritons decoupled mode so that responses are returned as soon as they are ready.
> It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up.
### Benchmark Results
Decoding on a single L20 GPU with 26 prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts) (≈221 s of audio):
The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio):
**Streaming TTS (First Chunk Latency)**
| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|---|---|---|---|---|
| Streaming, Decoupled=True | 1 | 220.43 | 218.07 | 0.1237 |
| Streaming, Decoupled=True | 2 | 476.97 | 369.25 | 0.1022 |
| Streaming, Decoupled=True | 4 | 1107.34 | 1243.75| 0.0922 |
**Offline TTS (Full Sentence Latency)**
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|------|------|-------------|------------------|------------------|-----|
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 659.87 | 655.63 | 0.0891 |
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1103.16 | 992.96 | 0.0693 |
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1790.91 | 1668.63 | 0.0604 |
|---|---|---|---|---|---|
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
### OpenAI-Compatible Server
To launch an OpenAI-compatible service, run:
To launch an OpenAI-compatible API service, run the following commands:
```sh
git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
cd Triton-OpenAI-Speech
pip install -r requirements.txt
# After the Triton service is up, start the FastAPI bridge:
# After the Triton service is running, start the FastAPI bridge:
python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
# Test with curl
# Test the service with curl:
bash test/test_cosyvoice.sh
```
> [!NOTE]
> Currently, only the offline TTS mode is compatible with the OpenAI-compatible server.
### Acknowledgements
This section originates from the NVIDIA CISI project. We also provide other multimodal resources—see [mair-hub](https://github.com/nvidia-china-sae/mair-hub) for details.
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).

View File

@@ -395,38 +395,45 @@ def run_sync_streaming_inference(
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
actual_duration = 0
if audios:
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
reconstructed_audio = None
# Only spark_tts model uses cross-fade
if model_name == "spark_tts":
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
reconstructed_audio = None
# Simplified reconstruction based on client_grpc_streaming.py
if not audios:
print("Warning: No audio chunks received.")
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
elif len(audios) == 1:
reconstructed_audio = audios[0]
# Simplified reconstruction based on client_grpc_streaming.py
if not audios:
print("Warning: No audio chunks received.")
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
elif len(audios) == 1:
reconstructed_audio = audios[0]
else:
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
for i in range(1, len(audios)):
# Cross-fade section
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
audios[i - 1][-cross_fade_samples:] * fade_out)
# Middle section of the current chunk
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
# Concatenate
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
# Add the last part of the final chunk
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
if reconstructed_audio is not None and reconstructed_audio.size > 0:
actual_duration = len(reconstructed_audio) / save_sample_rate
# Save reconstructed audio
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
else:
print("Warning: No audio chunks received or reconstructed.")
actual_duration = 0 # Set duration to 0 if no audio
else:
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
for i in range(1, len(audios)):
# Cross-fade section
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
audios[i - 1][-cross_fade_samples:] * fade_out)
# Middle section of the current chunk
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
# Concatenate
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
# Add the last part of the final chunk
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
if reconstructed_audio is not None and reconstructed_audio.size > 0:
reconstructed_audio = np.concatenate(audios)
print(f"reconstructed_audio: {reconstructed_audio.shape}")
actual_duration = len(reconstructed_audio) / save_sample_rate
# Save reconstructed audio
os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
else:
print("Warning: No audio chunks received or reconstructed.")
actual_duration = 0 # Set duration to 0 if no audio
else:
print("Warning: No audio chunks received.")
@@ -667,6 +674,7 @@ async def main():
manifest_item_list = split_data(manifest_item_list, num_tasks)
os.makedirs(args.log_dir, exist_ok=True)
tasks = []
start_time = time.time()
for i in range(num_tasks):

View File

@@ -32,7 +32,7 @@ import triton_python_backend_utils as pb_utils
import os
import numpy as np
import s3tokenizer
torch.set_num_threads(1)
ORIGINAL_VOCAB_SIZE = 151663

View File

@@ -20,7 +20,7 @@ dynamic_batching {
}
parameters [
{
key: "model_dir",
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]

View File

@@ -28,6 +28,8 @@ import json
import math
import os
import re
import threading
import time
from typing import Dict, List, Tuple, Optional, Union
import numpy as np
@@ -35,13 +37,14 @@ import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
import torchaudio.compliance.kaldi as kaldi
import torchaudio
import onnxruntime
from matcha.utils.audio import mel_spectrogram
torch.set_num_threads(1)
class TritonPythonModel:
"""Triton Python model for Spark TTS.
@@ -62,6 +65,8 @@ class TritonPythonModel:
parameters = self.model_config['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.logger.log_info(f"model_params:{model_params}")
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
@@ -72,11 +77,9 @@ class TritonPythonModel:
self.device = torch.device("cuda")
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
campplus_model = f'{model_params["model_dir"]}/campplus.onnx'
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
self.token_frame_rate = 25
self.flow_pre_lookahead_len = 3
self.token_hop_len = 15
def forward_llm(self, input_ids):
"""
@@ -105,7 +108,7 @@ class TritonPythonModel:
"""
# convert input_ids to numpy, with shape [1, sequence_length]
input_ids = input_ids.cpu().numpy()
max_tokens = 1024
max_tokens = 750
input_dict = {
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
@@ -114,6 +117,8 @@ class TritonPythonModel:
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
"runtime_top_k": np.array([[50]], dtype=np.int32),
"temperature": np.array([[0.8]], dtype=np.float32),
"repetition_penalty": np.array([[1.1]], dtype=np.float32),
"random_seed": np.array([[42]], dtype=np.uint64),
"input_ids": input_ids,
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
}
@@ -188,12 +193,40 @@ class TritonPythonModel:
return prompt_speech_tokens
def forward_speaker_embedding(self, wav):
"""Forward pass through the speaker embedding component.
Args:
wav: Input waveform tensor
Returns:
Prompt speaker embedding tensor
"""
inference_request = pb_utils.InferenceRequest(
model_name='speaker_embedding',
requested_output_names=['prompt_spk_embedding'],
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
return prompt_spk_embedding
def forward_token2wav(
self,
prompt_speech_tokens: torch.Tensor,
prompt_speech_feat: torch.Tensor,
prompt_spk_embedding: torch.Tensor,
target_speech_tokens: torch.Tensor) -> torch.Tensor:
target_speech_tokens: torch.Tensor,
request_id: str,
token_offset: int = None,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
Args:
@@ -210,11 +243,21 @@ class TritonPythonModel:
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
if token_offset is not None:
assert finalize is not None
token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor.append(token_offset_tensor)
inputs_tensor.append(finalize_tensor)
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav',
requested_output_names=['waveform'],
inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
inputs=inputs_tensor,
request_id=request_id,
)
inference_response = inference_request.exec()
@@ -235,17 +278,6 @@ class TritonPythonModel:
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
return input_ids
def _extract_spk_embedding(self, speech):
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.campplus_session.run(None,
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device).half()
return embedding
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(
speech,
@@ -263,6 +295,14 @@ class TritonPythonModel:
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
for generated_ids in generated_ids_iter:
generated_ids = generated_ids.tolist()
if len(generated_ids) == 0:
break
semantic_token_ids_arr.extend(generated_ids)
llm_is_done_flag[0] = True
def execute(self, requests):
"""Execute inference on the batched requests.
@@ -275,6 +315,7 @@ class TritonPythonModel:
responses = []
for request in requests:
request_id = request.request_id()
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
@@ -292,6 +333,8 @@ class TritonPythonModel:
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
@@ -307,25 +350,76 @@ class TritonPythonModel:
# Generate semantic tokens with LLM
generated_ids_iter = self.forward_llm(input_ids)
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
if self.decoupled:
response_sender = request.get_response_sender()
request_id = request.request_id()
generated_ids = []
for generated_id in generated_ids_iter:
# convert the numpy array into a int32 tensor
generated_id = generated_id.tolist()
if len(generated_id) > 0:
assert len(generated_id) == 1, "Generated ID is not a single integer"
generated_ids.append(generated_id[0])
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
# Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
semantic_token_ids_arr = []
llm_is_done_flag = [False]
llm_thread = threading.Thread(
target=self._llm_gen_thread,
args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
)
llm_thread.start()
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
while True:
pending_num = len(semantic_token_ids_arr) - token_offset
if llm_is_done_flag[0]:
break
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(
prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding,
this_tts_speech_token, request_id, token_offset, False)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
token_offset += this_token_hop_len
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time
duration = token_offset / self.token_frame_rate
if chunk_index > 0 and cost_time > 0:
avg_chunk_processing_time = cost_time / (chunk_index + 1)
if avg_chunk_processing_time > 0:
multiples = (duration - cost_time) / avg_chunk_processing_time
self.logger.log_info(f"multiples: {multiples}")
next_pending_num = len(semantic_token_ids_arr) - token_offset
if multiples > 4:
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
elif multiples > 2:
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
else:
this_token_hop_len = self.token_hop_len
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
chunk_index += 1
else:
time.sleep(0.02)
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
llm_thread.join()
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:
@@ -334,8 +428,7 @@ class TritonPythonModel:
if generated_ids is None or len(generated_ids) == 0:
raise pb_utils.TritonModelException("Generated IDs is None or empty")
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
# Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))

View File

@@ -23,11 +23,11 @@ model_transaction_policy {
}
parameters [
{
key: "llm_tokenizer_dir",
key: "llm_tokenizer_dir",
value: {string_value:"${llm_tokenizer_dir}"}
},
{
key: "model_dir",
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]

View File

@@ -0,0 +1,153 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import torch
from torch.utils.dlpack import to_dlpack
import triton_python_backend_utils as pb_utils
import os
import numpy as np
import torchaudio.compliance.kaldi as kaldi
from cosyvoice.utils.file_utils import convert_onnx_to_trt
from cosyvoice.utils.common import TrtContextWrapper
import onnxruntime
class TritonPythonModel:
"""Triton Python model for audio tokenization.
This model takes reference audio input and extracts semantic tokens
using s3tokenizer.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.device = torch.device("cuda")
model_dir = model_params["model_dir"]
gpu = "l20"
enable_trt = True
if enable_trt:
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False)
else:
campplus_model = f'{model_dir}/campplus.onnx'
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
trt_kwargs = self.get_spk_trt_kwargs()
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
import tensorrt as trt
with open(spk_model, 'rb') as f:
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_spk_trt_kwargs(self):
min_shape = [(1, 4, 80)]
opt_shape = [(1, 500, 80)]
max_shape = [(1, 3000, 80)]
input_names = ["input"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def _extract_spk_embedding(self, speech):
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
spk_feat = feat - feat.mean(dim=0, keepdim=True)
if isinstance(self.spk_model, onnxruntime.InferenceSession):
embedding = self.spk_model.run(
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device)
else:
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
# NOTE need to synchronize when switching stream
with torch.cuda.device(self.device):
torch.cuda.current_stream().synchronize()
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
batch_size = spk_feat.size(0)
with stream:
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
embedding = torch.empty((batch_size, 192), device=spk_feat.device)
data_ptrs = [spk_feat.contiguous().data_ptr(),
embedding.contiguous().data_ptr()]
for i, j in enumerate(data_ptrs):
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.spk_model.release_estimator(spk_model, stream)
return embedding.half()
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing tokenized outputs
"""
responses = []
# Process each request in batch
for request in requests:
# Extract input tensors
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_array = torch.from_numpy(wav_array).to(self.device)
embedding = self._extract_spk_embedding(wav_array)
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
"prompt_spk_embedding", to_dlpack(embedding))
inference_response = pb_utils.InferenceResponse(
output_tensors=[prompt_spk_embedding_tensor])
responses.append(inference_response)
return responses

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "speaker_embedding"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
}
]
output [
{
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

View File

@@ -32,22 +32,27 @@ from typing import List, Dict
import torch
from torch.utils.dlpack import to_dlpack
from torch.nn import functional as F
import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
from collections import defaultdict
import numpy as np
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
class CosyVoice2:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
self.model_dir = model_dir
self.fp16 = fp16
@@ -57,7 +62,7 @@ class CosyVoice2:
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
@@ -73,14 +78,22 @@ class CosyVoice2Model:
def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fp16: bool = False,
device: str = 'cuda'):
self.device = device
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.flow.half()
# streaming tts config
self.token_hop_len = 25
self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480)
self.speech_window = np.hamming(2 * self.source_cache_len)
self.hift_cache_dict = defaultdict(lambda: None)
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
@@ -111,6 +124,42 @@ class CosyVoice2Model:
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if finalize is False:
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
'source': tts_source[:, :, -self.source_cache_len:],
'speech': tts_speech[:, -self.source_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech
class TritonPythonModel:
"""Triton Python model for vocoder.
@@ -131,11 +180,11 @@ class TritonPythonModel:
model_dir = model_params["model_dir"]
# Initialize device and vocoder
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
self.token2wav_model = CosyVoice2(
model_dir, load_jit=True, load_trt=True, fp16=True
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
)
logger.info("Token2Wav initialized successfully")
@@ -166,25 +215,47 @@ class TritonPythonModel:
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
tts_mel, _ = self.token2wav_model.model.flow.inference(
token=target_speech_tokens,
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
self.device
),
prompt_token=prompt_speech_tokens,
prompt_token_len=torch.tensor(
[prompt_speech_tokens.shape[1]], dtype=torch.int32
).to(self.device),
prompt_feat=prompt_speech_feat,
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=prompt_spk_embedding,
streaming=False,
finalize=True,
)
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
if token_offset is not None:
token_offset = token_offset.as_numpy().item()
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
if not finalize:
stream = True
else:
stream = False
request_id = request.request_id()
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
prompt_token=prompt_speech_tokens,
prompt_feat=prompt_speech_feat,
embedding=prompt_spk_embedding,
token_offset=token_offset,
uuid=request_id,
stream=stream,
finalize=finalize)
if finalize:
self.token2wav_model.model.hift_cache_dict.pop(request_id)
audio_hat, _ = self.token2wav_model.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
else:
tts_mel, _ = self.token2wav_model.model.flow.inference(
token=target_speech_tokens,
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
self.device
),
prompt_token=prompt_speech_tokens,
prompt_token_len=torch.tensor(
[prompt_speech_tokens.shape[1]], dtype=torch.int32
).to(self.device),
prompt_feat=prompt_speech_feat,
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=prompt_spk_embedding,
streaming=False,
finalize=True,
)
audio_hat, _ = self.token2wav_model.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
generated_wave = audio_hat.squeeze(0).cpu().numpy()

View File

@@ -20,7 +20,7 @@ dynamic_batching {
}
parameters [
{
key: "model_dir",
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
@@ -45,6 +45,20 @@ input [
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
},
{
name: "token_offset"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "finalize"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
}
]
output [

View File

@@ -60,6 +60,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
cp -r ./model_repo/audio_tokenizer $model_repo
cp -r ./model_repo/tensorrt_llm $model_repo
cp -r ./model_repo/token2wav $model_repo
cp -r ./model_repo/speaker_embedding $model_repo
ENGINE_PATH=$trt_engines_dir
MAX_QUEUE_DELAY_MICROSECONDS=0
@@ -67,11 +68,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
BLS_INSTANCE_NUM=4
TRITON_MAX_BATCH_SIZE=16
DECOUPLED_MODE=False
DECOUPLED_MODE=True # True for streaming, False for offline
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
fi
@@ -82,7 +84,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Single request test http"
echo "Single request test http, only work for offline TTS mode"
python3 client_http.py \
--reference-audio ./assets/prompt_audio.wav \
--reference-text "吃燕窝就选燕之屋本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝营养更均衡本节目由豆本豆豆奶特约播出。" \
@@ -92,15 +94,16 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Running benchmark client grpc"
num_task=4
# set mode=streaming, when decoupled=True
# set mode=offline, when decoupled=False
mode=offline
num_task=1
mode=streaming
BLS_INSTANCE_NUM=4
python3 client_grpc.py \
--server-addr localhost \
--model-name cosyvoice2 \
--num-tasks $num_task \
--mode $mode \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_4_${trt_dtype}
fi
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
fi