Merge branch 'main' into dev/lyuxiang.lx

This commit is contained in:
lyuxiang.lx
2025-12-04 18:00:17 +08:00
35 changed files with 3584 additions and 311 deletions

View File

@@ -180,7 +180,7 @@ Notice that `vllm==v0.9.0` has a lot of specific requirements, for example `torc
``` sh
conda create -n cosyvoice_vllm --clone cosyvoice
conda activate cosyvoice_vllm
pip install vllm==v0.9.0 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
pip install vllm==v0.9.0 transformers==4.51.3 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
python vllm_example.py
```
@@ -246,6 +246,17 @@ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /o
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
```
#### Using Nvidia TensorRT-LLM for deployment
Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
To quick start:
``` sh
cd runtime/triton_trtllm
docker compose up -d
```
For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
## Discussion & Communication
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).

View File

@@ -46,6 +46,6 @@ RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5
RUN conda activate ${VENV} && cd CosyVoice && \
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
WORKDIR /workspace/CosyVoice

View File

@@ -36,7 +36,7 @@ Stage `0` converts raw JSONL files into the parquet format expected by veRL:
```bash
bash run.sh 0 0
```
Create two JSONL files—`train.jsonl` and `test.jsonl`.
Create two JSONL files—`train.jsonl` and `test.jsonl`.
The script will then generate two Parquet files:
```
@@ -111,7 +111,7 @@ bash run.sh 5 5
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
> [!TIP]
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
## Results

View File

@@ -53,7 +53,7 @@ except RuntimeError:
pass
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501
def audio_decode_cosyvoice2(

View File

@@ -1,5 +1,3 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#

View File

@@ -33,7 +33,7 @@ fi
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
python3 pretrained_to_huggingface.py \
--pretrained-cosyvoice2-path $model_scope_model_path \
--save-path $sft_model_path
@@ -61,7 +61,7 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: start token2wav asr server for reward function"
python3 token2wav_asr_server.py --number-of-devices 8
fi
fi
exp_name=official_llm_aishell3_grpo
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@@ -125,7 +125,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
--backend fsdp \
--local_dir $llm_path/actor \
--target_dir $llm_path/merged_hf_model || exit 1
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: Test the model"

View File

@@ -1,5 +1,3 @@
#!/usr/bin/env python3
#
# Copyright (c) 2023 by manyeyes
# Copyright (c) 2023 Xiaomi Corporation
@@ -195,7 +193,7 @@ def write_error_stats(
hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results:
for _cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali:
if ref_word == ERR:

View File

@@ -295,7 +295,7 @@ def main():
metrics_port=8002,
)
device_ids = [i for i in range(args.number_of_devices)]
device_ids = list(range(args.number_of_devices))
device_ids = device_ids * args.number_of_instances_per_device
with Triton(config=triton_config) as triton:

View File

@@ -34,7 +34,7 @@ tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux'
tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
torch==2.3.1
torchaudio==2.3.1
transformers==4.40.1
transformers==4.51.3
uvicorn==0.30.0
wetext==0.0.4
wget==3.2

View File

@@ -9,5 +9,5 @@ RUN apt-get -y install git unzip git-lfs g++
RUN git lfs install
RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
# here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto

View File

@@ -0,0 +1,141 @@
## Accelerating CosyVoice with DiT-based Token2Wav, NVIDIA Triton Inference Server and TensorRT-LLM
Contributed by Yuekai Zhang (NVIDIA).
This document describes how to accelerate CosyVoice with a DiT-based Token2Wav module from Step-Audio2, using NVIDIA Triton Inference Server and TensorRT-LLM.
### Quick Start
Launch the service directly with Docker Compose:
```sh
docker compose -f docker-compose.dit.yml up
```
### Build the Docker Image
To build the image from scratch:
```sh
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
```
### Run a Docker Container
```sh
your_mount_dir=/mnt:/mnt
docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
```
### Understanding `run_stepaudio2_dit_token2wav.sh`
The `run_stepaudio2_dit_token2wav.sh` script orchestrates the entire workflow through numbered stages.
You can run a subset of stages with:
```sh
bash run_stepaudio2_dit_token2wav.sh <start_stage> <stop_stage>
```
- `<start_stage>`: The stage to start from.
- `<stop_stage>`: The stage to stop after.
**Stages:**
- **Stage -1**: Clones the `Step-Audio2` and `CosyVoice` repositories.
- **Stage 0**: Downloads the `cosyvoice2_llm`, `CosyVoice2-0.5B`, and `Step-Audio-2-mini` models.
- **Stage 1**: Converts the HuggingFace checkpoint for the LLM to the TensorRT-LLM format and builds the TensorRT engines.
- **Stage 2**: Creates the Triton model repository, including configurations for `cosyvoice2_dit` and `token2wav_dit`.
- **Stage 3**: Launches the Triton Inference Server for Token2Wav module and uses `trtllm-serve` to deploy Cosyvoice2 LLM.
- **Stage 4**: Runs the gRPC benchmark client for performance testing.
- **Stage 5**: Runs the offline TTS inference benchmark test.
- **Stage 6**: Runs a standalone inference script for the Step-Audio2-mini DiT Token2Wav model.
- **Stage 7**: Launches servers in a disaggregated setup, with the LLM on GPU 0 and Token2Wav servers on GPUs 1-3.
- **Stage 8**: Runs the benchmark client for the disaggregated server configuration.
### Export Models and Launch Server
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
```sh
# This command runs stages 0, 1, 2, and 3
bash run_stepaudio2_dit_token2wav.sh 0 3
```
### Benchmark with client-server mode
To benchmark the running Triton server, run stage 4:
```sh
bash run_stepaudio2_dit_token2wav.sh 4 4
# You can customize parameters such as the number of tasks inside the script.
```
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
#### Total Request Latency
| Concurrent Tasks | RTF | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
| ---------------- | ------ | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
| 1 | 0.1228 | 833.66 | 779.98 | 1297.05 | 1555.97 | 1653.02 |
| 2 | 0.0901 | 1166.23 | 1124.69 | 1762.76 | 1900.64 | 2204.14 |
| 4 | 0.0741 | 1849.30 | 1759.42 | 2624.50 | 2822.20 | 3128.42 |
| 6 | 0.0774 | 2936.13 | 3054.64 | 3849.60 | 3900.49 | 4245.79 |
| 8 | 0.0691 | 3408.56 | 3434.98 | 4547.13 | 5047.76 | 5346.53 |
| 10 | 0.0707 | 4306.56 | 4343.44 | 5769.64 | 5876.09 | 5939.79 |
#### First Chunk Latency
| Concurrent Tasks | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
| ---------------- | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
| 1 | 197.50 | 196.13 | 214.65 | 215.96 | 229.21 |
| 2 | 281.15 | 278.20 | 345.18 | 361.79 | 395.97 |
| 4 | 510.65 | 530.50 | 630.13 | 642.44 | 666.65 |
| 6 | 921.54 | 918.86 | 1079.97 | 1265.22 | 1524.41 |
| 8 | 1019.95 | 1085.26 | 1371.05 | 1402.24 | 1410.66 |
| 10 | 1214.98 | 1293.54 | 1575.36 | 1654.51 | 2161.76 |
### Benchmark with offline inference mode
For offline inference mode benchmark, please run stage 5:
```sh
bash run_stepaudio2_dit_token2wav.sh 5 5
```
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
#### Offline TTS (Cosyvoice2 0.5B LLM + StepAudio2 DiT Token2Wav)
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|---------|------------|------------------|-----------------------|--|
| TRTLLM | 16 | 2.01 | 5.03 | 0.0292 |
### Disaggregated Server
When the LLM and token2wav components are deployed on the same GPU, they compete for resources. To optimize performance, we use a disaggregated setup where the LLM is deployed on one dedicated L20 GPU, taking advantage of in-flight batching for inference. The token2wav module is deployed on separate, dedicated GPUs.
The table below shows the first chunk latency results for this configuration. In our tests, we deploy two token2wav instances on each dedicated token2wav GPU.
| token2wav_num_gpu | concurrent_task_per_instance | concurrent_tasks_per_gpu | avg (ms) | p50 (ms) | p90 (ms) | p99 (ms) |
|---|---|---|---|---|---|---|
| 1 | 1 | 1.00 | 218.53 | 217.86 | 254.07 | 296.49 |
| 2 | 1 | 1.33 | 218.82 | 219.21 | 256.62 | 303.13 |
| 3 | 1 | 1.50 | 229.08 | 223.27 | 302.13 | 324.41 |
| 4 | 1 | 1.60 | 203.87 | 198.23 | 254.92 | 279.31 |
| 1 | 2 | 2.00 | 293.46 | 280.53 | 370.81 | 407.40 |
| 2 | 2 | 2.67 | 263.38 | 236.84 | 350.82 | 397.39 |
| 3 | 2 | 3.00 | 308.09 | 275.48 | 385.22 | 521.45 |
| 4 | 2 | 3.20 | 271.85 | 253.25 | 359.03 | 387.91 |
| 1 | 3 | 3.00 | 389.15 | 373.01 | 469.22 | 542.89 |
| 2 | 3 | 4.00 | 403.48 | 394.80 | 481.24 | 507.75 |
| 3 | 3 | 4.50 | 406.33 | 391.28 | 495.43 | 571.29 |
| 4 | 3 | 4.80 | 436.72 | 383.81 | 638.44 | 879.23 |
| 1 | 4 | 4.00 | 520.12 | 493.98 | 610.38 | 739.85 |
| 2 | 4 | 5.33 | 494.60 | 490.50 | 605.93 | 708.09 |
| 3 | 4 | 6.00 | 538.23 | 508.33 | 687.62 | 736.96 |
| 4 | 4 | 6.40 | 579.68 | 546.20 | 721.53 | 958.04 |
| 1 | 5 | 5.00 | 635.02 | 623.30 | 786.85 | 819.84 |
| 2 | 5 | 6.67 | 598.23 | 617.09 | 741.00 | 788.96 |
| 3 | 5 | 7.50 | 644.78 | 684.40 | 786.45 | 1009.45 |
| 4 | 5 | 8.00 | 733.92 | 642.26 | 1024.79 | 1281.55 |
| 1 | 6 | 6.00 | 715.38 | 745.68 | 887.04 | 906.68 |
| 2 | 6 | 8.00 | 748.31 | 753.94 | 873.59 | 1007.14 |
| 3 | 6 | 9.00 | 900.27 | 822.28 | 1431.14 | 1800.23 |
| 4 | 6 | 9.60 | 857.54 | 820.33 | 1150.30 | 1298.53 |
The `concurrent_task_per_gpu` is calculated as:
`concurrent_task_per_gpu = concurrent_task_per_instance * num_token2wav_instance_per_gpu (2) * token2wav_gpus / (token2wav_gpus + llm_gpus (1))`
### Acknowledgements
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).

View File

@@ -1,15 +1,17 @@
## Best Practices for Serving CosyVoice with NVIDIA Triton Inference Server
## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM
Thanks to the contribution from NVIDIA Yuekai Zhang.
Contributed by Yuekai Zhang (NVIDIA).
### Quick Start
Launch the service directly with Docker Compose:
```sh
docker compose up
```
### Build the Docker Image
Build the image from scratch:
To build the image from scratch:
```sh
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
```
@@ -21,71 +23,124 @@ docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_di
```
### Understanding `run.sh`
The `run.sh` script orchestrates the entire workflow through numbered stages.
Run a subset of stages with:
You can run a subset of stages with:
```sh
bash run.sh <start_stage> <stop_stage> [service_type]
```
- `<start_stage>` stage to start from (0-5).
- `<stop_stage>` stage to stop after (0-5).
- `<start_stage>`: The stage to start from (0-5).
- `<stop_stage>`: The stage to stop after (0-5).
Stages:
- **Stage 0** Download the cosyvoice-2 0.5B model from HuggingFace.
- **Stage 1** Convert the HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines.
- **Stage 2** Create the Triton model repository and configure the model files (adjusts depending on whether `Decoupled=True/False` will be used later).
- **Stage 3** Launch the Triton Inference Server.
- **Stage 4** Run the single-utterance HTTP client.
- **Stage 5** Run the gRPC benchmark client.
**Stages:**
- **Stage 0**: Downloads the `cosyvoice-2 0.5B` model from HuggingFace.
- **Stage 1**: Converts the HuggingFace checkpoint to the TensorRT-LLM format and builds the TensorRT engines.
- **Stage 2**: Creates the Triton model repository and configures the model files. The configuration is adjusted based on whether `Decoupled=True` (streaming) or `Decoupled=False` (offline) will be used.
- **Stage 3**: Launches the Triton Inference Server.
- **Stage 4**: Runs the single-utterance HTTP client for testing.
- **Stage 5**: Runs the gRPC benchmark client.
- **Stage 6**: Runs the offline inference benchmark test.
### Export Models and Launch Server
### Export Models to TensorRT-LLM and Launch the Server
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
```sh
# Runs stages 0, 1, 2, and 3
# This command runs stages 0, 1, 2, and 3
bash run.sh 0 3
```
*Note: Stage 2 prepares the model repository differently depending on whether you intend to run with `Decoupled=False` or `Decoupled=True`. Rerun stage 2 if you switch the service type.*
> [!TIP]
> Both streaming and offline (non-streaming) TTS modes are supported. For streaming TTS, set `Decoupled=True`. For offline TTS, set `Decoupled=False`. You need to rerun stage 2 if you switch between modes.
### Single-Utterance HTTP Client
Send a single HTTP inference request:
Sends a single HTTP inference request. This is intended for testing the offline TTS mode (`Decoupled=False`):
```sh
bash run.sh 4 4
```
### Benchmark with a Dataset
Benchmark the running Triton server. Pass either `streaming` or `offline` as the third argument.
```sh
bash run.sh 5 5
### Benchmark with client-server mode
# You can also customise parameters such as num_task and dataset split directly:
To benchmark the running Triton server, pass `streaming` or `offline` as the third argument:
```sh
bash run.sh 5 5 # [streaming|offline]
# You can also customize parameters such as the number of tasks and the dataset split:
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline]
```
> [!TIP]
> Only offline CosyVoice TTS is currently supported. Setting the client to `streaming` simply enables NVIDIA Tritons decoupled mode so that responses are returned as soon as they are ready.
> It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up.
### Benchmark Results
Decoding on a single L20 GPU with 26 prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts) (≈221 s of audio):
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|------|------|-------------|------------------|------------------|-----|
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 659.87 | 655.63 | 0.0891 |
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1103.16 | 992.96 | 0.0693 |
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1790.91 | 1668.63 | 0.0604 |
### OpenAI-Compatible Server
To launch an OpenAI-compatible service, run:
### Benchmark with offline inference mode
For offline inference mode benchmark, please check the below command:
```sh
git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
pip install -r requirements.txt
# After the Triton service is up, start the FastAPI bridge:
python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
# Test with curl
bash test/test_cosyvoice.sh
# install FlashCosyVoice for token2wav batching
# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt
# cd /workspace/FlashCosyVoice
# pip install -e .
# cd -
# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx
bash run.sh 6 6
# You can also switch to huggingface backend by setting backend=hf
```
### Acknowledgements
This section originates from the NVIDIA CISI project. We also provide other multimodal resources—see [mair-hub](https://github.com/nvidia-china-sae/mair-hub) for details.
### Benchmark Results
The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio):
**Client-Server Mode: Streaming TTS (First Chunk Latency)**
| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|---|---|---|---|---|
| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 |
| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 |
| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 |
| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 |
| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 |
> If your service only needs a fixed speaker, you can set `use_spk2info_cache=True` in `run.sh`. To add more speakers, refer to the instructions [here](https://github.com/qi-hua/async_cosyvoice?tab=readme-ov-file#9-spk2info-%E8%AF%B4%E6%98%8E).
**Client-Server Mode: Offline TTS (Full Sentence Latency)**
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|---|---|---|---|---|---|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
**Offline Inference Mode: Hugginface LLM V.S. TensorRT-LLM**
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|---------|------------|------------------|-----------------------|--|
| HF | 1 | 39.26 | 44.31 | 0.2494 |
| HF | 2 | 30.54 | 35.62 | 0.2064 |
| HF | 4 | 18.63 | 23.90 | 0.1421 |
| HF | 8 | 11.22 | 16.45 | 0.0947 |
| HF | 16 | 8.42 | 13.78 | 0.0821 |
| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 |
| TRTLLM | 2 | 7.64 |12.65 | 0.0739 |
| TRTLLM | 4 | 4.89 | 9.38 | 0.0539 |
| TRTLLM | 8 | 2.92 | 7.23 | 0.0418 |
| TRTLLM | 16 | 2.01 | 6.63 | 0.0386 |
### OpenAI-Compatible Server
To launch an OpenAI-compatible API service, run the following commands:
```sh
git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
cd Triton-OpenAI-Speech
pip install -r requirements.txt
# After the Triton service is running, start the FastAPI bridge:
python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
# Test the service with curl:
bash test/test_cosyvoice.sh
```
> [!NOTE]
> Currently, only the offline TTS mode is compatible with the OpenAI-compatible server.
### Acknowledgements
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).

View File

@@ -43,9 +43,9 @@ python3 client_grpc.py \
import argparse
import asyncio
import json
import queue # Added
import uuid # Added
import functools # Added
import queue
import uuid
import functools
import os
import time
@@ -55,16 +55,16 @@ from pathlib import Path
import numpy as np
import soundfile as sf
import tritonclient
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
import tritonclient.grpc as grpcclient_sync # Added sync client import
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
import tritonclient.grpc.aio as grpcclient_aio
import tritonclient.grpc as grpcclient_sync
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
# --- Added UserData and callback ---
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
self._first_chunk_time = None
self._second_chunk_time = None
self._start_time = None
def record_start_time(self):
@@ -75,39 +75,43 @@ class UserData:
return self._first_chunk_time - self._start_time
return None
def get_second_chunk_latency(self):
if self._first_chunk_time and self._second_chunk_time:
return self._second_chunk_time - self._first_chunk_time
return None
def callback(user_data, result, error):
if user_data._first_chunk_time is None and not error:
user_data._first_chunk_time = time.time() # Record time of first successful chunk
if not error:
if user_data._first_chunk_time is None:
user_data._first_chunk_time = time.time()
elif user_data._second_chunk_time is None:
user_data._second_chunk_time = time.time()
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
# --- End Added UserData and callback ---
def stream_callback(user_data_map, result, error):
request_id = None
if error:
print(f"An error occurred in the stream callback: {error}")
else:
request_id = result.get_response().id
if request_id:
user_data = user_data_map.get(request_id)
if user_data:
callback(user_data, result, error)
else:
print(f"Warning: Could not find user_data for request_id {request_id}")
def write_triton_stats(stats, summary_file):
with open(summary_file, "w") as summary_f:
model_stats = stats["model_stats"]
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
summary_f.write(
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
)
summary_f.write("To learn more about the log, please refer to: \n")
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
summary_f.write(
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
)
summary_f.write(
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
)
summary_f.write(
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
)
summary_f.write(
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
)
for model_state in model_stats:
if "last_inference" not in model_state:
continue
@@ -118,7 +122,10 @@ def write_triton_stats(stats, summary_file):
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
summary_f.write(
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
f"queue time {total_queue_time_s:<5.2f} s, "
f"compute infer time {total_infer_time_s:<5.2f} s, "
f"compute input time {total_input_time_s:<5.2f} s, "
f"compute output time {total_output_time_s:<5.2f} s \n"
)
model_batch_stats = model_state["batch_stats"]
for batch in model_batch_stats:
@@ -127,21 +134,86 @@ def write_triton_stats(stats, summary_file):
compute_output = batch["compute_output"]
compute_infer = batch["compute_infer"]
batch_count = int(compute_infer["count"])
if batch_count == 0:
continue
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
compute_input_time_ms = int(compute_input["ns"]) / 1e6
compute_output_time_ms = int(compute_output["ns"]) / 1e6
summary_f.write(
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
f"{compute_infer_time_ms / batch_count:.2f} ms, "
f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
)
summary_f.write(
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
)
summary_f.write(
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
)
def subtract_stats(stats_after, stats_before):
"""Subtracts two Triton inference statistics objects."""
stats_diff = json.loads(json.dumps(stats_after))
model_stats_before_map = {
s["name"]: {
"version": s["version"],
"last_inference": s.get("last_inference", 0),
"inference_count": s.get("inference_count", 0),
"execution_count": s.get("execution_count", 0),
"inference_stats": s.get("inference_stats", {}),
"batch_stats": s.get("batch_stats", []),
}
for s in stats_before["model_stats"]
}
for model_stat_after in stats_diff["model_stats"]:
model_name = model_stat_after["name"]
if model_name in model_stats_before_map:
model_stat_before = model_stats_before_map[model_name]
model_stat_after["inference_count"] = str(
int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
)
model_stat_after["execution_count"] = str(
int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
)
if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
if "ns" in model_stat_after["inference_stats"][key]:
ns_after = int(model_stat_after["inference_stats"][key]["ns"])
ns_before = int(model_stat_before["inference_stats"][key]["ns"])
model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
if "count" in model_stat_after["inference_stats"][key]:
count_after = int(model_stat_after["inference_stats"][key]["count"])
count_before = int(model_stat_before["inference_stats"][key]["count"])
model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
for batch_stat_after in model_stat_after["batch_stats"]:
bs = batch_stat_after["batch_size"]
if bs in batch_stats_before_map:
batch_stat_before = batch_stats_before_map[bs]
for key in ["compute_input", "compute_infer", "compute_output"]:
if key in batch_stat_after and key in batch_stat_before:
count_after = int(batch_stat_after[key]["count"])
count_before = int(batch_stat_before[key]["count"])
batch_stat_after[key]["count"] = str(count_after - count_before)
ns_after = int(batch_stat_after[key]["ns"])
ns_before = int(batch_stat_before[key]["ns"])
batch_stat_after[key]["ns"] = str(ns_after - ns_before)
return stats_diff
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@@ -209,7 +281,8 @@ def get_args():
choices=[
"f5_tts",
"spark_tts",
"cosyvoice2"],
"cosyvoice2",
"cosyvoice2_dit"],
help="triton model_repo module name to request",
)
@@ -243,7 +316,6 @@ def get_args():
help="log directory",
)
# --- Added arguments ---
parser.add_argument(
"--mode",
type=str,
@@ -257,7 +329,13 @@ def get_args():
default=0.1,
help="Chunk overlap duration for streaming reconstruction (in seconds)."
)
# --- End Added arguments ---
parser.add_argument(
"--use-spk2info-cache",
type=str,
default="False",
help="Use spk2info cache for reference audio.",
)
return parser.parse_args()
@@ -278,38 +356,33 @@ def load_audio(wav_path, target_sample_rate=16000):
def prepare_request_input_output(
protocol_client, # Can be grpcclient_aio or grpcclient_sync
protocol_client,
waveform,
reference_text,
target_text,
sample_rate=16000,
padding_duration: int = None # Optional padding for offline mode
padding_duration: int = None,
use_spk2info_cache: bool = False
):
"""Prepares inputs for Triton inference (offline or streaming)."""
assert len(waveform.shape) == 1, "waveform should be 1D"
lengths = np.array([[len(waveform)]], dtype=np.int32)
# Apply padding only if padding_duration is provided (for offline)
if padding_duration:
duration = len(waveform) / sample_rate
# Estimate target duration based on text length ratio (crude estimation)
# Avoid division by zero if reference_text is empty
if reference_text:
estimated_target_duration = duration / len(reference_text) * len(target_text)
else:
estimated_target_duration = duration # Assume target duration similar to reference if no text
estimated_target_duration = duration
# Calculate required samples based on estimated total duration
required_total_samples = padding_duration * sample_rate * (
(int(estimated_target_duration + duration) // padding_duration) + 1
)
samples = np.zeros((1, required_total_samples), dtype=np.float32)
samples[0, : len(waveform)] = waveform
else:
# No padding for streaming or if padding_duration is None
samples = waveform.reshape(1, -1).astype(np.float32)
# Common input creation logic
inputs = [
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
protocol_client.InferInput(
@@ -330,7 +403,8 @@ def prepare_request_input_output(
inputs[3].set_data_from_numpy(input_data_numpy)
outputs = [protocol_client.InferRequestedOutput("waveform")]
if use_spk2info_cache:
inputs = inputs[-1:]
return inputs, outputs
@@ -347,12 +421,8 @@ def run_sync_streaming_inference(
):
"""Helper function to run the blocking sync streaming call."""
start_time_total = time.time()
user_data.record_start_time() # Record start time for first chunk latency calculation
user_data.record_start_time()
# Establish stream
sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
# Send request
sync_triton_client.async_stream_infer(
model_name,
inputs,
@@ -361,84 +431,76 @@ def run_sync_streaming_inference(
enable_empty_final_response=True,
)
# Process results
audios = []
while True:
try:
result = user_data._completed_requests.get() # Add timeout
result = user_data._completed_requests.get(timeout=200)
if isinstance(result, InferenceServerException):
print(f"Received InferenceServerException: {result}")
sync_triton_client.stop_stream()
return None, None, None # Indicate error
# Get response metadata
return None, None, None, None
response = result.get_response()
final = response.parameters["triton_final_response"].bool_param
if final is True:
break
audio_chunk = result.as_numpy("waveform").reshape(-1)
if audio_chunk.size > 0: # Only append non-empty chunks
if audio_chunk.size > 0:
audios.append(audio_chunk)
else:
print("Warning: received empty audio chunk.")
except queue.Empty:
print(f"Timeout waiting for response for request id {request_id}")
sync_triton_client.stop_stream()
return None, None, None # Indicate error
return None, None, None, None
sync_triton_client.stop_stream()
end_time_total = time.time()
total_request_latency = end_time_total - start_time_total
first_chunk_latency = user_data.get_first_chunk_latency()
second_chunk_latency = user_data.get_second_chunk_latency()
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
actual_duration = 0
if audios:
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
reconstructed_audio = None
if model_name == "spark_tts":
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
reconstructed_audio = None
# Simplified reconstruction based on client_grpc_streaming.py
if not audios:
print("Warning: No audio chunks received.")
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
elif len(audios) == 1:
reconstructed_audio = audios[0]
if not audios:
print("Warning: No audio chunks received.")
reconstructed_audio = np.array([], dtype=np.float32)
elif len(audios) == 1:
reconstructed_audio = audios[0]
else:
reconstructed_audio = audios[0][:-cross_fade_samples]
for i in range(1, len(audios)):
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
audios[i - 1][-cross_fade_samples:] * fade_out)
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
if reconstructed_audio is not None and reconstructed_audio.size > 0:
actual_duration = len(reconstructed_audio) / save_sample_rate
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
else:
print("Warning: No audio chunks received or reconstructed.")
actual_duration = 0
else:
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
for i in range(1, len(audios)):
# Cross-fade section
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
audios[i - 1][-cross_fade_samples:] * fade_out)
# Middle section of the current chunk
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
# Concatenate
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
# Add the last part of the final chunk
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
if reconstructed_audio is not None and reconstructed_audio.size > 0:
reconstructed_audio = np.concatenate(audios)
actual_duration = len(reconstructed_audio) / save_sample_rate
# Save reconstructed audio
os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
else:
print("Warning: No audio chunks received or reconstructed.")
actual_duration = 0 # Set duration to 0 if no audio
else:
print("Warning: No audio chunks received.")
actual_duration = 0
return total_request_latency, first_chunk_latency, actual_duration
return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
async def send_streaming(
manifest_item_list: list,
name: str,
server_url: str, # Changed from sync_triton_client
server_url: str,
protocol_client: types.ModuleType,
log_interval: int,
model_name: str,
@@ -446,15 +508,18 @@ async def send_streaming(
save_sample_rate: int = 16000,
chunk_overlap_duration: float = 0.1,
padding_duration: int = None,
use_spk2info_cache: bool = False,
):
total_duration = 0.0
latency_data = []
task_id = int(name[5:])
sync_triton_client = None # Initialize client variable
sync_triton_client = None
user_data_map = {}
try: # Wrap in try...finally to ensure client closing
try:
print(f"{name}: Initializing sync client for streaming...")
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
for i, item in enumerate(manifest_item_list):
@@ -471,14 +536,16 @@ async def send_streaming(
reference_text,
target_text,
sample_rate,
padding_duration=padding_duration
padding_duration=padding_duration,
use_spk2info_cache=use_spk2info_cache
)
request_id = str(uuid.uuid4())
user_data = UserData()
user_data_map[request_id] = user_data
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
run_sync_streaming_inference,
sync_triton_client,
model_name,
@@ -492,12 +559,18 @@ async def send_streaming(
)
if total_request_latency is not None:
print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
print(
f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
)
latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
total_duration += actual_duration
else:
print(f"{name}: Item {i} failed.")
del user_data_map[request_id]
except FileNotFoundError:
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
except Exception as e:
@@ -505,10 +578,11 @@ async def send_streaming(
import traceback
traceback.print_exc()
finally: # Ensure client is closed
finally:
if sync_triton_client:
try:
print(f"{name}: Closing sync client...")
print(f"{name}: Closing stream and sync client...")
sync_triton_client.stop_stream()
sync_triton_client.close()
except Exception as e:
print(f"{name}: Error closing sync client: {e}")
@@ -527,12 +601,12 @@ async def send(
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
use_spk2info_cache: bool = False,
):
total_duration = 0.0
latency_data = []
task_id = int(name[5:])
print(f"manifest_item_list: {manifest_item_list}")
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
@@ -545,7 +619,8 @@ async def send(
reference_text,
target_text,
sample_rate,
padding_duration=padding_duration
padding_duration=padding_duration,
use_spk2info_cache=use_spk2info_cache
)
sequence_id = 100000000 + i + task_id * 10
start = time.time()
@@ -572,7 +647,6 @@ def load_manifests(manifest_path):
assert len(line.strip().split("|")) == 4
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
utt = Path(utt).stem
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
manifest_list.append(
@@ -613,23 +687,17 @@ async def main():
args = get_args()
url = f"{args.server_addr}:{args.server_port}"
# --- Client Initialization based on mode ---
triton_client = None
protocol_client = None
if args.mode == "offline":
print("Initializing gRPC client for offline mode...")
# Use the async client for offline tasks
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
protocol_client = grpcclient_aio
elif args.mode == "streaming":
print("Initializing gRPC client for streaming mode...")
# Use the sync client for streaming tasks, handled via asyncio.to_thread
# We will create one sync client instance PER TASK inside send_streaming.
# triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
protocol_client = grpcclient_sync # protocol client for input prep
protocol_client = grpcclient_sync
else:
raise ValueError(f"Invalid mode: {args.mode}")
# --- End Client Initialization ---
if args.reference_audio:
args.num_tasks = 1
@@ -663,14 +731,24 @@ async def main():
else:
manifest_item_list = load_manifests(args.manifest_path)
stats_client = None
stats_before = None
try:
print("Initializing temporary async client for fetching stats...")
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
print("Fetching inference statistics before running tasks...")
stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
except Exception as e:
print(f"Could not retrieve statistics before running tasks: {e}")
num_tasks = min(args.num_tasks, len(manifest_item_list))
manifest_item_list = split_data(manifest_item_list, num_tasks)
os.makedirs(args.log_dir, exist_ok=True)
args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
tasks = []
start_time = time.time()
for i in range(num_tasks):
# --- Task Creation based on mode ---
if args.mode == "offline":
task = asyncio.create_task(
send(
@@ -683,6 +761,7 @@ async def main():
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
use_spk2info_cache=args.use_spk2info_cache,
)
)
elif args.mode == "streaming":
@@ -690,7 +769,7 @@ async def main():
send_streaming(
manifest_item_list[i],
name=f"task-{i}",
server_url=url, # Pass URL instead of client
server_url=url,
protocol_client=protocol_client,
log_interval=args.log_interval,
model_name=args.model_name,
@@ -698,9 +777,9 @@ async def main():
padding_duration=10,
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
chunk_overlap_duration=args.chunk_overlap_duration,
use_spk2info_cache=args.use_spk2info_cache,
)
)
# --- End Task Creation ---
tasks.append(task)
ans_list = await asyncio.gather(*tasks)
@@ -713,7 +792,7 @@ async def main():
for ans in ans_list:
if ans:
total_duration += ans[0]
latency_data.extend(ans[1]) # Use extend for list of lists
latency_data.extend(ans[1])
else:
print("Warning: A task returned None, possibly due to an error.")
@@ -729,10 +808,8 @@ async def main():
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
# --- Statistics Reporting based on mode ---
if latency_data:
if args.mode == "offline":
# Original offline latency calculation
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
if latency_list:
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
@@ -747,9 +824,9 @@ async def main():
s += "No latency data collected for offline mode.\n"
elif args.mode == "streaming":
# Calculate stats for total request latency and first chunk latency
total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
s += "\n--- Total Request Latency ---\n"
if total_latency_list:
@@ -776,9 +853,21 @@ async def main():
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
else:
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
s += "\n--- Second Chunk Latency ---\n"
if second_chunk_latency_list:
avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
else:
s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
else:
s += "No latency data collected.\n"
# --- End Statistics Reporting ---
print(s)
if args.manifest_path:
@@ -788,26 +877,27 @@ async def main():
elif args.reference_audio:
name = Path(args.reference_audio).stem
else:
name = "results" # Default name if no manifest/split/audio provided
name = "results"
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
f.write(s)
# --- Statistics Fetching using temporary Async Client ---
# Use a separate async client for fetching stats regardless of mode
stats_client = None
try:
print("Initializing temporary async client for fetching stats...")
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
print("Fetching inference statistics...")
# Fetching for all models, filtering might be needed depending on server setup
stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
print("Fetching model config...")
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
if stats_client and stats_before:
print("Fetching inference statistics after running tasks...")
stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
print("Calculating statistics difference...")
stats = subtract_stats(stats_after, stats_before)
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
json.dump(metadata, f, indent=4)
print("Fetching model config...")
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
json.dump(metadata, f, indent=4)
else:
print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
except Exception as e:
print(f"Could not retrieve statistics or config: {e}")
@@ -818,11 +908,9 @@ async def main():
await stats_client.close()
except Exception as e:
print(f"Error closing async stats client: {e}")
# --- End Statistics Fetching ---
if __name__ == "__main__":
# asyncio.run(main()) # Use TaskGroup for better exception handling if needed
async def run_main():
try:
await main()

View File

@@ -25,7 +25,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import requests
import soundfile as sf
import json
import numpy as np
import argparse

View File

@@ -0,0 +1,20 @@
services:
tts:
image: soar97/triton-cosyvoice:25.06
shm_size: '1gb'
ports:
- "8000:8000"
- "8001:8001"
- "8002:8002"
environment:
- PYTHONIOENCODING=utf-8
- MODEL_ID=${MODEL_ID}
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: >
/bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run_stepaudio2_dit_token2wav.sh 0 3"

View File

@@ -32,7 +32,7 @@ import triton_python_backend_utils as pb_utils
import os
import numpy as np
import s3tokenizer
torch.set_num_threads(1)
ORIGINAL_VOCAB_SIZE = 151663

View File

@@ -20,7 +20,7 @@ dynamic_batching {
}
parameters [
{
key: "model_dir",
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]

View File

@@ -25,23 +25,24 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import math
import os
import re
from typing import Dict, List, Tuple, Optional, Union
import threading
import time
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
import torchaudio.compliance.kaldi as kaldi
import torchaudio
import onnxruntime
from matcha.utils.audio import mel_spectrogram
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
class TritonPythonModel:
"""Triton Python model for Spark TTS.
@@ -62,6 +63,8 @@ class TritonPythonModel:
parameters = self.model_config['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.logger.log_info(f"model_params:{model_params}")
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
@@ -72,11 +75,15 @@ class TritonPythonModel:
self.device = torch.device("cuda")
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
campplus_model = f'{model_params["model_dir"]}/campplus.onnx'
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
self.token_frame_rate = 25
self.flow_pre_lookahead_len = 3
self.token_hop_len = 15
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
def forward_llm(self, input_ids):
"""
@@ -105,7 +112,7 @@ class TritonPythonModel:
"""
# convert input_ids to numpy, with shape [1, sequence_length]
input_ids = input_ids.cpu().numpy()
max_tokens = 1024
max_tokens = 750
input_dict = {
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
@@ -114,6 +121,8 @@ class TritonPythonModel:
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
"runtime_top_k": np.array([[50]], dtype=np.int32),
"temperature": np.array([[0.8]], dtype=np.float32),
"repetition_penalty": np.array([[1.1]], dtype=np.float32),
"random_seed": np.array([[42]], dtype=np.uint64),
"input_ids": input_ids,
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
}
@@ -188,12 +197,40 @@ class TritonPythonModel:
return prompt_speech_tokens
def forward_speaker_embedding(self, wav):
"""Forward pass through the speaker embedding component.
Args:
wav: Input waveform tensor
Returns:
Prompt speaker embedding tensor
"""
inference_request = pb_utils.InferenceRequest(
model_name='speaker_embedding',
requested_output_names=['prompt_spk_embedding'],
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
return prompt_spk_embedding
def forward_token2wav(
self,
prompt_speech_tokens: torch.Tensor,
prompt_speech_feat: torch.Tensor,
prompt_spk_embedding: torch.Tensor,
target_speech_tokens: torch.Tensor) -> torch.Tensor:
target_speech_tokens: torch.Tensor,
request_id: str,
prompt_speech_tokens: torch.Tensor = None,
prompt_speech_feat: torch.Tensor = None,
prompt_spk_embedding: torch.Tensor = None,
token_offset: int = None,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
Args:
@@ -205,16 +242,30 @@ class TritonPythonModel:
Returns:
Generated waveform tensor
"""
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
inputs_tensor = [target_speech_tokens_tensor]
if token_offset is not None:
assert finalize is not None
token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor.append(token_offset_tensor)
inputs_tensor.append(finalize_tensor)
if prompt_spk_embedding is not None:
assert prompt_speech_feat is not None
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav',
requested_output_names=['waveform'],
inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
inputs=inputs_tensor,
request_id=request_id,
)
inference_response = inference_request.exec()
@@ -235,17 +286,6 @@ class TritonPythonModel:
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
return input_ids
def _extract_spk_embedding(self, speech):
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.campplus_session.run(None,
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device).half()
return embedding
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(
speech,
@@ -263,6 +303,14 @@ class TritonPythonModel:
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
for generated_ids in generated_ids_iter:
generated_ids = generated_ids.tolist()
if len(generated_ids) == 0:
break
semantic_token_ids_arr.extend(generated_ids)
llm_is_done_flag[0] = True
def execute(self, requests):
"""Execute inference on the batched requests.
@@ -275,25 +323,33 @@ class TritonPythonModel:
responses = []
for request in requests:
request_id = request.request_id()
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
# Process reference audio through audio tokenizer
if wav is not None:
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
else:
# using pre-cached reference text
reference_text = self.default_spk_info["prompt_text"]
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
prompt_speech_feat = None
prompt_spk_embedding = None
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
@@ -310,22 +366,73 @@ class TritonPythonModel:
if self.decoupled:
response_sender = request.get_response_sender()
request_id = request.request_id()
generated_ids = []
for generated_id in generated_ids_iter:
# convert the numpy array into a int32 tensor
generated_id = generated_id.tolist()
if len(generated_id) > 0:
assert len(generated_id) == 1, "Generated ID is not a single integer"
generated_ids.append(generated_id[0])
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
# Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
semantic_token_ids_arr = []
llm_is_done_flag = [False]
llm_thread = threading.Thread(
target=self._llm_gen_thread,
args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
)
llm_thread.start()
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
while True:
pending_num = len(semantic_token_ids_arr) - token_offset
if llm_is_done_flag[0]:
break
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(
this_tts_speech_token, request_id, prompt_speech_tokens,
prompt_speech_feat, prompt_spk_embedding, token_offset, False
)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
token_offset += this_token_hop_len
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time
duration = token_offset / self.token_frame_rate
if chunk_index > 0 and cost_time > 0:
avg_chunk_processing_time = cost_time / (chunk_index + 1)
if avg_chunk_processing_time > 0:
multiples = (duration - cost_time) / avg_chunk_processing_time
self.logger.log_info(f"multiples: {multiples}")
next_pending_num = len(semantic_token_ids_arr) - token_offset
if multiples > 4:
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
elif multiples > 2:
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
else:
this_token_hop_len = self.token_hop_len
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
chunk_index += 1
else:
time.sleep(0.02)
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
llm_thread.join()
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:
@@ -334,8 +441,7 @@ class TritonPythonModel:
if generated_ids is None or len(generated_ids) == 0:
raise pb_utils.TritonModelException("Generated IDs is None or empty")
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
# Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))

View File

@@ -23,11 +23,11 @@ model_transaction_policy {
}
parameters [
{
key: "llm_tokenizer_dir",
key: "llm_tokenizer_dir",
value: {string_value:"${llm_tokenizer_dir}"}
},
{
key: "model_dir",
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
@@ -37,16 +37,19 @@ input [
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: true
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: true
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
optional: true
},
{
name: "target_text"

View File

@@ -0,0 +1,394 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import math
import os
import re
import time
from typing import Dict, List, Tuple, Optional, Union
import asyncio
import httpx
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
import torchaudio
from matcha.utils.audio import mel_spectrogram
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def parse_speech_token_string(response_text: str) -> List[int]:
"""
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
"""
speech_tokens = response_text.strip().split('><')
if len(speech_tokens) > 1:
# Add back the missing '<' and '>' for proper parsing
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
speech_ids = []
for token_str in speech_tokens:
match = re.match(r'<\|s_(\d+)\|>', token_str)
if match:
speech_ids.append(int(match.group(1)))
return speech_ids
class TritonPythonModel:
"""Triton Python model for Spark TTS.
This model orchestrates the end-to-end TTS pipeline by coordinating
between audio tokenizer, LLM, and vocoder components.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
self.logger = pb_utils.Logger
# Parse model parameters
self.model_config = json.loads(args['model_config'])
parameters = self.model_config['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
self.device = torch.device("cuda")
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
self.token_frame_rate = 25
self.flow_pre_lookahead_len = 3
self.token_hop_len = 15
self.http_client = httpx.AsyncClient()
self.api_base = "http://localhost:8000/v1/chat/completions"
self.speaker_cache = {}
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
"""Converts a tensor or list of speech token IDs to a string representation."""
if isinstance(speech_tokens, torch.Tensor):
# Ensure tensor is on CPU and flattened
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
speech_id_str = ""
for token_id in speech_tokens:
# Convert token ID back to the speech number N
token_num = token_id - ORIGINAL_VOCAB_SIZE
speech_id_str += f"<|s_{token_num}|>"
return speech_id_str
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
"""
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
"""
full_text = f"{reference_text}{target_text}"
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
chat = [
{"role": "user", "content": full_text},
{"role": "assistant", "content": prompt_speech_tokens_str}
]
payload = {
"model": "trt_engines_bfloat16",
"messages": chat,
"max_tokens": 750,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": True,
}
buffer = ""
async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
line_data = line[len("data: "):].strip()
if line_data == "[DONE]":
break
try:
json_data = json.loads(line_data)
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
if content:
buffer += content
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
except json.JSONDecodeError:
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
continue
# Process any remaining complete tokens in the buffer after the stream ends
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component.
Args:
wav: Input waveform tensor
wav_len: Waveform length tensor
Returns:
Tuple of global and semantic tokens
"""
inference_request = pb_utils.InferenceRequest(
model_name='audio_tokenizer',
requested_output_names=['prompt_speech_tokens'],
inputs=[wav, wav_len]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
return prompt_speech_tokens
def forward_speaker_embedding(self, wav):
"""Forward pass through the speaker embedding component.
Args:
wav: Input waveform tensor
Returns:
Prompt speaker embedding tensor
"""
inference_request = pb_utils.InferenceRequest(
model_name='speaker_embedding',
requested_output_names=['prompt_spk_embedding'],
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
return prompt_spk_embedding
async def forward_token2wav(
self,
index: int,
target_speech_tokens: torch.Tensor,
request_id: str,
reference_wav: object,
reference_wav_len: object,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
Args:
index: Index of the request
target_speech_tokens: Target speech tokens tensor
request_id: Request ID
reference_wav: Reference waveform tensor
reference_wav_len: Reference waveform length tensor
finalize: Whether to finalize the request
Returns:
Generated waveform tensor
"""
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav_dit',
requested_output_names=[
"waveform",
],
inputs=inputs_tensor,
request_id=request_id,
parameters={"priority": index + 1},
)
inference_response = await inference_request.async_exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(
speech,
n_fft=1920,
num_mels=80,
sampling_rate=24000,
hop_size=480,
win_size=1920,
fmin=0,
fmax=8000).squeeze(
dim=0).transpose(
0,
1).to(
self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
async def _process_request(self, request):
request_id = request.request_id()
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
if reference_text not in self.speaker_cache:
self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0)
prompt_speech_tokens = self.speaker_cache[reference_text]
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
if self.decoupled:
response_sender = request.get_response_sender()
semantic_token_ids_arr = []
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
async for generated_ids in self.forward_llm_async(
target_text=target_text,
reference_text=reference_text,
prompt_speech_tokens=prompt_speech_tokens,
):
if not generated_ids:
break
semantic_token_ids_arr.append(generated_ids)
while True:
pending_num = len(semantic_token_ids_arr) - token_offset
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = await self.forward_token2wav(
chunk_index,
this_tts_speech_token, request_id, wav, wav_len, False
)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
token_offset += this_token_hop_len
if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "equal":
this_token_hop_len = self.token_hop_len
elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time
duration = token_offset / self.token_frame_rate
if chunk_index > 0 and cost_time > 0:
avg_chunk_processing_time = cost_time / (chunk_index + 1)
if avg_chunk_processing_time > 0:
multiples = (duration - cost_time) / avg_chunk_processing_time
next_pending_num = len(semantic_token_ids_arr) - token_offset
if multiples > 4:
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
elif multiples > 2:
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
else:
this_token_hop_len = self.token_hop_len
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
chunk_index += 1
else:
break
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
else:
raise NotImplementedError("Offline TTS mode is not supported")
async def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated audio
"""
tasks = [
asyncio.create_task(self._process_request(request))
for request in requests
]
await asyncio.gather(*tasks)
return None
def finalize(self):
self.logger.log_info("Finalizing CosyVoice DIT model")
if hasattr(self, "http_client"):
asyncio.run(self.http_client.aclose())

View File

@@ -0,0 +1,73 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "cosyvoice2_dit"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
model_transaction_policy {
decoupled: ${decoupled_mode}
}
parameters [
{
key: "llm_tokenizer_dir",
value: {string_value:"${llm_tokenizer_dir}"}
},
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: true
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: true
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
optional: true
},
{
name: "target_text"
data_type: TYPE_STRING
dims: [1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: ${bls_instance_num}
kind: KIND_CPU
}
]

View File

@@ -0,0 +1,153 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import torch
from torch.utils.dlpack import to_dlpack
import triton_python_backend_utils as pb_utils
import os
import numpy as np
import torchaudio.compliance.kaldi as kaldi
from cosyvoice.utils.file_utils import convert_onnx_to_trt
from cosyvoice.utils.common import TrtContextWrapper
import onnxruntime
class TritonPythonModel:
"""Triton Python model for audio tokenization.
This model takes reference audio input and extracts semantic tokens
using s3tokenizer.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.device = torch.device("cuda")
model_dir = model_params["model_dir"]
gpu = "l20"
enable_trt = True
if enable_trt:
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False)
else:
campplus_model = f'{model_dir}/campplus.onnx'
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
trt_kwargs = self.get_spk_trt_kwargs()
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
import tensorrt as trt
with open(spk_model, 'rb') as f:
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_spk_trt_kwargs(self):
min_shape = [(1, 4, 80)]
opt_shape = [(1, 500, 80)]
max_shape = [(1, 3000, 80)]
input_names = ["input"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def _extract_spk_embedding(self, speech):
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
spk_feat = feat - feat.mean(dim=0, keepdim=True)
if isinstance(self.spk_model, onnxruntime.InferenceSession):
embedding = self.spk_model.run(
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device)
else:
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
# NOTE need to synchronize when switching stream
with torch.cuda.device(self.device):
torch.cuda.current_stream().synchronize()
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
batch_size = spk_feat.size(0)
with stream:
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
embedding = torch.empty((batch_size, 192), device=spk_feat.device)
data_ptrs = [spk_feat.contiguous().data_ptr(),
embedding.contiguous().data_ptr()]
for i, j in enumerate(data_ptrs):
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.spk_model.release_estimator(spk_model, stream)
return embedding.half()
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing tokenized outputs
"""
responses = []
# Process each request in batch
for request in requests:
# Extract input tensors
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_array = torch.from_numpy(wav_array).to(self.device)
embedding = self._extract_spk_embedding(wav_array)
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
"prompt_spk_embedding", to_dlpack(embedding))
inference_response = pb_utils.InferenceResponse(
output_tensors=[prompt_spk_embedding_tensor])
responses.append(inference_response)
return responses

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "speaker_embedding"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
}
]
output [
{
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

View File

@@ -28,26 +28,30 @@ import json
import os
import logging
from typing import List, Dict
import torch
from torch.utils.dlpack import to_dlpack
from torch.nn import functional as F
import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
from collections import defaultdict
import numpy as np
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
class CosyVoice2:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
self.model_dir = model_dir
self.fp16 = fp16
@@ -57,7 +61,7 @@ class CosyVoice2:
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
@@ -73,14 +77,22 @@ class CosyVoice2Model:
def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fp16: bool = False,
device: str = 'cuda'):
self.device = device
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.flow.half()
# streaming tts config
self.token_hop_len = 25
self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480)
self.speech_window = np.hamming(2 * self.source_cache_len)
self.hift_cache_dict = defaultdict(lambda: None)
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
@@ -111,6 +123,42 @@ class CosyVoice2Model:
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if finalize is False:
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
'source': tts_source[:, :, -self.source_cache_len:],
'speech': tts_speech[:, -self.source_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech
class TritonPythonModel:
"""Triton Python model for vocoder.
@@ -131,13 +179,19 @@ class TritonPythonModel:
model_dir = model_params["model_dir"]
# Initialize device and vocoder
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
self.token2wav_model = CosyVoice2(
model_dir, load_jit=True, load_trt=True, fp16=True
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
)
spk_info_path = os.path.join(model_dir, "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_dir}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
logger.info("Token2Wav initialized successfully")
def execute(self, requests):
@@ -153,38 +207,66 @@ class TritonPythonModel:
# Process each request in batch
for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
if prompt_speech_tokens_tensor is not None:
prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
else:
prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
# shift the speech tokens according to the original vocab size
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
tts_mel, _ = self.token2wav_model.model.flow.inference(
token=target_speech_tokens,
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
self.device
),
prompt_token=prompt_speech_tokens,
prompt_token_len=torch.tensor(
[prompt_speech_tokens.shape[1]], dtype=torch.int32
).to(self.device),
prompt_feat=prompt_speech_feat,
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=prompt_spk_embedding,
streaming=False,
finalize=True,
)
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
if token_offset is not None:
token_offset = token_offset.as_numpy().item()
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
if not finalize:
stream = True
else:
stream = False
request_id = request.request_id()
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
prompt_token=prompt_speech_tokens,
prompt_feat=prompt_speech_feat,
embedding=prompt_spk_embedding,
token_offset=token_offset,
uuid=request_id,
stream=stream,
finalize=finalize)
if finalize:
self.token2wav_model.model.hift_cache_dict.pop(request_id)
audio_hat, _ = self.token2wav_model.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
else:
tts_mel, _ = self.token2wav_model.model.flow.inference(
token=target_speech_tokens,
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
self.device
),
prompt_token=prompt_speech_tokens,
prompt_token_len=torch.tensor(
[prompt_speech_tokens.shape[1]], dtype=torch.int32
).to(self.device),
prompt_feat=prompt_speech_feat,
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=prompt_spk_embedding,
streaming=False,
finalize=True,
)
audio_hat, _ = self.token2wav_model.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
generated_wave = audio_hat.squeeze(0).cpu().numpy()

View File

@@ -20,7 +20,7 @@ dynamic_batching {
}
parameters [
{
key: "model_dir",
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
@@ -35,16 +35,33 @@ input [
name: "prompt_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
optional: true
},
{
name: "prompt_speech_feat"
data_type: TYPE_FP16
dims: [-1, 80]
optional: true
},
{
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
optional: true
},
{
name: "token_offset"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "finalize"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
}
]
output [

View File

@@ -0,0 +1,142 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import logging
from typing import List, Dict
import torch
from torch.utils.dlpack import to_dlpack
from torch.nn import functional as F
import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
from collections import defaultdict
import numpy as np
from .token2wav_dit import CosyVoice2_Token2Wav
import hashlib
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
"""
Generates a unique ID for a torch.Tensor.
Tensors with the same elements and properties will have the same ID.
"""
# Convert tensor to a byte string
tensor_bytes = tensor.numpy().tobytes()
# Create a SHA-256 hash of the byte string
hasher = hashlib.sha256()
hasher.update(tensor_bytes)
return hasher.hexdigest()
class TritonPythonModel:
"""Triton Python model for vocoder.
This model takes global and semantic tokens as input and generates audio waveforms
using the BiCodec vocoder.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {key: value["string_value"] for key, value in parameters.items()}
model_dir = model_params["model_dir"]
# Initialize device and vocoder
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
# FIXME: device id settings
self.token2wav_model = CosyVoice2_Token2Wav(
model_dir, enable_trt=True, streaming=True
)
logger.info("Token2Wav initialized successfully")
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated waveforms
"""
responses = []
# Process each request in batch
for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens.squeeze().tolist()
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
request_id = request.request_id()
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(
request, "reference_wav_len").as_numpy().item()
wav_array = torch.from_numpy(wav_array)
wav = wav_array[:, :wav_len].squeeze(0)
spk_id = get_spk_id_from_prompt_audio(wav)
audio_hat = self.token2wav_model.forward_streaming(
target_speech_tokens, finalize, request_id=request_id,
speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
)
outputs = []
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
outputs.append(wav_tensor)
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
responses.append(inference_response)
return responses

View File

@@ -0,0 +1,510 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
CUDA_VISIBLE_DEVICES=0 \
python3 token2wav.py --enable-trt || exit 1
"""
import torch
# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
from flashcosyvoice.modules.hifigan import HiFTGenerator
from flashcosyvoice.utils.audio import mel_spectrogram
import torchaudio.compliance.kaldi as kaldi
import onnxruntime
import s3tokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset
import torchaudio
import os
import logging
import argparse
import queue
import time
import numpy as np
from hyperpyyaml import load_hyperpyyaml
def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor):
"""perform fade_in_out in tensor style
"""
mel_overlap_len = int(window.shape[0] / 2)
fade_in_mel = fade_in_mel.clone()
fade_in_mel[..., :mel_overlap_len] = \
fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
import tensorrt as trt
logging.info("Converting onnx to trt...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
if dtype == torch.float16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
# load onnx model
with open(onnx_model, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise ValueError('failed to parse {}'.format(onnx_model))
# set input shapes
for i in range(len(trt_kwargs['input_names'])):
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
if dtype == torch.float16:
tensor_dtype = trt.DataType.HALF
elif dtype == torch.bfloat16:
tensor_dtype = trt.DataType.BF16
elif dtype == torch.float32:
tensor_dtype = trt.DataType.FLOAT
else:
raise ValueError('invalid dtype {}'.format(dtype))
# set input and output data type
for i in range(network.num_inputs):
input_tensor = network.get_input(i)
input_tensor.dtype = tensor_dtype
for i in range(network.num_outputs):
output_tensor = network.get_output(i)
output_tensor.dtype = tensor_dtype
config.add_optimization_profile(profile)
engine_bytes = builder.build_serialized_network(network, config)
# save trt engine
with open(trt_model, "wb") as f:
f.write(engine_bytes)
logging.info("Succesfully convert onnx to trt...")
class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
self.trt_engine = trt_engine
self.device = device
for _ in range(trt_concurrent):
trt_context = trt_engine.create_execution_context()
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
self.trt_context_pool.put([trt_context, trt_stream])
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
def acquire_estimator(self):
return self.trt_context_pool.get(), self.trt_engine
def release_estimator(self, context, stream):
self.trt_context_pool.put([context, stream])
class CosyVoice2_Token2Wav(torch.nn.Module):
def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
super().__init__()
self.device_id = device_id
self.device = f"cuda:{device_id}"
with open(f"{model_dir}/flow.yaml", "r") as f:
configs = load_hyperpyyaml(f)
self.flow = configs['flow']
self.dtype = dtype
self.flow.to(self.dtype)
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
self.flow.to(self.device).eval()
self.hift = HiFTGenerator()
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(
f"{model_dir}/campplus.onnx", sess_options=option,
providers=["CPUExecutionProvider"])
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
gpu = "l20"
if enable_trt:
if streaming:
self.load_trt(
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
1,
self.dtype, streaming
)
else:
self.load_trt(
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
1,
self.dtype
)
self.load_spk_trt(
f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False
)
self.streaming_flow_cache = {}
self.speaker_cache = {}
self.mel_cache_len = 8 # hard-coded, 160ms
self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave
self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
# hifigan cache for streaming tts
self.hift_cache_dict = {}
def forward_spk_embedding(self, spk_feat):
if isinstance(self.spk_model, onnxruntime.InferenceSession):
return self.spk_model.run(
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
else:
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
# NOTE need to synchronize when switching stream
with torch.cuda.device(self.device_id):
torch.cuda.current_stream().synchronize()
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
batch_size = spk_feat.size(0)
with stream:
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
data_ptrs = [spk_feat.contiguous().data_ptr(),
output_tensor.contiguous().data_ptr()]
for i, j in enumerate(data_ptrs):
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.spk_model.release_estimator(spk_model, stream)
return output_tensor.cpu().numpy().flatten().tolist()
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
trt_kwargs = self.get_spk_trt_kwargs()
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32)
import tensorrt as trt
with open(spk_model, 'rb') as f:
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_spk_trt_kwargs(self):
min_shape = [(1, 4, 80)]
opt_shape = [(1, 500, 80)]
max_shape = [(1, 3000, 80)]
input_names = ["input"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
opt_batch_size = 2
max_batch_size = 16
if streaming:
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
if streaming:
min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
opt_shape = [
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500),
(opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2),
(16, opt_batch_size * 2, 8, 100, 128)
]
max_shape = [
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000),
(max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2),
(16, max_batch_size * 2, 8, 1000, 128)
]
input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
else:
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
opt_shape = [
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500),
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80)
]
max_shape = [
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000),
(max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80)
]
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
prompt_speech_tokens_list, prompt_speech_mels_list = [], []
for audio in prompt_audios_list:
assert len(audio.shape) == 1
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
prompt_speech_mels_list.append(log_mel)
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
)
for i in range(len(prompt_speech_tokens)):
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
prompt_speech_tokens_list.append(speech_tokens_i)
return prompt_speech_tokens_list
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
spk_emb_for_flow = []
for audio in prompt_audios_list:
assert len(audio.shape) == 1
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
spk_emb = self.forward_spk_embedding(spk_feat)
spk_emb_for_flow.append(spk_emb)
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
if self.dtype != torch.float32:
spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
return spk_emb_for_flow
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
prompt_mels_for_flow = []
prompt_mels_lens_for_flow = []
for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
assert len(audio.shape) == 1
audio = audio.unsqueeze(0)
if sample_rate != 24000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
mel_len = mel.shape[0]
prompt_mels_for_flow.append(mel)
prompt_mels_lens_for_flow.append(mel_len)
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
prompt_mels_for_flow, batch_first=True, padding_value=0
) # [B, T', num_mels=80]
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
return prompt_mels_for_flow, prompt_mels_lens_for_flow
def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
generated_speech_tokens_list: list[list[int]],
prompt_mels_for_flow: torch.Tensor,
prompt_mels_lens_for_flow: torch.Tensor,
spk_emb_for_flow: torch.Tensor):
batch_size = prompt_mels_for_flow.shape[0]
flow_inputs = []
flow_inputs_lens = []
for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
flow_inputs_lens = torch.tensor(flow_inputs_lens)
with torch.amp.autocast(self.device, dtype=torch.float16):
generated_mels, generated_mels_lens = self.flow.inference(
flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10
)
return generated_mels, generated_mels_lens
def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
batch_size = generated_mels.shape[0]
generated_wavs = []
for i in range(batch_size):
mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
wav, _ = self.hift(speech_feat=mel)
generated_wavs.append(wav)
return generated_wavs
@torch.inference_mode()
def forward(
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
):
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
generated_mels, generated_mels_lens = self.forward_flow(
prompt_speech_tokens_list, generated_speech_tokens_list,
prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
)
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
return generated_wavs
def prepare_prompt_audio(
self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
):
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
def get_prompt_audio_cache_for_streaming_tts(
self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
):
assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts"
for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list):
prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3])
prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0)
cache = self.flow.setup_cache(
prompt_speech_tokens_tensor.to(self.device),
prompt_mels_for_flow.to(self.device),
spk_emb_for_flow.to(self.device),
n_timesteps=10
)
new_cache = {k: v.clone() for k, v in cache.items()}
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
return new_cache
@torch.inference_mode()
def forward_streaming(
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
):
if speaker_id not in self.speaker_cache:
assert prompt_audio is not None, "prompt_audio is required for new speaker"
assert prompt_audio_sample_rate == 16000
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate])
token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0]))
prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
if request_id not in self.streaming_flow_cache:
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
self.hift_cache_dict[request_id] = dict(
mel=torch.zeros(1, 80, 0, device='cuda'),
source=torch.zeros(1, 1, 0, device='cuda'),
speech=torch.zeros(1, 0, device='cuda'),
)
current_request_cache = self.streaming_flow_cache[request_id]
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
token=generated_speech_tokens,
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
cache=current_request_cache,
last_chunk=last_chunk,
n_timesteps=10,
)
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
], dim=4)
hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone()
speech, source = self.hift(mel, hift_cache_source)
# overlap speech smooth
if hift_cache_speech.shape[-1] > 0:
speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
# update vocoder cache
self.hift_cache_dict[request_id] = dict(
mel=mel[..., -self.mel_cache_len:].clone().detach(),
source=source[:, :, -self.source_cache_len:].clone().detach(),
speech=speech[:, -self.source_cache_len:].clone().detach(),
)
if not last_chunk:
speech = speech[:, :-self.source_cache_len]
if last_chunk:
assert request_id in self.streaming_flow_cache
self.streaming_flow_cache.pop(request_id)
self.hift_cache_dict.pop(request_id)
return speech
def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
for item in batch:
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio)
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
ids.append(item['id'])
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--enable-trt", action="store_true")
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--output-dir", type=str, default="generated_wavs")
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
dataset_name = "yuekai/seed_tts_cosy2"
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
for _ in range(args.warmup):
start_time = time.time()
for batch in data_loader:
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
for id, wav in zip(ids, generated_wavs):
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
end_time = time.time()
epoch_time = end_time - start_time
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")

View File

@@ -0,0 +1,69 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "token2wav_dit"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
priority_levels: 10
default_priority_level: 10
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "target_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
},
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
},
{
name: "finalize"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

View File

@@ -0,0 +1,652 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
CUDA_VISIBLE_DEVICES=0 \
python3 offline_inference.py \
--output-dir $output_dir \
--llm-model-name-or-path $huggingface_model_local_dir \
--token2wav-path $model_scope_model_local_dir \
--backend $backend \
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
--engine-dir $trt_engines_dir \
--split-name ${dataset} || exit 1
"""
import argparse
import json
import os
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from cosyvoice.utils.file_utils import load_wav
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import soundfile as sf
import s3tokenizer
from functools import partial
import time
import requests
import asyncio
import httpx
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
async def send_request_async(client, url, payload):
response = await client.post(url, json=payload, timeout=None)
response.raise_for_status()
response_json = response.json()
return response_json['choices'][0]['message']['content']
async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
async with httpx.AsyncClient() as client:
tasks = []
for chat in chats:
payload = {
"model": model_name,
"messages": chat,
"max_tokens": 2048,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": False,
}
tasks.append(send_request_async(client, api_base, payload))
return await asyncio.gather(*tasks)
def extract_speech_ids(speech_tokens_str):
"""Extract speech IDs from token strings like <|s_23456|>"""
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith('<|s_') and token_str.endswith('|>'):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
speech_id_str = ""
for token in cosy2_tokens:
speech_id_str += f"<|s_{token}|>"
return speech_id_str
def get_args():
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch-size",
default=1,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--token2wav-batch-size",
default=1,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num-workers", type=int, default=0, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=None, help="prefetch for dataloader"
)
parser.add_argument(
"--llm-model-name-or-path",
required=True,
type=str,
help="LLM model path (includes both model and tokenizer)",
)
parser.add_argument(
"--token2wav-path",
required=True,
type=str,
help="CosyVoice2 token2wav model path",
)
parser.add_argument(
"--prompt-text",
type=str,
default=None,
help="The prompt text for CosyVoice2",
)
parser.add_argument(
"--prompt-speech-path",
type=str,
default=None,
help="The path to the prompt speech for CosyVoice2",
)
parser.add_argument(
"--top-p",
type=float,
default=0.95,
help="top p for sampling",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="temperature for sampling",
)
parser.add_argument(
"--top-k",
type=int,
default=50,
help="top k for sampling",
)
parser.add_argument(
"--backend",
type=str,
default="hf",
choices=["hf", "trtllm", "vllm", "trtllm-serve"],
help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
)
parser.add_argument(
"--engine-dir",
type=str,
default=None,
help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
)
parser.add_argument(
"--kv-cache-free-gpu-memory-fraction",
type=float,
default=0.6,
help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
)
parser.add_argument(
"--openai-api-base",
type=str,
default="http://localhost:8000/v1/chat/completions",
help="OpenAI API base URL (for trtllm-serve backend)",
)
parser.add_argument(
"--openai-model-name",
type=str,
default="trt_engines_bfloat16",
help="Model name to use with OpenAI API (for trtllm-serve backend)",
)
args = parser.parse_args()
return args
def data_collator(batch, tokenizer, s3_tokenizer):
"""Simplified data collator for batch_size=1 processing"""
collator_start_time = time.time()
total_audio_processing_time = 0
total_speech_tokenization_time = 0
total_text_tokenization_time = 0
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
prompt_text_after_apply_template_list = []
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
chat_list = []
for _, item in enumerate(batch):
audio_processing_start_time = time.time()
prompt_text, target_text = (
item["prompt_text"],
item["target_text"],
)
prompt_text_list.append(prompt_text)
full_text = prompt_text + target_text
full_text_list.append(full_text)
# remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
puncts = ['"', '(', ')', '', '', '', '', '', '\'']
for p in puncts:
if p in full_text:
full_text = full_text.replace(p, '')
print(f"removed {p} from {full_text}")
# get prompt audio for CosyVoice2 (convert to 16kHz)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
print(ref_audio_org.shape)
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
prompt_audio_list.append(ref_audio)
audio_processing_end_time = time.time()
total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
speech_tokenization_start_time = time.time()
if "prompt_audio_cosy2_tokens" in item:
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
else:
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
if len(mels) > 0:
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
for i in range(len(codes)):
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
speech_tokenization_end_time = time.time()
total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
text_tokenization_start_time = time.time()
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
# Create chat template for LLM generation
chat = [
{"role": "user", "content": full_text_list[i]},
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
]
chat_list.append(chat)
assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
)
input_ids_list.append(input_ids.squeeze(0))
prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
text_tokenization_end_time = time.time()
total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
ids = [item["id"] for item in batch]
return {
"input_ids": input_ids_list,
"ids": ids,
"prompt_text": prompt_text_list,
"prompt_audio_list": prompt_audio_list,
"prompt_text_after_apply_template": prompt_text_after_apply_template_list,
"audio_processing_time": total_audio_processing_time,
"speech_tokenization_time": total_speech_tokenization_time,
"text_tokenization_time": total_text_tokenization_time,
"chat_list": chat_list
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
local_rank, world_size, rank = 0, 1, 0
device = torch.device(f"cuda:{local_rank}")
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
if args.backend == "hf":
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
model.eval()
model.to(device)
runner = None
elif args.backend == "trtllm":
if args.engine_dir is None:
raise ValueError("--engine-dir is required when backend is 'trtllm'")
runtime_rank = tensorrt_llm.mpi_rank()
model = None
runner_kwargs = dict(
engine_dir=args.engine_dir,
rank=runtime_rank,
max_output_len=2048,
enable_context_fmha_fp32_acc=False,
max_batch_size=args.batch_size,
max_input_len=512,
kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
cuda_graph_mode=False,
gather_generation_logits=False,
)
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
elif args.backend == "vllm":
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
runner = None
elif args.backend == "trtllm-serve":
model = None
runner = None
else:
raise ValueError(f"Unsupported backend: {args.backend}")
if 'Step-Audio-2-mini' in args.token2wav_path:
from token2wav_dit import CosyVoice2_Token2Wav
else:
assert 'CosyVoice2-0.5B' in args.token2wav_path
from token2wav import CosyVoice2_Token2Wav
token2wav_model = CosyVoice2_Token2Wav(
model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
)
if args.prompt_speech_path:
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
else:
prompt_speech_16k = None
s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
dataset = load_dataset(
dataset_name,
split=args.split_name,
trust_remote_code=True,
)
sampler = None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
)
for _ in range(3):
print(f"Running {_} times")
total_llm_time = 0
total_token2wav_time = 0
total_data_load_time = 0
total_llm_post_processing_time = 0
total_audio_save_time = 0
total_audio_processing_time_in_collator = 0
total_speech_tokenization_time_in_collator = 0
total_text_tokenization_time_in_collator = 0
total_audio_samples = 0
start_time = time.time()
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
last_batch_end_time = time.time()
for batch in dataloader:
data_loaded_time = time.time()
total_data_load_time += data_loaded_time - last_batch_end_time
total_audio_processing_time_in_collator += batch["audio_processing_time"]
total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
with torch.no_grad():
llm_start_time = time.time()
if args.backend == "hf":
input_ids_list = batch["input_ids"]
if len(input_ids_list) == 1:
input_ids = input_ids_list[0].unsqueeze(0)
attention_mask = torch.ones_like(input_ids)
else:
max_len = max([len(input_ids) for input_ids in input_ids_list])
input_ids_list_new = [
torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
for input_ids in input_ids_list
]
input_ids = torch.stack(input_ids_list_new)
attention_mask = torch.zeros_like(input_ids)
for i in range(len(input_ids_list)):
attention_mask[i, :len(input_ids_list[i])] = 1
input_ids = input_ids.to(device)
outputs = model.generate(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
max_new_tokens=2048,
do_sample=True,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=1.1,
top_k=args.top_k,
)
torch.cuda.synchronize()
elif args.backend == "trtllm":
batch_input_ids = list(batch["input_ids"])
input_lengths = [x.size(0) for x in batch_input_ids]
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
outputs = runner.generate(
batch_input_ids=batch_input_ids,
max_new_tokens=2048,
end_id=end_id,
pad_id=end_id,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=1.1,
num_return_sequences=1,
streaming=False,
output_sequence_lengths=True,
output_generation_logits=False,
return_dict=True,
return_all_generated_tokens=False
)
torch.cuda.synchronize()
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
num_output_sents, num_beams, _ = output_ids.size()
assert num_beams == 1
beam = 0
batch_size = len(batch["input_ids"])
num_return_sequences = num_output_sents // batch_size
assert num_return_sequences == 1
outputs = []
for i in range(batch_size * num_return_sequences):
batch_idx = i // num_return_sequences
seq_idx = i % num_return_sequences
output_begin = input_lengths[batch_idx]
output_end = sequence_lengths[i][beam]
outputs_i = output_ids[i][beam][:output_end].tolist()
outputs.append(outputs_i)
elif args.backend == "vllm":
input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=1.1,
max_tokens=2048,
)
outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
print(outputs)
for j, output in enumerate(outputs):
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
elif args.backend == "trtllm-serve":
if args.batch_size > 1:
outputs = asyncio.run(send_batch_requests_async(
args.openai_api_base,
args.openai_model_name,
batch["chat_list"],
args.temperature,
args.top_p,
args.top_k,
))
else:
outputs = []
for chat in batch["chat_list"]:
payload = {
"model": args.openai_model_name,
"messages": chat,
"max_tokens": 2048,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": False,
}
response = requests.post(args.openai_api_base, json=payload)
response.raise_for_status()
response_json = response.json()
generated_content = response_json['choices'][0]['message']['content']
outputs.append(generated_content)
llm_end_time = time.time()
total_llm_time += (llm_end_time - llm_start_time)
items_for_token_2wav = []
for i in range(len(batch["ids"])):
llm_post_processing_start_time = time.time()
if args.backend == "trtllm-serve":
speech_tokens_str = outputs[i].strip().split('><')
if len(speech_tokens_str) > 1:
speech_tokens_str = [
t if t.startswith('<') else '<' + t for t in speech_tokens_str
]
speech_tokens_str = [
t if t.endswith('>') else t + '>' for t in speech_tokens_str
]
speech_ids = extract_speech_ids(speech_tokens_str)
else:
input_length = len(batch["input_ids"][i])
generated_ids = outputs[i][input_length:]
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
speech_ids = extract_speech_ids(speech_tokens_str)
print(i, speech_ids)
if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
continue
if args.prompt_text is not None:
current_prompt_text = args.prompt_text
current_prompt_audio = prompt_speech_16k
else:
current_prompt_text = batch["prompt_text"][i]
current_prompt_audio = batch["prompt_audio_list"][i]
llm_post_processing_end_time = time.time()
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
if current_prompt_audio is not None:
items_for_token_2wav.append({
"speech_ids": speech_ids,
"prompt_audio": current_prompt_audio.squeeze(0),
"id": batch["ids"][i]
})
else:
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
if not t2w_batch:
continue
t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
t2w_ids = [item["id"] for item in t2w_batch]
token2wav_start_time = time.time()
generated_wavs = token2wav_model(
t2w_generated_speech_tokens_list,
t2w_prompt_audios_list,
t2w_prompt_audios_sample_rate,
)
token2wav_end_time = time.time()
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
audio_save_start_time = time.time()
for j, audio_hat in enumerate(generated_wavs):
generated_wave = audio_hat.squeeze().cpu().numpy()
total_audio_samples += len(generated_wave)
target_sample_rate = 24000
utt = t2w_ids[j]
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
audio_save_end_time = time.time()
total_audio_save_time += audio_save_end_time - audio_save_start_time
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
last_batch_end_time = time.time()
if rank == 0:
progress_bar.close()
end_time = time.time()
target_sample_rate = 24000
total_audio_duration_seconds = total_audio_samples / target_sample_rate
log_file_path = os.path.join(args.output_dir, "log.txt")
with open(log_file_path, 'w') as f:
args_dict = vars(args)
log_data = {
"args": args_dict,
"data_load_time_seconds": total_data_load_time,
"audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
"speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
"text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
"llm_time_seconds": total_llm_time,
"llm_post_processing_time_seconds": total_llm_post_processing_time,
"token2wav_time_seconds": total_token2wav_time,
"audio_save_time_seconds": total_audio_save_time,
"total_audio_duration_seconds": total_audio_duration_seconds,
"pipeline_time_seconds": end_time - start_time,
}
print(log_data)
f.write(json.dumps(log_data, indent=4))
print(f"Metrics logged to {log_file_path}")
if __name__ == "__main__":
args = get_args()
if args.backend == "vllm":
from vllm import LLM, SamplingParams
elif args.backend == "trtllm":
import tensorrt_llm
from tensorrt_llm.runtime import ModelRunnerCpp
elif args.backend == "hf":
from transformers import AutoModelForCausalLM
elif args.backend == "trtllm-serve":
pass
else:
raise ValueError(f"Unsupported backend: {args.backend}")
main(args)

View File

@@ -15,6 +15,8 @@ trt_engines_dir=./trt_engines_${trt_dtype}
model_repo=./model_repo_cosyvoice2
use_spk2info_cache=False
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
echo "Cloning CosyVoice"
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
@@ -25,8 +27,11 @@ fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading CosyVoice2-0.5B"
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
# download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
fi
@@ -57,9 +62,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
cosyvoice2_dir="cosyvoice2"
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
cp -r ./model_repo/audio_tokenizer $model_repo
cp -r ./model_repo/tensorrt_llm $model_repo
cp -r ./model_repo/token2wav $model_repo
if [ $use_spk2info_cache == "False" ]; then
cp -r ./model_repo/audio_tokenizer $model_repo
cp -r ./model_repo/speaker_embedding $model_repo
fi
ENGINE_PATH=$trt_engines_dir
MAX_QUEUE_DELAY_MICROSECONDS=0
@@ -67,13 +75,15 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
BLS_INSTANCE_NUM=4
TRITON_MAX_BATCH_SIZE=16
DECOUPLED_MODE=False
DECOUPLED_MODE=True # True for streaming, False for offline
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
if [ $use_spk2info_cache == "False" ]; then
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
@@ -82,7 +92,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Single request test http"
echo "Single request test http, only work for offline TTS mode"
python3 client_http.py \
--reference-audio ./assets/prompt_audio.wav \
--reference-text "吃燕窝就选燕之屋本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝营养更均衡本节目由豆本豆豆奶特约播出。" \
@@ -93,14 +103,40 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Running benchmark client grpc"
num_task=4
# set mode=streaming, when decoupled=True
# set mode=offline, when decoupled=False
mode=offline
mode=streaming
BLS_INSTANCE_NUM=4
python3 client_grpc.py \
--server-addr localhost \
--model-name cosyvoice2 \
--num-tasks $num_task \
--mode $mode \
--use-spk2info-cache $use_spk2info_cache \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_4_${trt_dtype}
fi
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "stage 6: Offline inference benchmark"
n_gpus=1
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
backend=trtllm # hf, trtllm, vllm
batch_sizes=(16 8 4 2 1)
token2wav_batch_size=1
for batch_size in ${batch_sizes[@]}; do
for dataset in ${datasets[@]}; do
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
CUDA_VISIBLE_DEVICES=0 \
python3 offline_inference.py \
--output-dir $output_dir \
--llm-model-name-or-path $huggingface_model_local_dir \
--token2wav-path $model_scope_model_local_dir \
--backend $backend \
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
--engine-dir $trt_engines_dir \
--split-name ${dataset} || exit 1
done
done
fi

View File

@@ -0,0 +1,225 @@
#!/bin/bash
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
export CUDA_VISIBLE_DEVICES=0
cosyvoice_path=/workspace/CosyVoice
stepaudio2_path=/workspace/Step-Audio2
export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
stage=$1
stop_stage=$2
huggingface_model_local_dir=./cosyvoice2_llm
model_scope_model_local_dir=./CosyVoice2-0.5B
step_audio_model_dir=./Step-Audio-2-mini
trt_dtype=bfloat16
trt_weights_dir=./trt_weights_${trt_dtype}
trt_engines_dir=./trt_engines_${trt_dtype}
model_repo=./model_repo_cosyvoice2_dit
bls_instance_num=10
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
echo "Cloning Step-Audio2-mini"
git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt $stepaudio2_path
echo "Cloning CosyVoice"
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
cd $cosyvoice_path
git submodule update --init --recursive
cd runtime/triton_trtllm
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading CosyVoice2-0.5B"
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
echo "Step-Audio2-mini"
huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini
cd $step_audio_model_dir/token2wav
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx
cd -
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint to TensorRT weights"
python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
--output_dir $trt_weights_dir \
--dtype $trt_dtype || exit 1
echo "Building TensorRT engines"
trtllm-build --checkpoint_dir $trt_weights_dir \
--output_dir $trt_engines_dir \
--max_batch_size 64 \
--max_num_tokens 32768 \
--gemm_plugin $trt_dtype || exit 1
echo "Testing TensorRT engines"
python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
--tokenizer_dir $huggingface_model_local_dir \
--top_k 50 --top_p 0.95 --temperature 0.8 \
--engine_dir=$trt_engines_dir || exit 1
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Creating model repository async mode"
rm -rf $model_repo
mkdir -p $model_repo
cosyvoice2_dir="cosyvoice2_dit"
token2wav_dir="token2wav_dit"
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
cp -r ./model_repo/${token2wav_dir} $model_repo
cp -r ./model_repo/audio_tokenizer $model_repo
cp -r ./model_repo/speaker_embedding $model_repo
ENGINE_PATH=$trt_engines_dir
MAX_QUEUE_DELAY_MICROSECONDS=0
MODEL_DIR=$model_scope_model_local_dir
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
BLS_INSTANCE_NUM=$bls_instance_num
TRITON_MAX_BATCH_SIZE=1
DECOUPLED_MODE=True # Only streaming TTS mode is supported using Nvidia Triton for now
STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav
python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve"
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64 --kv_cache_free_gpu_memory_fraction 0.4 &
tritonserver --model-repository $model_repo --http-port 18000 &
wait
# Test using curl
# curl http://localhost:8000/v1/chat/completions \
# -H "Content-Type: application/json" \
# -d '{
# "model": "",
# "messages":[{"role": "user", "content": "Where is New York?"},
# {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
# "max_tokens": 512,
# "temperature": 0.8,
# "top_p": 0.95,
# "top_k": 50,
# "stop": ["<|eos1|>"],
# "repetition_penalty": 1.2,
# "stream": false
# }'
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Running benchmark client"
num_task=4
mode=streaming
BLS_INSTANCE_NUM=$bls_instance_num
python3 client_grpc.py \
--server-addr localhost \
--server-port 8001 \
--model-name cosyvoice2_dit \
--num-tasks $num_task \
--mode $mode \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_single_gpu_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script"
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
backend=trtllm # hf, trtllm, vllm, trtllm-serve
batch_sizes=(16)
token2wav_batch_size=1
for batch_size in ${batch_sizes[@]}; do
for dataset in ${datasets[@]}; do
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
CUDA_VISIBLE_DEVICES=1 \
python3 offline_inference.py \
--output-dir $output_dir \
--llm-model-name-or-path $huggingface_model_local_dir \
--token2wav-path $step_audio_model_dir/token2wav \
--backend $backend \
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
--engine-dir $trt_engines_dir \
--split-name ${dataset} || exit 1
done
done
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "Running Step-Audio2-mini DiT Token2Wav inference using a single python script"
export CUDA_VISIBLE_DEVICES=1
# Note: Using pre-computed cosyvoice2 tokens
python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential
# Offline Token2wav inference
python3 token2wav_dit.py --enable-trt
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "Disaggregated Server: LLM and Token2wav on different GPUs"
echo "Starting LLM server on GPU 0"
export CUDA_VISIBLE_DEVICES=0
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64 --kv_cache_free_gpu_memory_fraction 0.4 &
echo "Starting Token2wav server on GPUs 1-3"
Token2wav_num_gpus=3
http_port=17000
grpc_port=18000
metrics_port=16000
for i in $(seq 0 $(($Token2wav_num_gpus - 1))); do
echo "Starting server on GPU $i"
http_port=$((http_port + 1))
grpc_port=$((grpc_port + 1))
metrics_port=$((metrics_port + 1))
# Two instances of Token2wav server on the same GPU
CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
http_port=$((http_port + 1))
grpc_port=$((grpc_port + 1))
metrics_port=$((metrics_port + 1))
CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
done
wait
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Running benchmark client for Disaggregated Server"
per_gpu_instances=2
mode=streaming
BLS_INSTANCE_NUM=$bls_instance_num
Token2wav_num_gpus=(1 2 3)
concurrent_tasks=(1 2 3 4 5 6)
for n_gpu in ${Token2wav_num_gpus[@]}; do
echo "Test 1 GPU for LLM server and $n_gpu GPUs for Token2wav servers"
for concurrent_task in ${concurrent_tasks[@]}; do
num_instances=$((per_gpu_instances * n_gpu))
for i in $(seq 1 $num_instances); do
port=$(($i + 18000))
python3 client_grpc.py \
--server-addr localhost \
--server-port $port \
--model-name cosyvoice2_dit \
--num-tasks $concurrent_task \
--mode $mode \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_disagg_concurrent_tasks_${concurrent_task}_per_instance_total_token2wav_instances_${num_instances}_port_${port} &
done
wait
done
done
fi

View File

@@ -15,11 +15,6 @@
# limitations under the License.
import argparse
import ast
import csv
import os
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch

View File

@@ -0,0 +1,122 @@
import torch
import os
import argparse
from datasets import load_dataset
from torch.utils.data import DataLoader
import numpy as np
import torchaudio
import time
from token2wav_dit import CosyVoice2_Token2Wav
import soundfile as sf
def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
prompt_speech_tokens_list, prompt_text_list = [], []
for item in batch:
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio)
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
ids.append(item['id'])
prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens'])
prompt_text_list.append(item['prompt_text'])
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--enable-trt", action="store_true")
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--output-dir", type=str, default="generated_wavs")
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"])
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
dataset_name = args.dataset_name
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
CHUNK_SIZE = 25
token_frame_rate = 25
OVERLAP_SIZE = 0
warmup_times = 3
for _ in range(warmup_times):
start_time = time.time()
total_forward_count = 0
for batch in data_loader:
tts_speech_list = []
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
assert prompt_audio_sample_rate == 16000
prompt_text = prompt_text_list[0]
prompt_speech_tokens = prompt_speech_tokens_list[0]
semantic_token_ids_arr, token_offset = [], 0
flow_prompt_speech_token_len = len(prompt_speech_tokens)
buffer = generated_speech_tokens
output_wavs = []
chunk_index = 0
while True:
if args.strategy == "equal":
this_chunk_size = CHUNK_SIZE
elif args.strategy == "exponential":
this_chunk_size = token_frame_rate * (2 ** chunk_index)
if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
wavs = token2wav_model.forward_streaming(
buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len],
False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio,
prompt_audio_sample_rate=prompt_audio_sample_rate
)
buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
output_wavs.append(wavs)
total_forward_count += 1
chunk_index += 1
else:
wavs = token2wav_model.forward_streaming(
buffer, True, request_id=id, speaker_id=f"{id}",
prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate
)
output_wavs.append(wavs)
total_forward_count += 1
# chunk_index += 1
break
for i, wav in enumerate(output_wavs):
output_wavs[i] = wav.cpu().numpy().squeeze()
audios = output_wavs
reconstructed_audio = np.concatenate(audios)
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
end_time = time.time()
if _ == 0:
token2wav_model.speaker_cache = {}
print(f"Warmup time: {end_time - start_time} seconds")
print("clear speaker cache")
elif _ == 1:
print(f"Cost time without speaker cache: {end_time - start_time} seconds")
else:
print(f"Cost time with speaker cache: {end_time - start_time} seconds")
print(f"Total flow matching forward calls: {total_forward_count}")

View File

@@ -0,0 +1,335 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
CUDA_VISIBLE_DEVICES=0 \
python3 token2wav.py --enable-trt || exit 1
"""
import torch
from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
from flashcosyvoice.modules.hifigan import HiFTGenerator
from flashcosyvoice.utils.audio import mel_spectrogram
import torchaudio.compliance.kaldi as kaldi
import onnxruntime
import s3tokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset
import torchaudio
import os
import logging
import argparse
import queue
import time
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
import tensorrt as trt
logging.info("Converting onnx to trt...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
# load onnx model
with open(onnx_model, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise ValueError('failed to parse {}'.format(onnx_model))
# set input shapes
for i in range(len(trt_kwargs['input_names'])):
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
# set input and output data type
for i in range(network.num_inputs):
input_tensor = network.get_input(i)
input_tensor.dtype = tensor_dtype
for i in range(network.num_outputs):
output_tensor = network.get_output(i)
output_tensor.dtype = tensor_dtype
config.add_optimization_profile(profile)
engine_bytes = builder.build_serialized_network(network, config)
# save trt engine
with open(trt_model, "wb") as f:
f.write(engine_bytes)
logging.info("Succesfully convert onnx to trt...")
class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
self.trt_engine = trt_engine
self.device = device
for _ in range(trt_concurrent):
trt_context = trt_engine.create_execution_context()
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
self.trt_context_pool.put([trt_context, trt_stream])
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
def acquire_estimator(self):
return self.trt_context_pool.get(), self.trt_engine
def release_estimator(self, context, stream):
self.trt_context_pool.put([context, stream])
class CosyVoice2_Token2Wav(torch.nn.Module):
def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0):
super().__init__()
self.device_id = device_id
self.device = f"cuda:{device_id}"
self.flow = CausalMaskedDiffWithXvec()
self.flow.half()
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
self.flow.to(self.device).eval()
self.hift = HiFTGenerator()
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval()
gpu = "l20"
if enable_trt:
self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
1,
True)
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False)
def forward_spk_embedding(self, spk_feat):
if isinstance(self.spk_model, onnxruntime.InferenceSession):
return self.spk_model.run(
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
else:
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
# NOTE need to synchronize when switching stream
with torch.cuda.device(self.device_id):
torch.cuda.current_stream().synchronize()
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
batch_size = spk_feat.size(0)
with stream:
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
data_ptrs = [spk_feat.contiguous().data_ptr(),
output_tensor.contiguous().data_ptr()]
for i, j in enumerate(data_ptrs):
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.spk_model.release_estimator(spk_model, stream)
return output_tensor.cpu().numpy().flatten().tolist()
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
trt_kwargs = self.get_spk_trt_kwargs()
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
import tensorrt as trt
with open(spk_model, 'rb') as f:
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_spk_trt_kwargs(self):
min_shape = [(1, 4, 80)]
opt_shape = [(1, 500, 80)]
max_shape = [(1, 3000, 80)]
input_names = ["input"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_bs=2, max_batch_size=16)
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
opt_shape = [(opt_bs * 2, 80, 500), (opt_bs * 2, 1, 500), (opt_bs * 2, 80, 500), (opt_bs * 2, 80, 500), (opt_bs * 2,), (opt_bs * 2, 80)]
max_shape = [(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2,),
(max_batch_size * 2, 80)]
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
prompt_speech_tokens_list, prompt_speech_mels_list = [], []
for audio in prompt_audios_list:
assert len(audio.shape) == 1
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
prompt_speech_mels_list.append(log_mel)
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
)
for i in range(len(prompt_speech_tokens)):
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
prompt_speech_tokens_list.append(speech_tokens_i)
return prompt_speech_tokens_list
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
spk_emb_for_flow = []
for audio in prompt_audios_list:
assert len(audio.shape) == 1
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
spk_emb = self.forward_spk_embedding(spk_feat)
spk_emb_for_flow.append(spk_emb)
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
return spk_emb_for_flow
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
prompt_mels_for_flow = []
prompt_mels_lens_for_flow = []
for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
assert len(audio.shape) == 1
audio = audio.unsqueeze(0)
if sample_rate != 24000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
mel_len = mel.shape[0]
prompt_mels_for_flow.append(mel)
prompt_mels_lens_for_flow.append(mel_len)
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
return prompt_mels_for_flow, prompt_mels_lens_for_flow
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor,
prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
batch_size = prompt_mels_for_flow.shape[0]
flow_inputs = []
flow_inputs_lens = []
for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
flow_inputs_lens = torch.tensor(flow_inputs_lens)
with torch.amp.autocast(self.device, dtype=torch.float16):
generated_mels, generated_mels_lens = self.flow(
flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device),
streaming=False, finalize=True
)
return generated_mels, generated_mels_lens
def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
batch_size = generated_mels.shape[0]
generated_wavs = []
for i in range(batch_size):
mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
wav, _ = self.hift(speech_feat=mel)
generated_wavs.append(wav)
return generated_wavs
@torch.inference_mode()
def forward(
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
):
# assert all item in prompt_audios_sample_rate is 16000
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
generated_mels, generated_mels_lens = self.forward_flow(
prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
return generated_wavs
def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
for _, item in enumerate(batch):
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio)
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
ids.append(item['id'])
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--enable-trt", action="store_true")
parser.add_argument("--model-dir", type=str, default="./CosyVoice2-0.5B")
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--output-dir", type=str, default="generated_wavs")
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
# mkdir output_dir if not exists
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
dataset_name = "yuekai/seed_tts_cosy2"
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
for _ in range(args.warmup):
start_time = time.time()
for batch in data_loader:
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
for id, wav in zip(ids, generated_wavs):
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
end_time = time.time()
epoch_time = end_time - start_time
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")

View File

@@ -0,0 +1 @@
model_repo/token2wav_dit/1/token2wav_dit.py