mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
Merge branch 'main' into dev/lyuxiang.lx
This commit is contained in:
13
README.md
13
README.md
@@ -180,7 +180,7 @@ Notice that `vllm==v0.9.0` has a lot of specific requirements, for example `torc
|
|||||||
``` sh
|
``` sh
|
||||||
conda create -n cosyvoice_vllm --clone cosyvoice
|
conda create -n cosyvoice_vllm --clone cosyvoice
|
||||||
conda activate cosyvoice_vllm
|
conda activate cosyvoice_vllm
|
||||||
pip install vllm==v0.9.0 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
pip install vllm==v0.9.0 transformers==4.51.3 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||||
python vllm_example.py
|
python vllm_example.py
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -246,6 +246,17 @@ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /o
|
|||||||
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Using Nvidia TensorRT-LLM for deployment
|
||||||
|
|
||||||
|
Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
|
||||||
|
To quick start:
|
||||||
|
|
||||||
|
``` sh
|
||||||
|
cd runtime/triton_trtllm
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
|
||||||
|
|
||||||
## Discussion & Communication
|
## Discussion & Communication
|
||||||
|
|
||||||
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
|
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
|
||||||
|
|||||||
@@ -46,6 +46,6 @@ RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
|||||||
|
|
||||||
RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5
|
RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5
|
||||||
RUN conda activate ${VENV} && cd CosyVoice && \
|
RUN conda activate ${VENV} && cd CosyVoice && \
|
||||||
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
|
||||||
|
|
||||||
WORKDIR /workspace/CosyVoice
|
WORKDIR /workspace/CosyVoice
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ Stage `0` converts raw JSONL files into the parquet format expected by veRL:
|
|||||||
```bash
|
```bash
|
||||||
bash run.sh 0 0
|
bash run.sh 0 0
|
||||||
```
|
```
|
||||||
Create two JSONL files—`train.jsonl` and `test.jsonl`.
|
Create two JSONL files—`train.jsonl` and `test.jsonl`.
|
||||||
The script will then generate two Parquet files:
|
The script will then generate two Parquet files:
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -111,7 +111,7 @@ bash run.sh 5 5
|
|||||||
|
|
||||||
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
|
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
|
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
|
||||||
|
|
||||||
## Results
|
## Results
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ except RuntimeError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"
|
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
def audio_decode_cosyvoice2(
|
def audio_decode_cosyvoice2(
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
|
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
|
||||||
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
|
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
|
||||||
python3 pretrained_to_huggingface.py \
|
python3 pretrained_to_huggingface.py \
|
||||||
--pretrained-cosyvoice2-path $model_scope_model_path \
|
--pretrained-cosyvoice2-path $model_scope_model_path \
|
||||||
--save-path $sft_model_path
|
--save-path $sft_model_path
|
||||||
@@ -61,7 +61,7 @@ fi
|
|||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
log "stage 1: start token2wav asr server for reward function"
|
log "stage 1: start token2wav asr server for reward function"
|
||||||
python3 token2wav_asr_server.py --number-of-devices 8
|
python3 token2wav_asr_server.py --number-of-devices 8
|
||||||
fi
|
fi
|
||||||
|
|
||||||
exp_name=official_llm_aishell3_grpo
|
exp_name=official_llm_aishell3_grpo
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
@@ -125,7 +125,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
--backend fsdp \
|
--backend fsdp \
|
||||||
--local_dir $llm_path/actor \
|
--local_dir $llm_path/actor \
|
||||||
--target_dir $llm_path/merged_hf_model || exit 1
|
--target_dir $llm_path/merged_hf_model || exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
log "stage 4: Test the model"
|
log "stage 4: Test the model"
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
#
|
|
||||||
# Copyright (c) 2023 by manyeyes
|
# Copyright (c) 2023 by manyeyes
|
||||||
# Copyright (c) 2023 Xiaomi Corporation
|
# Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
@@ -195,7 +193,7 @@ def write_error_stats(
|
|||||||
hyp = list("".join(hyp))
|
hyp = list("".join(hyp))
|
||||||
results[i] = (cut_id, ref, hyp)
|
results[i] = (cut_id, ref, hyp)
|
||||||
|
|
||||||
for cut_id, ref, hyp in results:
|
for _cut_id, ref, hyp in results:
|
||||||
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||||
for ref_word, hyp_word in ali:
|
for ref_word, hyp_word in ali:
|
||||||
if ref_word == ERR:
|
if ref_word == ERR:
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ def main():
|
|||||||
metrics_port=8002,
|
metrics_port=8002,
|
||||||
)
|
)
|
||||||
|
|
||||||
device_ids = [i for i in range(args.number_of_devices)]
|
device_ids = list(range(args.number_of_devices))
|
||||||
device_ids = device_ids * args.number_of_instances_per_device
|
device_ids = device_ids * args.number_of_instances_per_device
|
||||||
|
|
||||||
with Triton(config=triton_config) as triton:
|
with Triton(config=triton_config) as triton:
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux'
|
|||||||
tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
|
tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
|
||||||
torch==2.3.1
|
torch==2.3.1
|
||||||
torchaudio==2.3.1
|
torchaudio==2.3.1
|
||||||
transformers==4.40.1
|
transformers==4.51.3
|
||||||
uvicorn==0.30.0
|
uvicorn==0.30.0
|
||||||
wetext==0.0.4
|
wetext==0.0.4
|
||||||
wget==3.2
|
wget==3.2
|
||||||
|
|||||||
@@ -9,5 +9,5 @@ RUN apt-get -y install git unzip git-lfs g++
|
|||||||
RUN git lfs install
|
RUN git lfs install
|
||||||
RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||||
# here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
|
# here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
|
||||||
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
|
||||||
RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto
|
RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto
|
||||||
141
runtime/triton_trtllm/README.DIT.md
Normal file
141
runtime/triton_trtllm/README.DIT.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
## Accelerating CosyVoice with DiT-based Token2Wav, NVIDIA Triton Inference Server and TensorRT-LLM
|
||||||
|
|
||||||
|
Contributed by Yuekai Zhang (NVIDIA).
|
||||||
|
|
||||||
|
This document describes how to accelerate CosyVoice with a DiT-based Token2Wav module from Step-Audio2, using NVIDIA Triton Inference Server and TensorRT-LLM.
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
Launch the service directly with Docker Compose:
|
||||||
|
```sh
|
||||||
|
docker compose -f docker-compose.dit.yml up
|
||||||
|
```
|
||||||
|
|
||||||
|
### Build the Docker Image
|
||||||
|
|
||||||
|
To build the image from scratch:
|
||||||
|
```sh
|
||||||
|
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run a Docker Container
|
||||||
|
```sh
|
||||||
|
your_mount_dir=/mnt:/mnt
|
||||||
|
docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
|
||||||
|
```
|
||||||
|
|
||||||
|
### Understanding `run_stepaudio2_dit_token2wav.sh`
|
||||||
|
|
||||||
|
The `run_stepaudio2_dit_token2wav.sh` script orchestrates the entire workflow through numbered stages.
|
||||||
|
|
||||||
|
You can run a subset of stages with:
|
||||||
|
```sh
|
||||||
|
bash run_stepaudio2_dit_token2wav.sh <start_stage> <stop_stage>
|
||||||
|
```
|
||||||
|
- `<start_stage>`: The stage to start from.
|
||||||
|
- `<stop_stage>`: The stage to stop after.
|
||||||
|
|
||||||
|
**Stages:**
|
||||||
|
|
||||||
|
- **Stage -1**: Clones the `Step-Audio2` and `CosyVoice` repositories.
|
||||||
|
- **Stage 0**: Downloads the `cosyvoice2_llm`, `CosyVoice2-0.5B`, and `Step-Audio-2-mini` models.
|
||||||
|
- **Stage 1**: Converts the HuggingFace checkpoint for the LLM to the TensorRT-LLM format and builds the TensorRT engines.
|
||||||
|
- **Stage 2**: Creates the Triton model repository, including configurations for `cosyvoice2_dit` and `token2wav_dit`.
|
||||||
|
- **Stage 3**: Launches the Triton Inference Server for Token2Wav module and uses `trtllm-serve` to deploy Cosyvoice2 LLM.
|
||||||
|
- **Stage 4**: Runs the gRPC benchmark client for performance testing.
|
||||||
|
- **Stage 5**: Runs the offline TTS inference benchmark test.
|
||||||
|
- **Stage 6**: Runs a standalone inference script for the Step-Audio2-mini DiT Token2Wav model.
|
||||||
|
- **Stage 7**: Launches servers in a disaggregated setup, with the LLM on GPU 0 and Token2Wav servers on GPUs 1-3.
|
||||||
|
- **Stage 8**: Runs the benchmark client for the disaggregated server configuration.
|
||||||
|
### Export Models and Launch Server
|
||||||
|
|
||||||
|
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
|
||||||
|
```sh
|
||||||
|
# This command runs stages 0, 1, 2, and 3
|
||||||
|
bash run_stepaudio2_dit_token2wav.sh 0 3
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark with client-server mode
|
||||||
|
|
||||||
|
To benchmark the running Triton server, run stage 4:
|
||||||
|
```sh
|
||||||
|
bash run_stepaudio2_dit_token2wav.sh 4 4
|
||||||
|
|
||||||
|
# You can customize parameters such as the number of tasks inside the script.
|
||||||
|
```
|
||||||
|
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
|
||||||
|
|
||||||
|
#### Total Request Latency
|
||||||
|
|
||||||
|
| Concurrent Tasks | RTF | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
|
||||||
|
| ---------------- | ------ | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
|
||||||
|
| 1 | 0.1228 | 833.66 | 779.98 | 1297.05 | 1555.97 | 1653.02 |
|
||||||
|
| 2 | 0.0901 | 1166.23 | 1124.69 | 1762.76 | 1900.64 | 2204.14 |
|
||||||
|
| 4 | 0.0741 | 1849.30 | 1759.42 | 2624.50 | 2822.20 | 3128.42 |
|
||||||
|
| 6 | 0.0774 | 2936.13 | 3054.64 | 3849.60 | 3900.49 | 4245.79 |
|
||||||
|
| 8 | 0.0691 | 3408.56 | 3434.98 | 4547.13 | 5047.76 | 5346.53 |
|
||||||
|
| 10 | 0.0707 | 4306.56 | 4343.44 | 5769.64 | 5876.09 | 5939.79 |
|
||||||
|
|
||||||
|
#### First Chunk Latency
|
||||||
|
|
||||||
|
| Concurrent Tasks | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
|
||||||
|
| ---------------- | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
|
||||||
|
| 1 | 197.50 | 196.13 | 214.65 | 215.96 | 229.21 |
|
||||||
|
| 2 | 281.15 | 278.20 | 345.18 | 361.79 | 395.97 |
|
||||||
|
| 4 | 510.65 | 530.50 | 630.13 | 642.44 | 666.65 |
|
||||||
|
| 6 | 921.54 | 918.86 | 1079.97 | 1265.22 | 1524.41 |
|
||||||
|
| 8 | 1019.95 | 1085.26 | 1371.05 | 1402.24 | 1410.66 |
|
||||||
|
| 10 | 1214.98 | 1293.54 | 1575.36 | 1654.51 | 2161.76 |
|
||||||
|
|
||||||
|
### Benchmark with offline inference mode
|
||||||
|
For offline inference mode benchmark, please run stage 5:
|
||||||
|
```sh
|
||||||
|
bash run_stepaudio2_dit_token2wav.sh 5 5
|
||||||
|
```
|
||||||
|
|
||||||
|
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
|
||||||
|
|
||||||
|
#### Offline TTS (Cosyvoice2 0.5B LLM + StepAudio2 DiT Token2Wav)
|
||||||
|
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|
||||||
|
|---------|------------|------------------|-----------------------|--|
|
||||||
|
| TRTLLM | 16 | 2.01 | 5.03 | 0.0292 |
|
||||||
|
|
||||||
|
|
||||||
|
### Disaggregated Server
|
||||||
|
When the LLM and token2wav components are deployed on the same GPU, they compete for resources. To optimize performance, we use a disaggregated setup where the LLM is deployed on one dedicated L20 GPU, taking advantage of in-flight batching for inference. The token2wav module is deployed on separate, dedicated GPUs.
|
||||||
|
|
||||||
|
The table below shows the first chunk latency results for this configuration. In our tests, we deploy two token2wav instances on each dedicated token2wav GPU.
|
||||||
|
|
||||||
|
| token2wav_num_gpu | concurrent_task_per_instance | concurrent_tasks_per_gpu | avg (ms) | p50 (ms) | p90 (ms) | p99 (ms) |
|
||||||
|
|---|---|---|---|---|---|---|
|
||||||
|
| 1 | 1 | 1.00 | 218.53 | 217.86 | 254.07 | 296.49 |
|
||||||
|
| 2 | 1 | 1.33 | 218.82 | 219.21 | 256.62 | 303.13 |
|
||||||
|
| 3 | 1 | 1.50 | 229.08 | 223.27 | 302.13 | 324.41 |
|
||||||
|
| 4 | 1 | 1.60 | 203.87 | 198.23 | 254.92 | 279.31 |
|
||||||
|
| 1 | 2 | 2.00 | 293.46 | 280.53 | 370.81 | 407.40 |
|
||||||
|
| 2 | 2 | 2.67 | 263.38 | 236.84 | 350.82 | 397.39 |
|
||||||
|
| 3 | 2 | 3.00 | 308.09 | 275.48 | 385.22 | 521.45 |
|
||||||
|
| 4 | 2 | 3.20 | 271.85 | 253.25 | 359.03 | 387.91 |
|
||||||
|
| 1 | 3 | 3.00 | 389.15 | 373.01 | 469.22 | 542.89 |
|
||||||
|
| 2 | 3 | 4.00 | 403.48 | 394.80 | 481.24 | 507.75 |
|
||||||
|
| 3 | 3 | 4.50 | 406.33 | 391.28 | 495.43 | 571.29 |
|
||||||
|
| 4 | 3 | 4.80 | 436.72 | 383.81 | 638.44 | 879.23 |
|
||||||
|
| 1 | 4 | 4.00 | 520.12 | 493.98 | 610.38 | 739.85 |
|
||||||
|
| 2 | 4 | 5.33 | 494.60 | 490.50 | 605.93 | 708.09 |
|
||||||
|
| 3 | 4 | 6.00 | 538.23 | 508.33 | 687.62 | 736.96 |
|
||||||
|
| 4 | 4 | 6.40 | 579.68 | 546.20 | 721.53 | 958.04 |
|
||||||
|
| 1 | 5 | 5.00 | 635.02 | 623.30 | 786.85 | 819.84 |
|
||||||
|
| 2 | 5 | 6.67 | 598.23 | 617.09 | 741.00 | 788.96 |
|
||||||
|
| 3 | 5 | 7.50 | 644.78 | 684.40 | 786.45 | 1009.45 |
|
||||||
|
| 4 | 5 | 8.00 | 733.92 | 642.26 | 1024.79 | 1281.55 |
|
||||||
|
| 1 | 6 | 6.00 | 715.38 | 745.68 | 887.04 | 906.68 |
|
||||||
|
| 2 | 6 | 8.00 | 748.31 | 753.94 | 873.59 | 1007.14 |
|
||||||
|
| 3 | 6 | 9.00 | 900.27 | 822.28 | 1431.14 | 1800.23 |
|
||||||
|
| 4 | 6 | 9.60 | 857.54 | 820.33 | 1150.30 | 1298.53 |
|
||||||
|
|
||||||
|
The `concurrent_task_per_gpu` is calculated as:
|
||||||
|
`concurrent_task_per_gpu = concurrent_task_per_instance * num_token2wav_instance_per_gpu (2) * token2wav_gpus / (token2wav_gpus + llm_gpus (1))`
|
||||||
|
|
||||||
|
### Acknowledgements
|
||||||
|
|
||||||
|
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).
|
||||||
@@ -1,15 +1,17 @@
|
|||||||
## Best Practices for Serving CosyVoice with NVIDIA Triton Inference Server
|
## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM
|
||||||
|
|
||||||
Thanks to the contribution from NVIDIA Yuekai Zhang.
|
Contributed by Yuekai Zhang (NVIDIA).
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
|
|
||||||
Launch the service directly with Docker Compose:
|
Launch the service directly with Docker Compose:
|
||||||
```sh
|
```sh
|
||||||
docker compose up
|
docker compose up
|
||||||
```
|
```
|
||||||
|
|
||||||
### Build the Docker Image
|
### Build the Docker Image
|
||||||
Build the image from scratch:
|
|
||||||
|
To build the image from scratch:
|
||||||
```sh
|
```sh
|
||||||
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
|
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
|
||||||
```
|
```
|
||||||
@@ -21,71 +23,124 @@ docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_di
|
|||||||
```
|
```
|
||||||
|
|
||||||
### Understanding `run.sh`
|
### Understanding `run.sh`
|
||||||
|
|
||||||
The `run.sh` script orchestrates the entire workflow through numbered stages.
|
The `run.sh` script orchestrates the entire workflow through numbered stages.
|
||||||
|
|
||||||
Run a subset of stages with:
|
You can run a subset of stages with:
|
||||||
```sh
|
```sh
|
||||||
bash run.sh <start_stage> <stop_stage> [service_type]
|
bash run.sh <start_stage> <stop_stage> [service_type]
|
||||||
```
|
```
|
||||||
- `<start_stage>` – stage to start from (0-5).
|
- `<start_stage>`: The stage to start from (0-5).
|
||||||
- `<stop_stage>` – stage to stop after (0-5).
|
- `<stop_stage>`: The stage to stop after (0-5).
|
||||||
|
|
||||||
Stages:
|
**Stages:**
|
||||||
- **Stage 0** – Download the cosyvoice-2 0.5B model from HuggingFace.
|
|
||||||
- **Stage 1** – Convert the HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines.
|
- **Stage 0**: Downloads the `cosyvoice-2 0.5B` model from HuggingFace.
|
||||||
- **Stage 2** – Create the Triton model repository and configure the model files (adjusts depending on whether `Decoupled=True/False` will be used later).
|
- **Stage 1**: Converts the HuggingFace checkpoint to the TensorRT-LLM format and builds the TensorRT engines.
|
||||||
- **Stage 3** – Launch the Triton Inference Server.
|
- **Stage 2**: Creates the Triton model repository and configures the model files. The configuration is adjusted based on whether `Decoupled=True` (streaming) or `Decoupled=False` (offline) will be used.
|
||||||
- **Stage 4** – Run the single-utterance HTTP client.
|
- **Stage 3**: Launches the Triton Inference Server.
|
||||||
- **Stage 5** – Run the gRPC benchmark client.
|
- **Stage 4**: Runs the single-utterance HTTP client for testing.
|
||||||
|
- **Stage 5**: Runs the gRPC benchmark client.
|
||||||
|
- **Stage 6**: Runs the offline inference benchmark test.
|
||||||
|
|
||||||
|
### Export Models and Launch Server
|
||||||
|
|
||||||
### Export Models to TensorRT-LLM and Launch the Server
|
|
||||||
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
|
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
|
||||||
```sh
|
```sh
|
||||||
# Runs stages 0, 1, 2, and 3
|
# This command runs stages 0, 1, 2, and 3
|
||||||
bash run.sh 0 3
|
bash run.sh 0 3
|
||||||
```
|
```
|
||||||
*Note: Stage 2 prepares the model repository differently depending on whether you intend to run with `Decoupled=False` or `Decoupled=True`. Rerun stage 2 if you switch the service type.*
|
> [!TIP]
|
||||||
|
> Both streaming and offline (non-streaming) TTS modes are supported. For streaming TTS, set `Decoupled=True`. For offline TTS, set `Decoupled=False`. You need to rerun stage 2 if you switch between modes.
|
||||||
|
|
||||||
### Single-Utterance HTTP Client
|
### Single-Utterance HTTP Client
|
||||||
Send a single HTTP inference request:
|
|
||||||
|
Sends a single HTTP inference request. This is intended for testing the offline TTS mode (`Decoupled=False`):
|
||||||
```sh
|
```sh
|
||||||
bash run.sh 4 4
|
bash run.sh 4 4
|
||||||
```
|
```
|
||||||
|
|
||||||
### Benchmark with a Dataset
|
### Benchmark with client-server mode
|
||||||
Benchmark the running Triton server. Pass either `streaming` or `offline` as the third argument.
|
|
||||||
```sh
|
|
||||||
bash run.sh 5 5
|
|
||||||
|
|
||||||
# You can also customise parameters such as num_task and dataset split directly:
|
To benchmark the running Triton server, pass `streaming` or `offline` as the third argument:
|
||||||
|
```sh
|
||||||
|
bash run.sh 5 5 # [streaming|offline]
|
||||||
|
|
||||||
|
# You can also customize parameters such as the number of tasks and the dataset split:
|
||||||
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline]
|
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline]
|
||||||
```
|
```
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Only offline CosyVoice TTS is currently supported. Setting the client to `streaming` simply enables NVIDIA Triton’s decoupled mode so that responses are returned as soon as they are ready.
|
> It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up.
|
||||||
|
|
||||||
### Benchmark Results
|
### Benchmark with offline inference mode
|
||||||
Decoding on a single L20 GPU with 26 prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts) (≈221 s of audio):
|
For offline inference mode benchmark, please check the below command:
|
||||||
|
|
||||||
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
|
||||||
|------|------|-------------|------------------|------------------|-----|
|
|
||||||
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
|
|
||||||
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
|
|
||||||
| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
|
|
||||||
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 659.87 | 655.63 | 0.0891 |
|
|
||||||
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1103.16 | 992.96 | 0.0693 |
|
|
||||||
| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1790.91 | 1668.63 | 0.0604 |
|
|
||||||
|
|
||||||
### OpenAI-Compatible Server
|
|
||||||
To launch an OpenAI-compatible service, run:
|
|
||||||
```sh
|
```sh
|
||||||
git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
|
# install FlashCosyVoice for token2wav batching
|
||||||
pip install -r requirements.txt
|
# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt
|
||||||
# After the Triton service is up, start the FastAPI bridge:
|
# cd /workspace/FlashCosyVoice
|
||||||
python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
|
# pip install -e .
|
||||||
# Test with curl
|
# cd -
|
||||||
bash test/test_cosyvoice.sh
|
# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx
|
||||||
|
|
||||||
|
bash run.sh 6 6
|
||||||
|
|
||||||
|
# You can also switch to huggingface backend by setting backend=hf
|
||||||
```
|
```
|
||||||
|
|
||||||
### Acknowledgements
|
|
||||||
This section originates from the NVIDIA CISI project. We also provide other multimodal resources—see [mair-hub](https://github.com/nvidia-china-sae/mair-hub) for details.
|
### Benchmark Results
|
||||||
|
The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio):
|
||||||
|
|
||||||
|
**Client-Server Mode: Streaming TTS (First Chunk Latency)**
|
||||||
|
| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
||||||
|
|---|---|---|---|---|
|
||||||
|
| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
|
||||||
|
| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 |
|
||||||
|
| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 |
|
||||||
|
| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 |
|
||||||
|
| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 |
|
||||||
|
| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 |
|
||||||
|
|
||||||
|
> If your service only needs a fixed speaker, you can set `use_spk2info_cache=True` in `run.sh`. To add more speakers, refer to the instructions [here](https://github.com/qi-hua/async_cosyvoice?tab=readme-ov-file#9-spk2info-%E8%AF%B4%E6%98%8E).
|
||||||
|
|
||||||
|
**Client-Server Mode: Offline TTS (Full Sentence Latency)**
|
||||||
|
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
||||||
|
|---|---|---|---|---|---|
|
||||||
|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
|
||||||
|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
|
||||||
|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
|
||||||
|
|
||||||
|
**Offline Inference Mode: Hugginface LLM V.S. TensorRT-LLM**
|
||||||
|
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|
||||||
|
|---------|------------|------------------|-----------------------|--|
|
||||||
|
| HF | 1 | 39.26 | 44.31 | 0.2494 |
|
||||||
|
| HF | 2 | 30.54 | 35.62 | 0.2064 |
|
||||||
|
| HF | 4 | 18.63 | 23.90 | 0.1421 |
|
||||||
|
| HF | 8 | 11.22 | 16.45 | 0.0947 |
|
||||||
|
| HF | 16 | 8.42 | 13.78 | 0.0821 |
|
||||||
|
| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 |
|
||||||
|
| TRTLLM | 2 | 7.64 |12.65 | 0.0739 |
|
||||||
|
| TRTLLM | 4 | 4.89 | 9.38 | 0.0539 |
|
||||||
|
| TRTLLM | 8 | 2.92 | 7.23 | 0.0418 |
|
||||||
|
| TRTLLM | 16 | 2.01 | 6.63 | 0.0386 |
|
||||||
|
### OpenAI-Compatible Server
|
||||||
|
|
||||||
|
To launch an OpenAI-compatible API service, run the following commands:
|
||||||
|
```sh
|
||||||
|
git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
|
||||||
|
cd Triton-OpenAI-Speech
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# After the Triton service is running, start the FastAPI bridge:
|
||||||
|
python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
|
||||||
|
|
||||||
|
# Test the service with curl:
|
||||||
|
bash test/test_cosyvoice.sh
|
||||||
|
```
|
||||||
|
> [!NOTE]
|
||||||
|
> Currently, only the offline TTS mode is compatible with the OpenAI-compatible server.
|
||||||
|
|
||||||
|
### Acknowledgements
|
||||||
|
|
||||||
|
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).
|
||||||
|
|
||||||
|
|||||||
@@ -43,9 +43,9 @@ python3 client_grpc.py \
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import queue # Added
|
import queue
|
||||||
import uuid # Added
|
import uuid
|
||||||
import functools # Added
|
import functools
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@@ -55,16 +55,16 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import tritonclient
|
import tritonclient
|
||||||
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
|
import tritonclient.grpc.aio as grpcclient_aio
|
||||||
import tritonclient.grpc as grpcclient_sync # Added sync client import
|
import tritonclient.grpc as grpcclient_sync
|
||||||
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
|
||||||
|
|
||||||
|
|
||||||
# --- Added UserData and callback ---
|
|
||||||
class UserData:
|
class UserData:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._completed_requests = queue.Queue()
|
self._completed_requests = queue.Queue()
|
||||||
self._first_chunk_time = None
|
self._first_chunk_time = None
|
||||||
|
self._second_chunk_time = None
|
||||||
self._start_time = None
|
self._start_time = None
|
||||||
|
|
||||||
def record_start_time(self):
|
def record_start_time(self):
|
||||||
@@ -75,39 +75,43 @@ class UserData:
|
|||||||
return self._first_chunk_time - self._start_time
|
return self._first_chunk_time - self._start_time
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_second_chunk_latency(self):
|
||||||
|
if self._first_chunk_time and self._second_chunk_time:
|
||||||
|
return self._second_chunk_time - self._first_chunk_time
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def callback(user_data, result, error):
|
def callback(user_data, result, error):
|
||||||
if user_data._first_chunk_time is None and not error:
|
if not error:
|
||||||
user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
if user_data._first_chunk_time is None:
|
||||||
|
user_data._first_chunk_time = time.time()
|
||||||
|
elif user_data._second_chunk_time is None:
|
||||||
|
user_data._second_chunk_time = time.time()
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
user_data._completed_requests.put(error)
|
user_data._completed_requests.put(error)
|
||||||
else:
|
else:
|
||||||
user_data._completed_requests.put(result)
|
user_data._completed_requests.put(result)
|
||||||
# --- End Added UserData and callback ---
|
|
||||||
|
|
||||||
|
def stream_callback(user_data_map, result, error):
|
||||||
|
request_id = None
|
||||||
|
if error:
|
||||||
|
print(f"An error occurred in the stream callback: {error}")
|
||||||
|
else:
|
||||||
|
request_id = result.get_response().id
|
||||||
|
|
||||||
|
if request_id:
|
||||||
|
user_data = user_data_map.get(request_id)
|
||||||
|
if user_data:
|
||||||
|
callback(user_data, result, error)
|
||||||
|
else:
|
||||||
|
print(f"Warning: Could not find user_data for request_id {request_id}")
|
||||||
|
|
||||||
|
|
||||||
def write_triton_stats(stats, summary_file):
|
def write_triton_stats(stats, summary_file):
|
||||||
with open(summary_file, "w") as summary_f:
|
with open(summary_file, "w") as summary_f:
|
||||||
model_stats = stats["model_stats"]
|
model_stats = stats["model_stats"]
|
||||||
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
|
|
||||||
summary_f.write(
|
|
||||||
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
|
|
||||||
)
|
|
||||||
summary_f.write("To learn more about the log, please refer to: \n")
|
|
||||||
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
|
|
||||||
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
|
|
||||||
summary_f.write(
|
|
||||||
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
|
|
||||||
)
|
|
||||||
summary_f.write(
|
|
||||||
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
|
|
||||||
)
|
|
||||||
summary_f.write(
|
|
||||||
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
|
|
||||||
)
|
|
||||||
summary_f.write(
|
|
||||||
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
|
|
||||||
)
|
|
||||||
for model_state in model_stats:
|
for model_state in model_stats:
|
||||||
if "last_inference" not in model_state:
|
if "last_inference" not in model_state:
|
||||||
continue
|
continue
|
||||||
@@ -118,7 +122,10 @@ def write_triton_stats(stats, summary_file):
|
|||||||
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
||||||
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
||||||
summary_f.write(
|
summary_f.write(
|
||||||
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
|
f"queue time {total_queue_time_s:<5.2f} s, "
|
||||||
|
f"compute infer time {total_infer_time_s:<5.2f} s, "
|
||||||
|
f"compute input time {total_input_time_s:<5.2f} s, "
|
||||||
|
f"compute output time {total_output_time_s:<5.2f} s \n"
|
||||||
)
|
)
|
||||||
model_batch_stats = model_state["batch_stats"]
|
model_batch_stats = model_state["batch_stats"]
|
||||||
for batch in model_batch_stats:
|
for batch in model_batch_stats:
|
||||||
@@ -127,21 +134,86 @@ def write_triton_stats(stats, summary_file):
|
|||||||
compute_output = batch["compute_output"]
|
compute_output = batch["compute_output"]
|
||||||
compute_infer = batch["compute_infer"]
|
compute_infer = batch["compute_infer"]
|
||||||
batch_count = int(compute_infer["count"])
|
batch_count = int(compute_infer["count"])
|
||||||
|
if batch_count == 0:
|
||||||
|
continue
|
||||||
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
|
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
|
||||||
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
||||||
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
||||||
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
||||||
summary_f.write(
|
summary_f.write(
|
||||||
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
|
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
|
||||||
|
f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
|
||||||
|
f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
|
||||||
|
f"{compute_infer_time_ms / batch_count:.2f} ms, "
|
||||||
|
f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
|
||||||
|
f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
|
||||||
)
|
)
|
||||||
summary_f.write(
|
summary_f.write(
|
||||||
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
|
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
|
||||||
)
|
)
|
||||||
summary_f.write(
|
summary_f.write(
|
||||||
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
|
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def subtract_stats(stats_after, stats_before):
|
||||||
|
"""Subtracts two Triton inference statistics objects."""
|
||||||
|
stats_diff = json.loads(json.dumps(stats_after))
|
||||||
|
|
||||||
|
model_stats_before_map = {
|
||||||
|
s["name"]: {
|
||||||
|
"version": s["version"],
|
||||||
|
"last_inference": s.get("last_inference", 0),
|
||||||
|
"inference_count": s.get("inference_count", 0),
|
||||||
|
"execution_count": s.get("execution_count", 0),
|
||||||
|
"inference_stats": s.get("inference_stats", {}),
|
||||||
|
"batch_stats": s.get("batch_stats", []),
|
||||||
|
}
|
||||||
|
for s in stats_before["model_stats"]
|
||||||
|
}
|
||||||
|
|
||||||
|
for model_stat_after in stats_diff["model_stats"]:
|
||||||
|
model_name = model_stat_after["name"]
|
||||||
|
if model_name in model_stats_before_map:
|
||||||
|
model_stat_before = model_stats_before_map[model_name]
|
||||||
|
|
||||||
|
model_stat_after["inference_count"] = str(
|
||||||
|
int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
|
||||||
|
)
|
||||||
|
model_stat_after["execution_count"] = str(
|
||||||
|
int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
|
||||||
|
)
|
||||||
|
|
||||||
|
if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
|
||||||
|
for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
|
||||||
|
if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
|
||||||
|
if "ns" in model_stat_after["inference_stats"][key]:
|
||||||
|
ns_after = int(model_stat_after["inference_stats"][key]["ns"])
|
||||||
|
ns_before = int(model_stat_before["inference_stats"][key]["ns"])
|
||||||
|
model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
|
||||||
|
if "count" in model_stat_after["inference_stats"][key]:
|
||||||
|
count_after = int(model_stat_after["inference_stats"][key]["count"])
|
||||||
|
count_before = int(model_stat_before["inference_stats"][key]["count"])
|
||||||
|
model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
|
||||||
|
|
||||||
|
if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
|
||||||
|
batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
|
||||||
|
for batch_stat_after in model_stat_after["batch_stats"]:
|
||||||
|
bs = batch_stat_after["batch_size"]
|
||||||
|
if bs in batch_stats_before_map:
|
||||||
|
batch_stat_before = batch_stats_before_map[bs]
|
||||||
|
for key in ["compute_input", "compute_infer", "compute_output"]:
|
||||||
|
if key in batch_stat_after and key in batch_stat_before:
|
||||||
|
count_after = int(batch_stat_after[key]["count"])
|
||||||
|
count_before = int(batch_stat_before[key]["count"])
|
||||||
|
batch_stat_after[key]["count"] = str(count_after - count_before)
|
||||||
|
|
||||||
|
ns_after = int(batch_stat_after[key]["ns"])
|
||||||
|
ns_before = int(batch_stat_before[key]["ns"])
|
||||||
|
batch_stat_after[key]["ns"] = str(ns_after - ns_before)
|
||||||
|
return stats_diff
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
@@ -209,7 +281,8 @@ def get_args():
|
|||||||
choices=[
|
choices=[
|
||||||
"f5_tts",
|
"f5_tts",
|
||||||
"spark_tts",
|
"spark_tts",
|
||||||
"cosyvoice2"],
|
"cosyvoice2",
|
||||||
|
"cosyvoice2_dit"],
|
||||||
help="triton model_repo module name to request",
|
help="triton model_repo module name to request",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -243,7 +316,6 @@ def get_args():
|
|||||||
help="log directory",
|
help="log directory",
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Added arguments ---
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mode",
|
"--mode",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -257,7 +329,13 @@ def get_args():
|
|||||||
default=0.1,
|
default=0.1,
|
||||||
help="Chunk overlap duration for streaming reconstruction (in seconds)."
|
help="Chunk overlap duration for streaming reconstruction (in seconds)."
|
||||||
)
|
)
|
||||||
# --- End Added arguments ---
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-spk2info-cache",
|
||||||
|
type=str,
|
||||||
|
default="False",
|
||||||
|
help="Use spk2info cache for reference audio.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@@ -278,38 +356,33 @@ def load_audio(wav_path, target_sample_rate=16000):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_request_input_output(
|
def prepare_request_input_output(
|
||||||
protocol_client, # Can be grpcclient_aio or grpcclient_sync
|
protocol_client,
|
||||||
waveform,
|
waveform,
|
||||||
reference_text,
|
reference_text,
|
||||||
target_text,
|
target_text,
|
||||||
sample_rate=16000,
|
sample_rate=16000,
|
||||||
padding_duration: int = None # Optional padding for offline mode
|
padding_duration: int = None,
|
||||||
|
use_spk2info_cache: bool = False
|
||||||
):
|
):
|
||||||
"""Prepares inputs for Triton inference (offline or streaming)."""
|
"""Prepares inputs for Triton inference (offline or streaming)."""
|
||||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||||
|
|
||||||
# Apply padding only if padding_duration is provided (for offline)
|
|
||||||
if padding_duration:
|
if padding_duration:
|
||||||
duration = len(waveform) / sample_rate
|
duration = len(waveform) / sample_rate
|
||||||
# Estimate target duration based on text length ratio (crude estimation)
|
|
||||||
# Avoid division by zero if reference_text is empty
|
|
||||||
if reference_text:
|
if reference_text:
|
||||||
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
||||||
else:
|
else:
|
||||||
estimated_target_duration = duration # Assume target duration similar to reference if no text
|
estimated_target_duration = duration
|
||||||
|
|
||||||
# Calculate required samples based on estimated total duration
|
|
||||||
required_total_samples = padding_duration * sample_rate * (
|
required_total_samples = padding_duration * sample_rate * (
|
||||||
(int(estimated_target_duration + duration) // padding_duration) + 1
|
(int(estimated_target_duration + duration) // padding_duration) + 1
|
||||||
)
|
)
|
||||||
samples = np.zeros((1, required_total_samples), dtype=np.float32)
|
samples = np.zeros((1, required_total_samples), dtype=np.float32)
|
||||||
samples[0, : len(waveform)] = waveform
|
samples[0, : len(waveform)] = waveform
|
||||||
else:
|
else:
|
||||||
# No padding for streaming or if padding_duration is None
|
|
||||||
samples = waveform.reshape(1, -1).astype(np.float32)
|
samples = waveform.reshape(1, -1).astype(np.float32)
|
||||||
|
|
||||||
# Common input creation logic
|
|
||||||
inputs = [
|
inputs = [
|
||||||
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
|
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
|
||||||
protocol_client.InferInput(
|
protocol_client.InferInput(
|
||||||
@@ -330,7 +403,8 @@ def prepare_request_input_output(
|
|||||||
inputs[3].set_data_from_numpy(input_data_numpy)
|
inputs[3].set_data_from_numpy(input_data_numpy)
|
||||||
|
|
||||||
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
||||||
|
if use_spk2info_cache:
|
||||||
|
inputs = inputs[-1:]
|
||||||
return inputs, outputs
|
return inputs, outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -347,12 +421,8 @@ def run_sync_streaming_inference(
|
|||||||
):
|
):
|
||||||
"""Helper function to run the blocking sync streaming call."""
|
"""Helper function to run the blocking sync streaming call."""
|
||||||
start_time_total = time.time()
|
start_time_total = time.time()
|
||||||
user_data.record_start_time() # Record start time for first chunk latency calculation
|
user_data.record_start_time()
|
||||||
|
|
||||||
# Establish stream
|
|
||||||
sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
|
|
||||||
|
|
||||||
# Send request
|
|
||||||
sync_triton_client.async_stream_infer(
|
sync_triton_client.async_stream_infer(
|
||||||
model_name,
|
model_name,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -361,84 +431,76 @@ def run_sync_streaming_inference(
|
|||||||
enable_empty_final_response=True,
|
enable_empty_final_response=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process results
|
|
||||||
audios = []
|
audios = []
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
result = user_data._completed_requests.get() # Add timeout
|
result = user_data._completed_requests.get(timeout=200)
|
||||||
if isinstance(result, InferenceServerException):
|
if isinstance(result, InferenceServerException):
|
||||||
print(f"Received InferenceServerException: {result}")
|
print(f"Received InferenceServerException: {result}")
|
||||||
sync_triton_client.stop_stream()
|
return None, None, None, None
|
||||||
return None, None, None # Indicate error
|
|
||||||
# Get response metadata
|
|
||||||
response = result.get_response()
|
response = result.get_response()
|
||||||
final = response.parameters["triton_final_response"].bool_param
|
final = response.parameters["triton_final_response"].bool_param
|
||||||
if final is True:
|
if final is True:
|
||||||
break
|
break
|
||||||
|
|
||||||
audio_chunk = result.as_numpy("waveform").reshape(-1)
|
audio_chunk = result.as_numpy("waveform").reshape(-1)
|
||||||
if audio_chunk.size > 0: # Only append non-empty chunks
|
if audio_chunk.size > 0:
|
||||||
audios.append(audio_chunk)
|
audios.append(audio_chunk)
|
||||||
else:
|
else:
|
||||||
print("Warning: received empty audio chunk.")
|
print("Warning: received empty audio chunk.")
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
print(f"Timeout waiting for response for request id {request_id}")
|
print(f"Timeout waiting for response for request id {request_id}")
|
||||||
sync_triton_client.stop_stream()
|
return None, None, None, None
|
||||||
return None, None, None # Indicate error
|
|
||||||
|
|
||||||
sync_triton_client.stop_stream()
|
|
||||||
end_time_total = time.time()
|
end_time_total = time.time()
|
||||||
total_request_latency = end_time_total - start_time_total
|
total_request_latency = end_time_total - start_time_total
|
||||||
first_chunk_latency = user_data.get_first_chunk_latency()
|
first_chunk_latency = user_data.get_first_chunk_latency()
|
||||||
|
second_chunk_latency = user_data.get_second_chunk_latency()
|
||||||
|
|
||||||
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
|
|
||||||
actual_duration = 0
|
|
||||||
if audios:
|
if audios:
|
||||||
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
|
if model_name == "spark_tts":
|
||||||
fade_out = np.linspace(1, 0, cross_fade_samples)
|
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
|
||||||
fade_in = np.linspace(0, 1, cross_fade_samples)
|
fade_out = np.linspace(1, 0, cross_fade_samples)
|
||||||
reconstructed_audio = None
|
fade_in = np.linspace(0, 1, cross_fade_samples)
|
||||||
|
reconstructed_audio = None
|
||||||
|
|
||||||
# Simplified reconstruction based on client_grpc_streaming.py
|
if not audios:
|
||||||
if not audios:
|
print("Warning: No audio chunks received.")
|
||||||
print("Warning: No audio chunks received.")
|
reconstructed_audio = np.array([], dtype=np.float32)
|
||||||
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
|
elif len(audios) == 1:
|
||||||
elif len(audios) == 1:
|
reconstructed_audio = audios[0]
|
||||||
reconstructed_audio = audios[0]
|
else:
|
||||||
|
reconstructed_audio = audios[0][:-cross_fade_samples]
|
||||||
|
for i in range(1, len(audios)):
|
||||||
|
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
||||||
|
audios[i - 1][-cross_fade_samples:] * fade_out)
|
||||||
|
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
||||||
|
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
||||||
|
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
||||||
|
|
||||||
|
if reconstructed_audio is not None and reconstructed_audio.size > 0:
|
||||||
|
actual_duration = len(reconstructed_audio) / save_sample_rate
|
||||||
|
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
||||||
|
else:
|
||||||
|
print("Warning: No audio chunks received or reconstructed.")
|
||||||
|
actual_duration = 0
|
||||||
else:
|
else:
|
||||||
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
|
reconstructed_audio = np.concatenate(audios)
|
||||||
for i in range(1, len(audios)):
|
|
||||||
# Cross-fade section
|
|
||||||
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
|
||||||
audios[i - 1][-cross_fade_samples:] * fade_out)
|
|
||||||
# Middle section of the current chunk
|
|
||||||
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
|
||||||
# Concatenate
|
|
||||||
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
|
||||||
# Add the last part of the final chunk
|
|
||||||
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
|
||||||
|
|
||||||
if reconstructed_audio is not None and reconstructed_audio.size > 0:
|
|
||||||
actual_duration = len(reconstructed_audio) / save_sample_rate
|
actual_duration = len(reconstructed_audio) / save_sample_rate
|
||||||
# Save reconstructed audio
|
|
||||||
os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
|
|
||||||
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
||||||
else:
|
|
||||||
print("Warning: No audio chunks received or reconstructed.")
|
|
||||||
actual_duration = 0 # Set duration to 0 if no audio
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("Warning: No audio chunks received.")
|
print("Warning: No audio chunks received.")
|
||||||
actual_duration = 0
|
actual_duration = 0
|
||||||
|
|
||||||
return total_request_latency, first_chunk_latency, actual_duration
|
return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
|
||||||
|
|
||||||
|
|
||||||
async def send_streaming(
|
async def send_streaming(
|
||||||
manifest_item_list: list,
|
manifest_item_list: list,
|
||||||
name: str,
|
name: str,
|
||||||
server_url: str, # Changed from sync_triton_client
|
server_url: str,
|
||||||
protocol_client: types.ModuleType,
|
protocol_client: types.ModuleType,
|
||||||
log_interval: int,
|
log_interval: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@@ -446,15 +508,18 @@ async def send_streaming(
|
|||||||
save_sample_rate: int = 16000,
|
save_sample_rate: int = 16000,
|
||||||
chunk_overlap_duration: float = 0.1,
|
chunk_overlap_duration: float = 0.1,
|
||||||
padding_duration: int = None,
|
padding_duration: int = None,
|
||||||
|
use_spk2info_cache: bool = False,
|
||||||
):
|
):
|
||||||
total_duration = 0.0
|
total_duration = 0.0
|
||||||
latency_data = []
|
latency_data = []
|
||||||
task_id = int(name[5:])
|
task_id = int(name[5:])
|
||||||
sync_triton_client = None # Initialize client variable
|
sync_triton_client = None
|
||||||
|
user_data_map = {}
|
||||||
|
|
||||||
try: # Wrap in try...finally to ensure client closing
|
try:
|
||||||
print(f"{name}: Initializing sync client for streaming...")
|
print(f"{name}: Initializing sync client for streaming...")
|
||||||
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
|
||||||
|
sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
|
||||||
|
|
||||||
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
||||||
for i, item in enumerate(manifest_item_list):
|
for i, item in enumerate(manifest_item_list):
|
||||||
@@ -471,14 +536,16 @@ async def send_streaming(
|
|||||||
reference_text,
|
reference_text,
|
||||||
target_text,
|
target_text,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
padding_duration=padding_duration
|
padding_duration=padding_duration,
|
||||||
|
use_spk2info_cache=use_spk2info_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
user_data = UserData()
|
user_data = UserData()
|
||||||
|
user_data_map[request_id] = user_data
|
||||||
|
|
||||||
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||||
|
total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
|
||||||
total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
|
|
||||||
run_sync_streaming_inference,
|
run_sync_streaming_inference,
|
||||||
sync_triton_client,
|
sync_triton_client,
|
||||||
model_name,
|
model_name,
|
||||||
@@ -492,12 +559,18 @@ async def send_streaming(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if total_request_latency is not None:
|
if total_request_latency is not None:
|
||||||
print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
|
print(
|
||||||
latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
|
f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
|
||||||
|
f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
|
||||||
|
f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
|
||||||
|
)
|
||||||
|
latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
|
||||||
total_duration += actual_duration
|
total_duration += actual_duration
|
||||||
else:
|
else:
|
||||||
print(f"{name}: Item {i} failed.")
|
print(f"{name}: Item {i} failed.")
|
||||||
|
|
||||||
|
del user_data_map[request_id]
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -505,10 +578,11 @@ async def send_streaming(
|
|||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
finally: # Ensure client is closed
|
finally:
|
||||||
if sync_triton_client:
|
if sync_triton_client:
|
||||||
try:
|
try:
|
||||||
print(f"{name}: Closing sync client...")
|
print(f"{name}: Closing stream and sync client...")
|
||||||
|
sync_triton_client.stop_stream()
|
||||||
sync_triton_client.close()
|
sync_triton_client.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{name}: Error closing sync client: {e}")
|
print(f"{name}: Error closing sync client: {e}")
|
||||||
@@ -527,12 +601,12 @@ async def send(
|
|||||||
padding_duration: int = None,
|
padding_duration: int = None,
|
||||||
audio_save_dir: str = "./",
|
audio_save_dir: str = "./",
|
||||||
save_sample_rate: int = 16000,
|
save_sample_rate: int = 16000,
|
||||||
|
use_spk2info_cache: bool = False,
|
||||||
):
|
):
|
||||||
total_duration = 0.0
|
total_duration = 0.0
|
||||||
latency_data = []
|
latency_data = []
|
||||||
task_id = int(name[5:])
|
task_id = int(name[5:])
|
||||||
|
|
||||||
print(f"manifest_item_list: {manifest_item_list}")
|
|
||||||
for i, item in enumerate(manifest_item_list):
|
for i, item in enumerate(manifest_item_list):
|
||||||
if i % log_interval == 0:
|
if i % log_interval == 0:
|
||||||
print(f"{name}: {i}/{len(manifest_item_list)}")
|
print(f"{name}: {i}/{len(manifest_item_list)}")
|
||||||
@@ -545,7 +619,8 @@ async def send(
|
|||||||
reference_text,
|
reference_text,
|
||||||
target_text,
|
target_text,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
padding_duration=padding_duration
|
padding_duration=padding_duration,
|
||||||
|
use_spk2info_cache=use_spk2info_cache
|
||||||
)
|
)
|
||||||
sequence_id = 100000000 + i + task_id * 10
|
sequence_id = 100000000 + i + task_id * 10
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -572,7 +647,6 @@ def load_manifests(manifest_path):
|
|||||||
assert len(line.strip().split("|")) == 4
|
assert len(line.strip().split("|")) == 4
|
||||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||||
utt = Path(utt).stem
|
utt = Path(utt).stem
|
||||||
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
|
|
||||||
if not os.path.isabs(prompt_wav):
|
if not os.path.isabs(prompt_wav):
|
||||||
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
|
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
|
||||||
manifest_list.append(
|
manifest_list.append(
|
||||||
@@ -613,23 +687,17 @@ async def main():
|
|||||||
args = get_args()
|
args = get_args()
|
||||||
url = f"{args.server_addr}:{args.server_port}"
|
url = f"{args.server_addr}:{args.server_port}"
|
||||||
|
|
||||||
# --- Client Initialization based on mode ---
|
|
||||||
triton_client = None
|
triton_client = None
|
||||||
protocol_client = None
|
protocol_client = None
|
||||||
if args.mode == "offline":
|
if args.mode == "offline":
|
||||||
print("Initializing gRPC client for offline mode...")
|
print("Initializing gRPC client for offline mode...")
|
||||||
# Use the async client for offline tasks
|
|
||||||
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
||||||
protocol_client = grpcclient_aio
|
protocol_client = grpcclient_aio
|
||||||
elif args.mode == "streaming":
|
elif args.mode == "streaming":
|
||||||
print("Initializing gRPC client for streaming mode...")
|
print("Initializing gRPC client for streaming mode...")
|
||||||
# Use the sync client for streaming tasks, handled via asyncio.to_thread
|
protocol_client = grpcclient_sync
|
||||||
# We will create one sync client instance PER TASK inside send_streaming.
|
|
||||||
# triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
|
|
||||||
protocol_client = grpcclient_sync # protocol client for input prep
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid mode: {args.mode}")
|
raise ValueError(f"Invalid mode: {args.mode}")
|
||||||
# --- End Client Initialization ---
|
|
||||||
|
|
||||||
if args.reference_audio:
|
if args.reference_audio:
|
||||||
args.num_tasks = 1
|
args.num_tasks = 1
|
||||||
@@ -663,14 +731,24 @@ async def main():
|
|||||||
else:
|
else:
|
||||||
manifest_item_list = load_manifests(args.manifest_path)
|
manifest_item_list = load_manifests(args.manifest_path)
|
||||||
|
|
||||||
|
stats_client = None
|
||||||
|
stats_before = None
|
||||||
|
try:
|
||||||
|
print("Initializing temporary async client for fetching stats...")
|
||||||
|
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
||||||
|
print("Fetching inference statistics before running tasks...")
|
||||||
|
stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not retrieve statistics before running tasks: {e}")
|
||||||
|
|
||||||
num_tasks = min(args.num_tasks, len(manifest_item_list))
|
num_tasks = min(args.num_tasks, len(manifest_item_list))
|
||||||
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
||||||
|
|
||||||
os.makedirs(args.log_dir, exist_ok=True)
|
os.makedirs(args.log_dir, exist_ok=True)
|
||||||
|
args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
|
||||||
tasks = []
|
tasks = []
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(num_tasks):
|
for i in range(num_tasks):
|
||||||
# --- Task Creation based on mode ---
|
|
||||||
if args.mode == "offline":
|
if args.mode == "offline":
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
send(
|
send(
|
||||||
@@ -683,6 +761,7 @@ async def main():
|
|||||||
audio_save_dir=args.log_dir,
|
audio_save_dir=args.log_dir,
|
||||||
padding_duration=1,
|
padding_duration=1,
|
||||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||||
|
use_spk2info_cache=args.use_spk2info_cache,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif args.mode == "streaming":
|
elif args.mode == "streaming":
|
||||||
@@ -690,7 +769,7 @@ async def main():
|
|||||||
send_streaming(
|
send_streaming(
|
||||||
manifest_item_list[i],
|
manifest_item_list[i],
|
||||||
name=f"task-{i}",
|
name=f"task-{i}",
|
||||||
server_url=url, # Pass URL instead of client
|
server_url=url,
|
||||||
protocol_client=protocol_client,
|
protocol_client=protocol_client,
|
||||||
log_interval=args.log_interval,
|
log_interval=args.log_interval,
|
||||||
model_name=args.model_name,
|
model_name=args.model_name,
|
||||||
@@ -698,9 +777,9 @@ async def main():
|
|||||||
padding_duration=10,
|
padding_duration=10,
|
||||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||||
chunk_overlap_duration=args.chunk_overlap_duration,
|
chunk_overlap_duration=args.chunk_overlap_duration,
|
||||||
|
use_spk2info_cache=args.use_spk2info_cache,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# --- End Task Creation ---
|
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
ans_list = await asyncio.gather(*tasks)
|
ans_list = await asyncio.gather(*tasks)
|
||||||
@@ -713,7 +792,7 @@ async def main():
|
|||||||
for ans in ans_list:
|
for ans in ans_list:
|
||||||
if ans:
|
if ans:
|
||||||
total_duration += ans[0]
|
total_duration += ans[0]
|
||||||
latency_data.extend(ans[1]) # Use extend for list of lists
|
latency_data.extend(ans[1])
|
||||||
else:
|
else:
|
||||||
print("Warning: A task returned None, possibly due to an error.")
|
print("Warning: A task returned None, possibly due to an error.")
|
||||||
|
|
||||||
@@ -729,10 +808,8 @@ async def main():
|
|||||||
s += f"({total_duration / 3600:.2f} hours)\n"
|
s += f"({total_duration / 3600:.2f} hours)\n"
|
||||||
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
|
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
|
||||||
|
|
||||||
# --- Statistics Reporting based on mode ---
|
|
||||||
if latency_data:
|
if latency_data:
|
||||||
if args.mode == "offline":
|
if args.mode == "offline":
|
||||||
# Original offline latency calculation
|
|
||||||
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
|
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
|
||||||
if latency_list:
|
if latency_list:
|
||||||
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
||||||
@@ -747,9 +824,9 @@ async def main():
|
|||||||
s += "No latency data collected for offline mode.\n"
|
s += "No latency data collected for offline mode.\n"
|
||||||
|
|
||||||
elif args.mode == "streaming":
|
elif args.mode == "streaming":
|
||||||
# Calculate stats for total request latency and first chunk latency
|
total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
|
||||||
total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
|
first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
|
||||||
first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
|
second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
|
||||||
|
|
||||||
s += "\n--- Total Request Latency ---\n"
|
s += "\n--- Total Request Latency ---\n"
|
||||||
if total_latency_list:
|
if total_latency_list:
|
||||||
@@ -776,9 +853,21 @@ async def main():
|
|||||||
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
||||||
else:
|
else:
|
||||||
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
||||||
|
|
||||||
|
s += "\n--- Second Chunk Latency ---\n"
|
||||||
|
if second_chunk_latency_list:
|
||||||
|
avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
|
||||||
|
variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
|
||||||
|
s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
||||||
|
s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
|
||||||
|
else:
|
||||||
|
s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
|
||||||
else:
|
else:
|
||||||
s += "No latency data collected.\n"
|
s += "No latency data collected.\n"
|
||||||
# --- End Statistics Reporting ---
|
|
||||||
|
|
||||||
print(s)
|
print(s)
|
||||||
if args.manifest_path:
|
if args.manifest_path:
|
||||||
@@ -788,26 +877,27 @@ async def main():
|
|||||||
elif args.reference_audio:
|
elif args.reference_audio:
|
||||||
name = Path(args.reference_audio).stem
|
name = Path(args.reference_audio).stem
|
||||||
else:
|
else:
|
||||||
name = "results" # Default name if no manifest/split/audio provided
|
name = "results"
|
||||||
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
||||||
f.write(s)
|
f.write(s)
|
||||||
|
|
||||||
# --- Statistics Fetching using temporary Async Client ---
|
|
||||||
# Use a separate async client for fetching stats regardless of mode
|
|
||||||
stats_client = None
|
|
||||||
try:
|
try:
|
||||||
print("Initializing temporary async client for fetching stats...")
|
if stats_client and stats_before:
|
||||||
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
print("Fetching inference statistics after running tasks...")
|
||||||
print("Fetching inference statistics...")
|
stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
||||||
# Fetching for all models, filtering might be needed depending on server setup
|
|
||||||
stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
|
||||||
print("Fetching model config...")
|
|
||||||
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
|
|
||||||
|
|
||||||
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
print("Calculating statistics difference...")
|
||||||
|
stats = subtract_stats(stats_after, stats_before)
|
||||||
|
|
||||||
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
print("Fetching model config...")
|
||||||
json.dump(metadata, f, indent=4)
|
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
|
||||||
|
|
||||||
|
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
||||||
|
|
||||||
|
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
||||||
|
json.dump(metadata, f, indent=4)
|
||||||
|
else:
|
||||||
|
print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Could not retrieve statistics or config: {e}")
|
print(f"Could not retrieve statistics or config: {e}")
|
||||||
@@ -818,11 +908,9 @@ async def main():
|
|||||||
await stats_client.close()
|
await stats_client.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error closing async stats client: {e}")
|
print(f"Error closing async stats client: {e}")
|
||||||
# --- End Statistics Fetching ---
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# asyncio.run(main()) # Use TaskGroup for better exception handling if needed
|
|
||||||
async def run_main():
|
async def run_main():
|
||||||
try:
|
try:
|
||||||
await main()
|
await main()
|
||||||
|
|||||||
@@ -25,7 +25,6 @@
|
|||||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
import requests
|
import requests
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import json
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|||||||
20
runtime/triton_trtllm/docker-compose.dit.yml
Normal file
20
runtime/triton_trtllm/docker-compose.dit.yml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
services:
|
||||||
|
tts:
|
||||||
|
image: soar97/triton-cosyvoice:25.06
|
||||||
|
shm_size: '1gb'
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
- "8001:8001"
|
||||||
|
- "8002:8002"
|
||||||
|
environment:
|
||||||
|
- PYTHONIOENCODING=utf-8
|
||||||
|
- MODEL_ID=${MODEL_ID}
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
device_ids: ['0']
|
||||||
|
capabilities: [gpu]
|
||||||
|
command: >
|
||||||
|
/bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run_stepaudio2_dit_token2wav.sh 0 3"
|
||||||
@@ -32,7 +32,7 @@ import triton_python_backend_utils as pb_utils
|
|||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import s3tokenizer
|
import s3tokenizer
|
||||||
|
torch.set_num_threads(1)
|
||||||
ORIGINAL_VOCAB_SIZE = 151663
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ dynamic_batching {
|
|||||||
}
|
}
|
||||||
parameters [
|
parameters [
|
||||||
{
|
{
|
||||||
key: "model_dir",
|
key: "model_dir",
|
||||||
value: {string_value:"${model_dir}"}
|
value: {string_value:"${model_dir}"}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -25,23 +25,24 @@
|
|||||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import re
|
import threading
|
||||||
from typing import Dict, List, Tuple, Optional, Union
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||||
import triton_python_backend_utils as pb_utils
|
import triton_python_backend_utils as pb_utils
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
import torchaudio.compliance.kaldi as kaldi
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import onnxruntime
|
|
||||||
|
|
||||||
|
|
||||||
from matcha.utils.audio import mel_spectrogram
|
from matcha.utils.audio import mel_spectrogram
|
||||||
|
|
||||||
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
"""Triton Python model for Spark TTS.
|
"""Triton Python model for Spark TTS.
|
||||||
@@ -62,6 +63,8 @@ class TritonPythonModel:
|
|||||||
parameters = self.model_config['parameters']
|
parameters = self.model_config['parameters']
|
||||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||||
self.logger.log_info(f"model_params:{model_params}")
|
self.logger.log_info(f"model_params:{model_params}")
|
||||||
|
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
||||||
|
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
||||||
|
|
||||||
# Initialize tokenizer
|
# Initialize tokenizer
|
||||||
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
||||||
@@ -72,11 +75,15 @@ class TritonPythonModel:
|
|||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
||||||
|
|
||||||
campplus_model = f'{model_params["model_dir"]}/campplus.onnx'
|
self.token_frame_rate = 25
|
||||||
option = onnxruntime.SessionOptions()
|
self.flow_pre_lookahead_len = 3
|
||||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
self.token_hop_len = 15
|
||||||
option.intra_op_num_threads = 1
|
|
||||||
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
|
||||||
|
if not os.path.exists(spk_info_path):
|
||||||
|
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
||||||
|
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||||
|
self.default_spk_info = spk_info["001"]
|
||||||
|
|
||||||
def forward_llm(self, input_ids):
|
def forward_llm(self, input_ids):
|
||||||
"""
|
"""
|
||||||
@@ -105,7 +112,7 @@ class TritonPythonModel:
|
|||||||
"""
|
"""
|
||||||
# convert input_ids to numpy, with shape [1, sequence_length]
|
# convert input_ids to numpy, with shape [1, sequence_length]
|
||||||
input_ids = input_ids.cpu().numpy()
|
input_ids = input_ids.cpu().numpy()
|
||||||
max_tokens = 1024
|
max_tokens = 750
|
||||||
input_dict = {
|
input_dict = {
|
||||||
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
||||||
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
||||||
@@ -114,6 +121,8 @@ class TritonPythonModel:
|
|||||||
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
||||||
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
||||||
"temperature": np.array([[0.8]], dtype=np.float32),
|
"temperature": np.array([[0.8]], dtype=np.float32),
|
||||||
|
"repetition_penalty": np.array([[1.1]], dtype=np.float32),
|
||||||
|
"random_seed": np.array([[42]], dtype=np.uint64),
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
||||||
}
|
}
|
||||||
@@ -188,12 +197,40 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
return prompt_speech_tokens
|
return prompt_speech_tokens
|
||||||
|
|
||||||
|
def forward_speaker_embedding(self, wav):
|
||||||
|
"""Forward pass through the speaker embedding component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav: Input waveform tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prompt speaker embedding tensor
|
||||||
|
"""
|
||||||
|
inference_request = pb_utils.InferenceRequest(
|
||||||
|
model_name='speaker_embedding',
|
||||||
|
requested_output_names=['prompt_spk_embedding'],
|
||||||
|
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_response = inference_request.exec()
|
||||||
|
if inference_response.has_error():
|
||||||
|
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||||
|
|
||||||
|
# Extract and convert output tensors
|
||||||
|
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
|
||||||
|
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
|
||||||
|
|
||||||
|
return prompt_spk_embedding
|
||||||
|
|
||||||
def forward_token2wav(
|
def forward_token2wav(
|
||||||
self,
|
self,
|
||||||
prompt_speech_tokens: torch.Tensor,
|
target_speech_tokens: torch.Tensor,
|
||||||
prompt_speech_feat: torch.Tensor,
|
request_id: str,
|
||||||
prompt_spk_embedding: torch.Tensor,
|
prompt_speech_tokens: torch.Tensor = None,
|
||||||
target_speech_tokens: torch.Tensor) -> torch.Tensor:
|
prompt_speech_feat: torch.Tensor = None,
|
||||||
|
prompt_spk_embedding: torch.Tensor = None,
|
||||||
|
token_offset: int = None,
|
||||||
|
finalize: bool = None) -> torch.Tensor:
|
||||||
"""Forward pass through the vocoder component.
|
"""Forward pass through the vocoder component.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -205,16 +242,30 @@ class TritonPythonModel:
|
|||||||
Returns:
|
Returns:
|
||||||
Generated waveform tensor
|
Generated waveform tensor
|
||||||
"""
|
"""
|
||||||
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
|
||||||
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
|
||||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
|
||||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||||
|
|
||||||
|
inputs_tensor = [target_speech_tokens_tensor]
|
||||||
|
|
||||||
|
if token_offset is not None:
|
||||||
|
assert finalize is not None
|
||||||
|
token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
|
||||||
|
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||||
|
inputs_tensor.append(token_offset_tensor)
|
||||||
|
inputs_tensor.append(finalize_tensor)
|
||||||
|
|
||||||
|
if prompt_spk_embedding is not None:
|
||||||
|
assert prompt_speech_feat is not None
|
||||||
|
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
||||||
|
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
||||||
|
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
||||||
|
inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
|
||||||
|
|
||||||
# Create and execute inference request
|
# Create and execute inference request
|
||||||
inference_request = pb_utils.InferenceRequest(
|
inference_request = pb_utils.InferenceRequest(
|
||||||
model_name='token2wav',
|
model_name='token2wav',
|
||||||
requested_output_names=['waveform'],
|
requested_output_names=['waveform'],
|
||||||
inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
inputs=inputs_tensor,
|
||||||
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_response = inference_request.exec()
|
inference_response = inference_request.exec()
|
||||||
@@ -235,17 +286,6 @@ class TritonPythonModel:
|
|||||||
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def _extract_spk_embedding(self, speech):
|
|
||||||
feat = kaldi.fbank(speech,
|
|
||||||
num_mel_bins=80,
|
|
||||||
dither=0,
|
|
||||||
sample_frequency=16000)
|
|
||||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
|
||||||
embedding = self.campplus_session.run(None,
|
|
||||||
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
|
||||||
embedding = torch.tensor([embedding]).to(self.device).half()
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
def _extract_speech_feat(self, speech):
|
def _extract_speech_feat(self, speech):
|
||||||
speech_feat = mel_spectrogram(
|
speech_feat = mel_spectrogram(
|
||||||
speech,
|
speech,
|
||||||
@@ -263,6 +303,14 @@ class TritonPythonModel:
|
|||||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||||
return speech_feat
|
return speech_feat
|
||||||
|
|
||||||
|
def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
|
||||||
|
for generated_ids in generated_ids_iter:
|
||||||
|
generated_ids = generated_ids.tolist()
|
||||||
|
if len(generated_ids) == 0:
|
||||||
|
break
|
||||||
|
semantic_token_ids_arr.extend(generated_ids)
|
||||||
|
llm_is_done_flag[0] = True
|
||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
"""Execute inference on the batched requests.
|
"""Execute inference on the batched requests.
|
||||||
|
|
||||||
@@ -275,25 +323,33 @@ class TritonPythonModel:
|
|||||||
responses = []
|
responses = []
|
||||||
|
|
||||||
for request in requests:
|
for request in requests:
|
||||||
|
request_id = request.request_id()
|
||||||
# Extract input tensors
|
# Extract input tensors
|
||||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
|
||||||
|
|
||||||
# Process reference audio through audio tokenizer
|
# Process reference audio through audio tokenizer
|
||||||
|
if wav is not None:
|
||||||
|
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||||
|
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||||
|
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||||
|
|
||||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
wav_tensor = wav.as_numpy()
|
||||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||||
|
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||||
|
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
||||||
|
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||||
|
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||||
|
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||||
|
|
||||||
wav_tensor = wav.as_numpy()
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
reference_text = reference_text[0][0].decode('utf-8')
|
||||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
else:
|
||||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
# using pre-cached reference text
|
||||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
reference_text = self.default_spk_info["prompt_text"]
|
||||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||||
|
prompt_speech_feat = None
|
||||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
prompt_spk_embedding = None
|
||||||
reference_text = reference_text[0][0].decode('utf-8')
|
|
||||||
|
|
||||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||||
target_text = target_text[0][0].decode('utf-8')
|
target_text = target_text[0][0].decode('utf-8')
|
||||||
@@ -310,22 +366,73 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
if self.decoupled:
|
if self.decoupled:
|
||||||
response_sender = request.get_response_sender()
|
response_sender = request.get_response_sender()
|
||||||
request_id = request.request_id()
|
|
||||||
generated_ids = []
|
|
||||||
for generated_id in generated_ids_iter:
|
|
||||||
# convert the numpy array into a int32 tensor
|
|
||||||
generated_id = generated_id.tolist()
|
|
||||||
if len(generated_id) > 0:
|
|
||||||
assert len(generated_id) == 1, "Generated ID is not a single integer"
|
|
||||||
generated_ids.append(generated_id[0])
|
|
||||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
|
|
||||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
|
||||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
|
||||||
|
|
||||||
# Prepare response
|
semantic_token_ids_arr = []
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
llm_is_done_flag = [False]
|
||||||
|
|
||||||
|
llm_thread = threading.Thread(
|
||||||
|
target=self._llm_gen_thread,
|
||||||
|
args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_thread.start()
|
||||||
|
|
||||||
|
token_offset, chunk_index = 0, 0
|
||||||
|
start_time = time.time()
|
||||||
|
this_token_hop_len = self.token_hop_len
|
||||||
|
|
||||||
|
while True:
|
||||||
|
pending_num = len(semantic_token_ids_arr) - token_offset
|
||||||
|
|
||||||
|
if llm_is_done_flag[0]:
|
||||||
|
break
|
||||||
|
|
||||||
|
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
||||||
|
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
||||||
|
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
|
|
||||||
|
sub_tts_speech = self.forward_token2wav(
|
||||||
|
this_tts_speech_token, request_id, prompt_speech_tokens,
|
||||||
|
prompt_speech_feat, prompt_spk_embedding, token_offset, False
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
|
response_sender.send(inference_response)
|
||||||
|
|
||||||
|
token_offset += this_token_hop_len
|
||||||
|
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
|
||||||
|
|
||||||
|
if self.dynamic_chunk_strategy == "exponential":
|
||||||
|
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
||||||
|
elif self.dynamic_chunk_strategy == "time_based":
|
||||||
|
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
||||||
|
cost_time = time.time() - start_time
|
||||||
|
duration = token_offset / self.token_frame_rate
|
||||||
|
if chunk_index > 0 and cost_time > 0:
|
||||||
|
avg_chunk_processing_time = cost_time / (chunk_index + 1)
|
||||||
|
if avg_chunk_processing_time > 0:
|
||||||
|
multiples = (duration - cost_time) / avg_chunk_processing_time
|
||||||
|
self.logger.log_info(f"multiples: {multiples}")
|
||||||
|
next_pending_num = len(semantic_token_ids_arr) - token_offset
|
||||||
|
if multiples > 4:
|
||||||
|
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
|
||||||
|
elif multiples > 2:
|
||||||
|
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
|
||||||
|
else:
|
||||||
|
this_token_hop_len = self.token_hop_len
|
||||||
|
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
||||||
|
chunk_index += 1
|
||||||
|
else:
|
||||||
|
time.sleep(0.02)
|
||||||
|
|
||||||
|
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
|
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
|
||||||
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
response_sender.send(inference_response)
|
response_sender.send(inference_response)
|
||||||
|
|
||||||
|
llm_thread.join()
|
||||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||||
else:
|
else:
|
||||||
@@ -334,8 +441,7 @@ class TritonPythonModel:
|
|||||||
if generated_ids is None or len(generated_ids) == 0:
|
if generated_ids is None or len(generated_ids) == 0:
|
||||||
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
||||||
|
|
||||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
|
||||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
|
||||||
|
|
||||||
# Prepare response
|
# Prepare response
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||||
|
|||||||
@@ -23,11 +23,11 @@ model_transaction_policy {
|
|||||||
}
|
}
|
||||||
parameters [
|
parameters [
|
||||||
{
|
{
|
||||||
key: "llm_tokenizer_dir",
|
key: "llm_tokenizer_dir",
|
||||||
value: {string_value:"${llm_tokenizer_dir}"}
|
value: {string_value:"${llm_tokenizer_dir}"}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
key: "model_dir",
|
key: "model_dir",
|
||||||
value: {string_value:"${model_dir}"}
|
value: {string_value:"${model_dir}"}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -37,16 +37,19 @@ input [
|
|||||||
name: "reference_wav"
|
name: "reference_wav"
|
||||||
data_type: TYPE_FP32
|
data_type: TYPE_FP32
|
||||||
dims: [-1]
|
dims: [-1]
|
||||||
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "reference_wav_len"
|
name: "reference_wav_len"
|
||||||
data_type: TYPE_INT32
|
data_type: TYPE_INT32
|
||||||
dims: [1]
|
dims: [1]
|
||||||
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "reference_text"
|
name: "reference_text"
|
||||||
data_type: TYPE_STRING
|
data_type: TYPE_STRING
|
||||||
dims: [1]
|
dims: [1]
|
||||||
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "target_text"
|
name: "target_text"
|
||||||
|
|||||||
394
runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py
Normal file
394
runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py
Normal file
@@ -0,0 +1,394 @@
|
|||||||
|
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions
|
||||||
|
# are met:
|
||||||
|
# * Redistributions of source code must retain the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer.
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer in the
|
||||||
|
# documentation and/or other materials provided with the distribution.
|
||||||
|
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived
|
||||||
|
# from this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||||
|
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||||
|
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||||
|
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Tuple, Optional, Union
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
from matcha.utils.audio import mel_spectrogram
|
||||||
|
|
||||||
|
|
||||||
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_speech_token_string(response_text: str) -> List[int]:
|
||||||
|
"""
|
||||||
|
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
|
||||||
|
"""
|
||||||
|
speech_tokens = response_text.strip().split('><')
|
||||||
|
if len(speech_tokens) > 1:
|
||||||
|
# Add back the missing '<' and '>' for proper parsing
|
||||||
|
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
|
||||||
|
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
|
||||||
|
|
||||||
|
speech_ids = []
|
||||||
|
for token_str in speech_tokens:
|
||||||
|
match = re.match(r'<\|s_(\d+)\|>', token_str)
|
||||||
|
if match:
|
||||||
|
speech_ids.append(int(match.group(1)))
|
||||||
|
return speech_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
"""Triton Python model for Spark TTS.
|
||||||
|
|
||||||
|
This model orchestrates the end-to-end TTS pipeline by coordinating
|
||||||
|
between audio tokenizer, LLM, and vocoder components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize(self, args):
|
||||||
|
"""Initialize the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Dictionary containing model configuration
|
||||||
|
"""
|
||||||
|
self.logger = pb_utils.Logger
|
||||||
|
# Parse model parameters
|
||||||
|
self.model_config = json.loads(args['model_config'])
|
||||||
|
parameters = self.model_config['parameters']
|
||||||
|
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||||
|
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
||||||
|
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
||||||
|
|
||||||
|
# Initialize tokenizer
|
||||||
|
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
|
||||||
|
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
|
||||||
|
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
|
||||||
|
|
||||||
|
self.device = torch.device("cuda")
|
||||||
|
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
||||||
|
|
||||||
|
self.token_frame_rate = 25
|
||||||
|
self.flow_pre_lookahead_len = 3
|
||||||
|
self.token_hop_len = 15
|
||||||
|
|
||||||
|
self.http_client = httpx.AsyncClient()
|
||||||
|
self.api_base = "http://localhost:8000/v1/chat/completions"
|
||||||
|
self.speaker_cache = {}
|
||||||
|
|
||||||
|
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
||||||
|
"""Converts a tensor or list of speech token IDs to a string representation."""
|
||||||
|
if isinstance(speech_tokens, torch.Tensor):
|
||||||
|
# Ensure tensor is on CPU and flattened
|
||||||
|
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
|
||||||
|
|
||||||
|
speech_id_str = ""
|
||||||
|
for token_id in speech_tokens:
|
||||||
|
# Convert token ID back to the speech number N
|
||||||
|
token_num = token_id - ORIGINAL_VOCAB_SIZE
|
||||||
|
speech_id_str += f"<|s_{token_num}|>"
|
||||||
|
return speech_id_str
|
||||||
|
|
||||||
|
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
|
||||||
|
"""
|
||||||
|
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
|
||||||
|
"""
|
||||||
|
full_text = f"{reference_text}{target_text}"
|
||||||
|
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
{"role": "user", "content": full_text},
|
||||||
|
{"role": "assistant", "content": prompt_speech_tokens_str}
|
||||||
|
]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": "trt_engines_bfloat16",
|
||||||
|
"messages": chat,
|
||||||
|
"max_tokens": 750,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"top_k": 50,
|
||||||
|
"repetition_penalty": 1.1,
|
||||||
|
"stop": ["<|eos1|>", "<|eos|>"],
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
line_data = line[len("data: "):].strip()
|
||||||
|
if line_data == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
json_data = json.loads(line_data)
|
||||||
|
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
|
||||||
|
if content:
|
||||||
|
buffer += content
|
||||||
|
while True:
|
||||||
|
match = re.search(r"<\|s_(\d+)\|>", buffer)
|
||||||
|
if not match:
|
||||||
|
break
|
||||||
|
|
||||||
|
token_num = int(match.group(1))
|
||||||
|
final_id = token_num + ORIGINAL_VOCAB_SIZE
|
||||||
|
yield final_id
|
||||||
|
buffer = buffer[match.end():]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process any remaining complete tokens in the buffer after the stream ends
|
||||||
|
while True:
|
||||||
|
match = re.search(r"<\|s_(\d+)\|>", buffer)
|
||||||
|
if not match:
|
||||||
|
break
|
||||||
|
token_num = int(match.group(1))
|
||||||
|
final_id = token_num + ORIGINAL_VOCAB_SIZE
|
||||||
|
yield final_id
|
||||||
|
buffer = buffer[match.end():]
|
||||||
|
|
||||||
|
def forward_audio_tokenizer(self, wav, wav_len):
|
||||||
|
"""Forward pass through the audio tokenizer component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav: Input waveform tensor
|
||||||
|
wav_len: Waveform length tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of global and semantic tokens
|
||||||
|
"""
|
||||||
|
inference_request = pb_utils.InferenceRequest(
|
||||||
|
model_name='audio_tokenizer',
|
||||||
|
requested_output_names=['prompt_speech_tokens'],
|
||||||
|
inputs=[wav, wav_len]
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_response = inference_request.exec()
|
||||||
|
if inference_response.has_error():
|
||||||
|
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||||
|
|
||||||
|
# Extract and convert output tensors
|
||||||
|
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
|
||||||
|
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
|
||||||
|
|
||||||
|
return prompt_speech_tokens
|
||||||
|
|
||||||
|
def forward_speaker_embedding(self, wav):
|
||||||
|
"""Forward pass through the speaker embedding component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav: Input waveform tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prompt speaker embedding tensor
|
||||||
|
"""
|
||||||
|
inference_request = pb_utils.InferenceRequest(
|
||||||
|
model_name='speaker_embedding',
|
||||||
|
requested_output_names=['prompt_spk_embedding'],
|
||||||
|
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_response = inference_request.exec()
|
||||||
|
if inference_response.has_error():
|
||||||
|
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||||
|
|
||||||
|
# Extract and convert output tensors
|
||||||
|
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
|
||||||
|
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
|
||||||
|
|
||||||
|
return prompt_spk_embedding
|
||||||
|
|
||||||
|
async def forward_token2wav(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
target_speech_tokens: torch.Tensor,
|
||||||
|
request_id: str,
|
||||||
|
reference_wav: object,
|
||||||
|
reference_wav_len: object,
|
||||||
|
finalize: bool = None) -> torch.Tensor:
|
||||||
|
"""Forward pass through the vocoder component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: Index of the request
|
||||||
|
target_speech_tokens: Target speech tokens tensor
|
||||||
|
request_id: Request ID
|
||||||
|
reference_wav: Reference waveform tensor
|
||||||
|
reference_wav_len: Reference waveform length tensor
|
||||||
|
finalize: Whether to finalize the request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated waveform tensor
|
||||||
|
"""
|
||||||
|
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||||
|
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||||
|
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
||||||
|
|
||||||
|
# Create and execute inference request
|
||||||
|
inference_request = pb_utils.InferenceRequest(
|
||||||
|
model_name='token2wav_dit',
|
||||||
|
requested_output_names=[
|
||||||
|
"waveform",
|
||||||
|
],
|
||||||
|
inputs=inputs_tensor,
|
||||||
|
request_id=request_id,
|
||||||
|
parameters={"priority": index + 1},
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_response = await inference_request.async_exec()
|
||||||
|
if inference_response.has_error():
|
||||||
|
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||||
|
|
||||||
|
# Extract and convert output waveform
|
||||||
|
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
||||||
|
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def _extract_speech_feat(self, speech):
|
||||||
|
speech_feat = mel_spectrogram(
|
||||||
|
speech,
|
||||||
|
n_fft=1920,
|
||||||
|
num_mels=80,
|
||||||
|
sampling_rate=24000,
|
||||||
|
hop_size=480,
|
||||||
|
win_size=1920,
|
||||||
|
fmin=0,
|
||||||
|
fmax=8000).squeeze(
|
||||||
|
dim=0).transpose(
|
||||||
|
0,
|
||||||
|
1).to(
|
||||||
|
self.device)
|
||||||
|
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||||
|
return speech_feat
|
||||||
|
|
||||||
|
async def _process_request(self, request):
|
||||||
|
request_id = request.request_id()
|
||||||
|
|
||||||
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||||
|
reference_text = reference_text[0][0].decode('utf-8')
|
||||||
|
|
||||||
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||||
|
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||||
|
|
||||||
|
if reference_text not in self.speaker_cache:
|
||||||
|
self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0)
|
||||||
|
prompt_speech_tokens = self.speaker_cache[reference_text]
|
||||||
|
|
||||||
|
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||||
|
target_text = target_text[0][0].decode('utf-8')
|
||||||
|
|
||||||
|
if self.decoupled:
|
||||||
|
response_sender = request.get_response_sender()
|
||||||
|
|
||||||
|
semantic_token_ids_arr = []
|
||||||
|
token_offset, chunk_index = 0, 0
|
||||||
|
start_time = time.time()
|
||||||
|
this_token_hop_len = self.token_hop_len
|
||||||
|
async for generated_ids in self.forward_llm_async(
|
||||||
|
target_text=target_text,
|
||||||
|
reference_text=reference_text,
|
||||||
|
prompt_speech_tokens=prompt_speech_tokens,
|
||||||
|
):
|
||||||
|
if not generated_ids:
|
||||||
|
break
|
||||||
|
semantic_token_ids_arr.append(generated_ids)
|
||||||
|
while True:
|
||||||
|
pending_num = len(semantic_token_ids_arr) - token_offset
|
||||||
|
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
||||||
|
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
||||||
|
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
|
sub_tts_speech = await self.forward_token2wav(
|
||||||
|
chunk_index,
|
||||||
|
this_tts_speech_token, request_id, wav, wav_len, False
|
||||||
|
)
|
||||||
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
|
response_sender.send(inference_response)
|
||||||
|
|
||||||
|
token_offset += this_token_hop_len
|
||||||
|
|
||||||
|
if self.dynamic_chunk_strategy == "exponential":
|
||||||
|
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
||||||
|
elif self.dynamic_chunk_strategy == "equal":
|
||||||
|
this_token_hop_len = self.token_hop_len
|
||||||
|
elif self.dynamic_chunk_strategy == "time_based":
|
||||||
|
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
||||||
|
cost_time = time.time() - start_time
|
||||||
|
duration = token_offset / self.token_frame_rate
|
||||||
|
if chunk_index > 0 and cost_time > 0:
|
||||||
|
avg_chunk_processing_time = cost_time / (chunk_index + 1)
|
||||||
|
if avg_chunk_processing_time > 0:
|
||||||
|
multiples = (duration - cost_time) / avg_chunk_processing_time
|
||||||
|
next_pending_num = len(semantic_token_ids_arr) - token_offset
|
||||||
|
if multiples > 4:
|
||||||
|
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
|
||||||
|
elif multiples > 2:
|
||||||
|
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
|
||||||
|
else:
|
||||||
|
this_token_hop_len = self.token_hop_len
|
||||||
|
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
||||||
|
chunk_index += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
|
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
||||||
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
|
response_sender.send(inference_response)
|
||||||
|
|
||||||
|
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Offline TTS mode is not supported")
|
||||||
|
|
||||||
|
async def execute(self, requests):
|
||||||
|
"""Execute inference on the batched requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requests: List of inference requests
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of inference responses containing generated audio
|
||||||
|
"""
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(self._process_request(request))
|
||||||
|
for request in requests
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
self.logger.log_info("Finalizing CosyVoice DIT model")
|
||||||
|
if hasattr(self, "http_client"):
|
||||||
|
asyncio.run(self.http_client.aclose())
|
||||||
73
runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt
Normal file
73
runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
name: "cosyvoice2_dit"
|
||||||
|
backend: "python"
|
||||||
|
max_batch_size: ${triton_max_batch_size}
|
||||||
|
dynamic_batching {
|
||||||
|
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||||
|
}
|
||||||
|
model_transaction_policy {
|
||||||
|
decoupled: ${decoupled_mode}
|
||||||
|
}
|
||||||
|
parameters [
|
||||||
|
{
|
||||||
|
key: "llm_tokenizer_dir",
|
||||||
|
value: {string_value:"${llm_tokenizer_dir}"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: "model_dir",
|
||||||
|
value: {string_value:"${model_dir}"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "reference_wav"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [-1]
|
||||||
|
optional: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reference_wav_len"
|
||||||
|
data_type: TYPE_INT32
|
||||||
|
dims: [1]
|
||||||
|
optional: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reference_text"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [1]
|
||||||
|
optional: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "target_text"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [1]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "waveform"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
count: ${bls_instance_num}
|
||||||
|
kind: KIND_CPU
|
||||||
|
}
|
||||||
|
]
|
||||||
153
runtime/triton_trtllm/model_repo/speaker_embedding/1/model.py
Normal file
153
runtime/triton_trtllm/model_repo/speaker_embedding/1/model.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions
|
||||||
|
# are met:
|
||||||
|
# * Redistributions of source code must retain the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer.
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer in the
|
||||||
|
# documentation and/or other materials provided with the distribution.
|
||||||
|
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived
|
||||||
|
# from this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||||
|
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||||
|
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||||
|
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from torch.utils.dlpack import to_dlpack
|
||||||
|
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
||||||
|
from cosyvoice.utils.common import TrtContextWrapper
|
||||||
|
import onnxruntime
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
"""Triton Python model for audio tokenization.
|
||||||
|
|
||||||
|
This model takes reference audio input and extracts semantic tokens
|
||||||
|
using s3tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize(self, args):
|
||||||
|
"""Initialize the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Dictionary containing model configuration
|
||||||
|
"""
|
||||||
|
# Parse model parameters
|
||||||
|
parameters = json.loads(args['model_config'])['parameters']
|
||||||
|
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||||
|
|
||||||
|
self.device = torch.device("cuda")
|
||||||
|
|
||||||
|
model_dir = model_params["model_dir"]
|
||||||
|
gpu = "l20"
|
||||||
|
enable_trt = True
|
||||||
|
if enable_trt:
|
||||||
|
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||||
|
f'{model_dir}/campplus.onnx',
|
||||||
|
1,
|
||||||
|
False)
|
||||||
|
else:
|
||||||
|
campplus_model = f'{model_dir}/campplus.onnx'
|
||||||
|
option = onnxruntime.SessionOptions()
|
||||||
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
option.intra_op_num_threads = 1
|
||||||
|
self.spk_model = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
|
||||||
|
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
|
||||||
|
trt_kwargs = self.get_spk_trt_kwargs()
|
||||||
|
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
|
||||||
|
import tensorrt as trt
|
||||||
|
with open(spk_model, 'rb') as f:
|
||||||
|
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||||
|
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
|
||||||
|
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||||
|
|
||||||
|
def get_spk_trt_kwargs(self):
|
||||||
|
min_shape = [(1, 4, 80)]
|
||||||
|
opt_shape = [(1, 500, 80)]
|
||||||
|
max_shape = [(1, 3000, 80)]
|
||||||
|
input_names = ["input"]
|
||||||
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
|
|
||||||
|
def _extract_spk_embedding(self, speech):
|
||||||
|
feat = kaldi.fbank(speech,
|
||||||
|
num_mel_bins=80,
|
||||||
|
dither=0,
|
||||||
|
sample_frequency=16000)
|
||||||
|
spk_feat = feat - feat.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
if isinstance(self.spk_model, onnxruntime.InferenceSession):
|
||||||
|
embedding = self.spk_model.run(
|
||||||
|
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
||||||
|
)[0].flatten().tolist()
|
||||||
|
embedding = torch.tensor([embedding]).to(self.device)
|
||||||
|
else:
|
||||||
|
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
|
||||||
|
# NOTE need to synchronize when switching stream
|
||||||
|
with torch.cuda.device(self.device):
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
|
||||||
|
batch_size = spk_feat.size(0)
|
||||||
|
|
||||||
|
with stream:
|
||||||
|
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
|
||||||
|
embedding = torch.empty((batch_size, 192), device=spk_feat.device)
|
||||||
|
|
||||||
|
data_ptrs = [spk_feat.contiguous().data_ptr(),
|
||||||
|
embedding.contiguous().data_ptr()]
|
||||||
|
for i, j in enumerate(data_ptrs):
|
||||||
|
|
||||||
|
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||||
|
# run trt engine
|
||||||
|
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
self.spk_model.release_estimator(spk_model, stream)
|
||||||
|
|
||||||
|
return embedding.half()
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""Execute inference on the batched requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requests: List of inference requests
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of inference responses containing tokenized outputs
|
||||||
|
"""
|
||||||
|
responses = []
|
||||||
|
# Process each request in batch
|
||||||
|
for request in requests:
|
||||||
|
# Extract input tensors
|
||||||
|
wav_array = pb_utils.get_input_tensor_by_name(
|
||||||
|
request, "reference_wav").as_numpy()
|
||||||
|
wav_array = torch.from_numpy(wav_array).to(self.device)
|
||||||
|
|
||||||
|
embedding = self._extract_spk_embedding(wav_array)
|
||||||
|
|
||||||
|
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
|
||||||
|
"prompt_spk_embedding", to_dlpack(embedding))
|
||||||
|
inference_response = pb_utils.InferenceResponse(
|
||||||
|
output_tensors=[prompt_spk_embedding_tensor])
|
||||||
|
|
||||||
|
responses.append(inference_response)
|
||||||
|
|
||||||
|
return responses
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
name: "speaker_embedding"
|
||||||
|
backend: "python"
|
||||||
|
max_batch_size: ${triton_max_batch_size}
|
||||||
|
dynamic_batching {
|
||||||
|
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||||
|
}
|
||||||
|
parameters [
|
||||||
|
{
|
||||||
|
key: "model_dir",
|
||||||
|
value: {string_value:"${model_dir}"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "reference_wav"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [-1]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "prompt_spk_embedding"
|
||||||
|
data_type: TYPE_FP16
|
||||||
|
dims: [-1]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
count: 1
|
||||||
|
kind: KIND_CPU
|
||||||
|
}
|
||||||
|
]
|
||||||
@@ -28,26 +28,30 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Dict
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.dlpack import to_dlpack
|
from torch.utils.dlpack import to_dlpack
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
import triton_python_backend_utils as pb_utils
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
from hyperpyyaml import load_hyperpyyaml
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
|
from cosyvoice.utils.common import fade_in_out
|
||||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
||||||
from cosyvoice.utils.common import TrtContextWrapper
|
from cosyvoice.utils.common import TrtContextWrapper
|
||||||
|
from collections import defaultdict
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ORIGINAL_VOCAB_SIZE = 151663
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
|
||||||
class CosyVoice2:
|
class CosyVoice2:
|
||||||
|
|
||||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
|
||||||
|
|
||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
@@ -57,7 +61,7 @@ class CosyVoice2:
|
|||||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||||
with open(hyper_yaml_path, 'r') as f:
|
with open(hyper_yaml_path, 'r') as f:
|
||||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||||
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
|
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
|
||||||
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
|
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
|
||||||
if load_jit:
|
if load_jit:
|
||||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
@@ -73,14 +77,22 @@ class CosyVoice2Model:
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
flow: torch.nn.Module,
|
flow: torch.nn.Module,
|
||||||
hift: torch.nn.Module,
|
hift: torch.nn.Module,
|
||||||
fp16: bool = False):
|
fp16: bool = False,
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device: str = 'cuda'):
|
||||||
|
self.device = device
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.hift = hift
|
self.hift = hift
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
if self.fp16 is True:
|
if self.fp16 is True:
|
||||||
self.flow.half()
|
self.flow.half()
|
||||||
|
|
||||||
|
# streaming tts config
|
||||||
|
self.token_hop_len = 25
|
||||||
|
self.mel_cache_len = 8
|
||||||
|
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||||
|
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||||
|
self.hift_cache_dict = defaultdict(lambda: None)
|
||||||
|
|
||||||
def load_jit(self, flow_encoder_model):
|
def load_jit(self, flow_encoder_model):
|
||||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
self.flow.encoder = flow_encoder
|
self.flow.encoder = flow_encoder
|
||||||
@@ -111,6 +123,42 @@ class CosyVoice2Model:
|
|||||||
input_names = ["x", "mask", "mu", "cond"]
|
input_names = ["x", "mask", "mu", "cond"]
|
||||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
|
|
||||||
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||||
|
with torch.cuda.amp.autocast(self.fp16):
|
||||||
|
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||||
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_token=prompt_token.to(self.device),
|
||||||
|
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_feat=prompt_feat.to(self.device),
|
||||||
|
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=embedding.to(self.device),
|
||||||
|
streaming=stream,
|
||||||
|
finalize=finalize)
|
||||||
|
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||||
|
# append hift cache
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||||
|
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||||
|
else:
|
||||||
|
hift_cache_source = torch.zeros(1, 1, 0)
|
||||||
|
# keep overlap mel and hift cache
|
||||||
|
if finalize is False:
|
||||||
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
|
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||||
|
'source': tts_source[:, :, -self.source_cache_len:],
|
||||||
|
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||||
|
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||||
|
else:
|
||||||
|
if speed != 1.0:
|
||||||
|
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||||
|
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||||
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
|
return tts_speech
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
"""Triton Python model for vocoder.
|
"""Triton Python model for vocoder.
|
||||||
@@ -131,13 +179,19 @@ class TritonPythonModel:
|
|||||||
model_dir = model_params["model_dir"]
|
model_dir = model_params["model_dir"]
|
||||||
|
|
||||||
# Initialize device and vocoder
|
# Initialize device and vocoder
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
||||||
|
|
||||||
self.token2wav_model = CosyVoice2(
|
self.token2wav_model = CosyVoice2(
|
||||||
model_dir, load_jit=True, load_trt=True, fp16=True
|
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
spk_info_path = os.path.join(model_dir, "spk2info.pt")
|
||||||
|
if not os.path.exists(spk_info_path):
|
||||||
|
raise ValueError(f"spk2info.pt not found in {model_dir}")
|
||||||
|
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||||
|
self.default_spk_info = spk_info["001"]
|
||||||
|
|
||||||
logger.info("Token2Wav initialized successfully")
|
logger.info("Token2Wav initialized successfully")
|
||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
@@ -153,38 +207,66 @@ class TritonPythonModel:
|
|||||||
# Process each request in batch
|
# Process each request in batch
|
||||||
for request in requests:
|
for request in requests:
|
||||||
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
||||||
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
|
|
||||||
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
|
|
||||||
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
|
|
||||||
|
|
||||||
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
|
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
|
||||||
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
|
|
||||||
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
|
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
|
||||||
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
|
if prompt_speech_tokens_tensor is not None:
|
||||||
|
prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
|
||||||
|
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
|
||||||
|
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
|
||||||
|
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
|
||||||
|
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
|
||||||
|
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
|
||||||
|
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
|
else:
|
||||||
|
prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
|
||||||
|
prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
|
||||||
|
prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
|
||||||
|
|
||||||
# shift the speech tokens according to the original vocab size
|
# shift the speech tokens according to the original vocab size
|
||||||
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
|
||||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
|
|
||||||
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
|
||||||
token=target_speech_tokens,
|
token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
|
||||||
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
if token_offset is not None:
|
||||||
self.device
|
token_offset = token_offset.as_numpy().item()
|
||||||
),
|
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||||
prompt_token=prompt_speech_tokens,
|
if not finalize:
|
||||||
prompt_token_len=torch.tensor(
|
stream = True
|
||||||
[prompt_speech_tokens.shape[1]], dtype=torch.int32
|
else:
|
||||||
).to(self.device),
|
stream = False
|
||||||
prompt_feat=prompt_speech_feat,
|
request_id = request.request_id()
|
||||||
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
|
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
|
||||||
embedding=prompt_spk_embedding,
|
prompt_token=prompt_speech_tokens,
|
||||||
streaming=False,
|
prompt_feat=prompt_speech_feat,
|
||||||
finalize=True,
|
embedding=prompt_spk_embedding,
|
||||||
)
|
token_offset=token_offset,
|
||||||
|
uuid=request_id,
|
||||||
|
stream=stream,
|
||||||
|
finalize=finalize)
|
||||||
|
if finalize:
|
||||||
|
self.token2wav_model.model.hift_cache_dict.pop(request_id)
|
||||||
|
|
||||||
audio_hat, _ = self.token2wav_model.model.hift.inference(
|
else:
|
||||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
||||||
)
|
token=target_speech_tokens,
|
||||||
|
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
||||||
|
self.device
|
||||||
|
),
|
||||||
|
prompt_token=prompt_speech_tokens,
|
||||||
|
prompt_token_len=torch.tensor(
|
||||||
|
[prompt_speech_tokens.shape[1]], dtype=torch.int32
|
||||||
|
).to(self.device),
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=prompt_spk_embedding,
|
||||||
|
streaming=False,
|
||||||
|
finalize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_hat, _ = self.token2wav_model.model.hift.inference(
|
||||||
|
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ dynamic_batching {
|
|||||||
}
|
}
|
||||||
parameters [
|
parameters [
|
||||||
{
|
{
|
||||||
key: "model_dir",
|
key: "model_dir",
|
||||||
value: {string_value:"${model_dir}"}
|
value: {string_value:"${model_dir}"}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -35,16 +35,33 @@ input [
|
|||||||
name: "prompt_speech_tokens"
|
name: "prompt_speech_tokens"
|
||||||
data_type: TYPE_INT32
|
data_type: TYPE_INT32
|
||||||
dims: [-1]
|
dims: [-1]
|
||||||
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "prompt_speech_feat"
|
name: "prompt_speech_feat"
|
||||||
data_type: TYPE_FP16
|
data_type: TYPE_FP16
|
||||||
dims: [-1, 80]
|
dims: [-1, 80]
|
||||||
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "prompt_spk_embedding"
|
name: "prompt_spk_embedding"
|
||||||
data_type: TYPE_FP16
|
data_type: TYPE_FP16
|
||||||
dims: [-1]
|
dims: [-1]
|
||||||
|
optional: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token_offset"
|
||||||
|
data_type: TYPE_INT32
|
||||||
|
dims: [ 1 ]
|
||||||
|
reshape: { shape: [ ] }
|
||||||
|
optional: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "finalize"
|
||||||
|
data_type: TYPE_BOOL
|
||||||
|
dims: [ 1 ]
|
||||||
|
reshape: { shape: [ ] }
|
||||||
|
optional: true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
output [
|
output [
|
||||||
|
|||||||
142
runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py
Normal file
142
runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions
|
||||||
|
# are met:
|
||||||
|
# * Redistributions of source code must retain the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer.
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright
|
||||||
|
# notice, this list of conditions and the following disclaimer in the
|
||||||
|
# documentation and/or other materials provided with the distribution.
|
||||||
|
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived
|
||||||
|
# from this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||||
|
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||||
|
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||||
|
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.dlpack import to_dlpack
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
|
from cosyvoice.utils.common import fade_in_out
|
||||||
|
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
||||||
|
from cosyvoice.utils.common import TrtContextWrapper
|
||||||
|
from collections import defaultdict
|
||||||
|
import numpy as np
|
||||||
|
from .token2wav_dit import CosyVoice2_Token2Wav
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
||||||
|
"""
|
||||||
|
Generates a unique ID for a torch.Tensor.
|
||||||
|
Tensors with the same elements and properties will have the same ID.
|
||||||
|
"""
|
||||||
|
# Convert tensor to a byte string
|
||||||
|
tensor_bytes = tensor.numpy().tobytes()
|
||||||
|
|
||||||
|
# Create a SHA-256 hash of the byte string
|
||||||
|
hasher = hashlib.sha256()
|
||||||
|
hasher.update(tensor_bytes)
|
||||||
|
|
||||||
|
return hasher.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
"""Triton Python model for vocoder.
|
||||||
|
|
||||||
|
This model takes global and semantic tokens as input and generates audio waveforms
|
||||||
|
using the BiCodec vocoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize(self, args):
|
||||||
|
"""Initialize the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Dictionary containing model configuration
|
||||||
|
"""
|
||||||
|
# Parse model parameters
|
||||||
|
parameters = json.loads(args['model_config'])['parameters']
|
||||||
|
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
||||||
|
model_dir = model_params["model_dir"]
|
||||||
|
|
||||||
|
# Initialize device and vocoder
|
||||||
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
||||||
|
|
||||||
|
# FIXME: device id settings
|
||||||
|
self.token2wav_model = CosyVoice2_Token2Wav(
|
||||||
|
model_dir, enable_trt=True, streaming=True
|
||||||
|
)
|
||||||
|
logger.info("Token2Wav initialized successfully")
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""Execute inference on the batched requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requests: List of inference requests
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of inference responses containing generated waveforms
|
||||||
|
"""
|
||||||
|
responses = []
|
||||||
|
# Process each request in batch
|
||||||
|
for request in requests:
|
||||||
|
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
||||||
|
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)
|
||||||
|
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
|
target_speech_tokens = target_speech_tokens.squeeze().tolist()
|
||||||
|
|
||||||
|
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||||
|
|
||||||
|
request_id = request.request_id()
|
||||||
|
|
||||||
|
wav_array = pb_utils.get_input_tensor_by_name(
|
||||||
|
request, "reference_wav").as_numpy()
|
||||||
|
wav_len = pb_utils.get_input_tensor_by_name(
|
||||||
|
request, "reference_wav_len").as_numpy().item()
|
||||||
|
|
||||||
|
wav_array = torch.from_numpy(wav_array)
|
||||||
|
wav = wav_array[:, :wav_len].squeeze(0)
|
||||||
|
|
||||||
|
spk_id = get_spk_id_from_prompt_audio(wav)
|
||||||
|
|
||||||
|
audio_hat = self.token2wav_model.forward_streaming(
|
||||||
|
target_speech_tokens, finalize, request_id=request_id,
|
||||||
|
speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
||||||
|
outputs.append(wav_tensor)
|
||||||
|
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
|
||||||
|
responses.append(inference_response)
|
||||||
|
|
||||||
|
return responses
|
||||||
@@ -0,0 +1,510 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Example Usage
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python3 token2wav.py --enable-trt || exit 1
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
|
||||||
|
from flashcosyvoice.modules.hifigan import HiFTGenerator
|
||||||
|
from flashcosyvoice.utils.audio import mel_spectrogram
|
||||||
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
import onnxruntime
|
||||||
|
import s3tokenizer
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from datasets import load_dataset
|
||||||
|
import torchaudio
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
|
|
||||||
|
|
||||||
|
def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor):
|
||||||
|
"""perform fade_in_out in tensor style
|
||||||
|
"""
|
||||||
|
mel_overlap_len = int(window.shape[0] / 2)
|
||||||
|
fade_in_mel = fade_in_mel.clone()
|
||||||
|
fade_in_mel[..., :mel_overlap_len] = \
|
||||||
|
fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
||||||
|
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
||||||
|
return fade_in_mel
|
||||||
|
|
||||||
|
|
||||||
|
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
|
||||||
|
import tensorrt as trt
|
||||||
|
logging.info("Converting onnx to trt...")
|
||||||
|
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||||
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
|
builder = trt.Builder(logger)
|
||||||
|
network = builder.create_network(network_flags)
|
||||||
|
parser = trt.OnnxParser(network, logger)
|
||||||
|
config = builder.create_builder_config()
|
||||||
|
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||||
|
if dtype == torch.float16:
|
||||||
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
|
|
||||||
|
profile = builder.create_optimization_profile()
|
||||||
|
# load onnx model
|
||||||
|
with open(onnx_model, "rb") as f:
|
||||||
|
if not parser.parse(f.read()):
|
||||||
|
for error in range(parser.num_errors):
|
||||||
|
print(parser.get_error(error))
|
||||||
|
raise ValueError('failed to parse {}'.format(onnx_model))
|
||||||
|
# set input shapes
|
||||||
|
for i in range(len(trt_kwargs['input_names'])):
|
||||||
|
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
||||||
|
if dtype == torch.float16:
|
||||||
|
tensor_dtype = trt.DataType.HALF
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
tensor_dtype = trt.DataType.BF16
|
||||||
|
elif dtype == torch.float32:
|
||||||
|
tensor_dtype = trt.DataType.FLOAT
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid dtype {}'.format(dtype))
|
||||||
|
# set input and output data type
|
||||||
|
for i in range(network.num_inputs):
|
||||||
|
input_tensor = network.get_input(i)
|
||||||
|
input_tensor.dtype = tensor_dtype
|
||||||
|
for i in range(network.num_outputs):
|
||||||
|
output_tensor = network.get_output(i)
|
||||||
|
output_tensor.dtype = tensor_dtype
|
||||||
|
config.add_optimization_profile(profile)
|
||||||
|
engine_bytes = builder.build_serialized_network(network, config)
|
||||||
|
# save trt engine
|
||||||
|
with open(trt_model, "wb") as f:
|
||||||
|
f.write(engine_bytes)
|
||||||
|
logging.info("Succesfully convert onnx to trt...")
|
||||||
|
|
||||||
|
|
||||||
|
class TrtContextWrapper:
|
||||||
|
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||||
|
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||||
|
self.trt_engine = trt_engine
|
||||||
|
self.device = device
|
||||||
|
for _ in range(trt_concurrent):
|
||||||
|
trt_context = trt_engine.create_execution_context()
|
||||||
|
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
|
||||||
|
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||||
|
self.trt_context_pool.put([trt_context, trt_stream])
|
||||||
|
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||||
|
|
||||||
|
def acquire_estimator(self):
|
||||||
|
return self.trt_context_pool.get(), self.trt_engine
|
||||||
|
|
||||||
|
def release_estimator(self, context, stream):
|
||||||
|
self.trt_context_pool.put([context, stream])
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||||
|
def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
|
||||||
|
super().__init__()
|
||||||
|
self.device_id = device_id
|
||||||
|
self.device = f"cuda:{device_id}"
|
||||||
|
with open(f"{model_dir}/flow.yaml", "r") as f:
|
||||||
|
configs = load_hyperpyyaml(f)
|
||||||
|
self.flow = configs['flow']
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
self.flow.to(self.dtype)
|
||||||
|
|
||||||
|
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
||||||
|
self.flow.to(self.device).eval()
|
||||||
|
|
||||||
|
self.hift = HiFTGenerator()
|
||||||
|
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
|
||||||
|
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||||
|
self.hift.to(self.device).eval()
|
||||||
|
|
||||||
|
option = onnxruntime.SessionOptions()
|
||||||
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
option.intra_op_num_threads = 1
|
||||||
|
self.spk_model = onnxruntime.InferenceSession(
|
||||||
|
f"{model_dir}/campplus.onnx", sess_options=option,
|
||||||
|
providers=["CPUExecutionProvider"])
|
||||||
|
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
|
||||||
|
|
||||||
|
gpu = "l20"
|
||||||
|
if enable_trt:
|
||||||
|
if streaming:
|
||||||
|
self.load_trt(
|
||||||
|
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
|
||||||
|
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
|
||||||
|
1,
|
||||||
|
self.dtype, streaming
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.load_trt(
|
||||||
|
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
|
||||||
|
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
||||||
|
1,
|
||||||
|
self.dtype
|
||||||
|
)
|
||||||
|
self.load_spk_trt(
|
||||||
|
f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||||
|
f'{model_dir}/campplus.onnx',
|
||||||
|
1,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.streaming_flow_cache = {}
|
||||||
|
self.speaker_cache = {}
|
||||||
|
|
||||||
|
self.mel_cache_len = 8 # hard-coded, 160ms
|
||||||
|
self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave
|
||||||
|
self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
|
||||||
|
|
||||||
|
# hifigan cache for streaming tts
|
||||||
|
self.hift_cache_dict = {}
|
||||||
|
|
||||||
|
def forward_spk_embedding(self, spk_feat):
|
||||||
|
if isinstance(self.spk_model, onnxruntime.InferenceSession):
|
||||||
|
return self.spk_model.run(
|
||||||
|
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
||||||
|
)[0].flatten().tolist()
|
||||||
|
else:
|
||||||
|
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
|
||||||
|
# NOTE need to synchronize when switching stream
|
||||||
|
with torch.cuda.device(self.device_id):
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
|
||||||
|
batch_size = spk_feat.size(0)
|
||||||
|
|
||||||
|
with stream:
|
||||||
|
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
|
||||||
|
output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
|
||||||
|
|
||||||
|
data_ptrs = [spk_feat.contiguous().data_ptr(),
|
||||||
|
output_tensor.contiguous().data_ptr()]
|
||||||
|
for i, j in enumerate(data_ptrs):
|
||||||
|
|
||||||
|
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||||
|
# run trt engine
|
||||||
|
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
self.spk_model.release_estimator(spk_model, stream)
|
||||||
|
|
||||||
|
return output_tensor.cpu().numpy().flatten().tolist()
|
||||||
|
|
||||||
|
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
|
||||||
|
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
|
||||||
|
trt_kwargs = self.get_spk_trt_kwargs()
|
||||||
|
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32)
|
||||||
|
import tensorrt as trt
|
||||||
|
with open(spk_model, 'rb') as f:
|
||||||
|
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||||
|
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
|
||||||
|
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||||
|
|
||||||
|
def get_spk_trt_kwargs(self):
|
||||||
|
min_shape = [(1, 4, 80)]
|
||||||
|
opt_shape = [(1, 500, 80)]
|
||||||
|
max_shape = [(1, 3000, 80)]
|
||||||
|
input_names = ["input"]
|
||||||
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
|
|
||||||
|
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False):
|
||||||
|
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||||
|
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||||
|
opt_batch_size = 2
|
||||||
|
max_batch_size = 16
|
||||||
|
if streaming:
|
||||||
|
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
|
||||||
|
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
|
||||||
|
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
|
||||||
|
del self.flow.decoder.estimator
|
||||||
|
import tensorrt as trt
|
||||||
|
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||||
|
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||||
|
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||||
|
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||||
|
|
||||||
|
def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
|
||||||
|
if streaming:
|
||||||
|
min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
|
||||||
|
opt_shape = [
|
||||||
|
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500),
|
||||||
|
(opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2),
|
||||||
|
(16, opt_batch_size * 2, 8, 100, 128)
|
||||||
|
]
|
||||||
|
max_shape = [
|
||||||
|
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000),
|
||||||
|
(max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2),
|
||||||
|
(16, max_batch_size * 2, 8, 1000, 128)
|
||||||
|
]
|
||||||
|
input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
|
||||||
|
else:
|
||||||
|
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
|
||||||
|
opt_shape = [
|
||||||
|
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500),
|
||||||
|
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80)
|
||||||
|
]
|
||||||
|
max_shape = [
|
||||||
|
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000),
|
||||||
|
(max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80)
|
||||||
|
]
|
||||||
|
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
|
||||||
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
|
|
||||||
|
def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
|
||||||
|
prompt_speech_tokens_list, prompt_speech_mels_list = [], []
|
||||||
|
for audio in prompt_audios_list:
|
||||||
|
assert len(audio.shape) == 1
|
||||||
|
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
|
||||||
|
prompt_speech_mels_list.append(log_mel)
|
||||||
|
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
|
||||||
|
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
|
||||||
|
prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
|
||||||
|
)
|
||||||
|
for i in range(len(prompt_speech_tokens)):
|
||||||
|
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
||||||
|
prompt_speech_tokens_list.append(speech_tokens_i)
|
||||||
|
return prompt_speech_tokens_list
|
||||||
|
|
||||||
|
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
||||||
|
spk_emb_for_flow = []
|
||||||
|
for audio in prompt_audios_list:
|
||||||
|
assert len(audio.shape) == 1
|
||||||
|
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
|
||||||
|
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
|
||||||
|
spk_emb = self.forward_spk_embedding(spk_feat)
|
||||||
|
|
||||||
|
spk_emb_for_flow.append(spk_emb)
|
||||||
|
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
||||||
|
if self.dtype != torch.float32:
|
||||||
|
spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
|
||||||
|
return spk_emb_for_flow
|
||||||
|
|
||||||
|
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
||||||
|
prompt_mels_for_flow = []
|
||||||
|
prompt_mels_lens_for_flow = []
|
||||||
|
for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
|
||||||
|
assert len(audio.shape) == 1
|
||||||
|
audio = audio.unsqueeze(0)
|
||||||
|
if sample_rate != 24000:
|
||||||
|
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
|
||||||
|
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
|
||||||
|
mel_len = mel.shape[0]
|
||||||
|
prompt_mels_for_flow.append(mel)
|
||||||
|
prompt_mels_lens_for_flow.append(mel_len)
|
||||||
|
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
prompt_mels_for_flow, batch_first=True, padding_value=0
|
||||||
|
) # [B, T', num_mels=80]
|
||||||
|
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
||||||
|
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
||||||
|
|
||||||
|
def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
|
||||||
|
generated_speech_tokens_list: list[list[int]],
|
||||||
|
prompt_mels_for_flow: torch.Tensor,
|
||||||
|
prompt_mels_lens_for_flow: torch.Tensor,
|
||||||
|
spk_emb_for_flow: torch.Tensor):
|
||||||
|
batch_size = prompt_mels_for_flow.shape[0]
|
||||||
|
flow_inputs = []
|
||||||
|
flow_inputs_lens = []
|
||||||
|
for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
|
||||||
|
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
|
||||||
|
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
|
||||||
|
|
||||||
|
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
|
||||||
|
flow_inputs_lens = torch.tensor(flow_inputs_lens)
|
||||||
|
|
||||||
|
with torch.amp.autocast(self.device, dtype=torch.float16):
|
||||||
|
generated_mels, generated_mels_lens = self.flow.inference(
|
||||||
|
flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
|
||||||
|
prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10
|
||||||
|
)
|
||||||
|
|
||||||
|
return generated_mels, generated_mels_lens
|
||||||
|
|
||||||
|
def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
|
||||||
|
batch_size = generated_mels.shape[0]
|
||||||
|
generated_wavs = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
|
||||||
|
wav, _ = self.hift(speech_feat=mel)
|
||||||
|
generated_wavs.append(wav)
|
||||||
|
return generated_wavs
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward(
|
||||||
|
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||||
|
):
|
||||||
|
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
generated_mels, generated_mels_lens = self.forward_flow(
|
||||||
|
prompt_speech_tokens_list, generated_speech_tokens_list,
|
||||||
|
prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
||||||
|
return generated_wavs
|
||||||
|
|
||||||
|
def prepare_prompt_audio(
|
||||||
|
self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||||
|
):
|
||||||
|
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
||||||
|
|
||||||
|
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
||||||
|
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||||
|
|
||||||
|
def get_prompt_audio_cache_for_streaming_tts(
|
||||||
|
self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||||
|
):
|
||||||
|
assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts"
|
||||||
|
for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list):
|
||||||
|
prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3])
|
||||||
|
prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0)
|
||||||
|
|
||||||
|
cache = self.flow.setup_cache(
|
||||||
|
prompt_speech_tokens_tensor.to(self.device),
|
||||||
|
prompt_mels_for_flow.to(self.device),
|
||||||
|
spk_emb_for_flow.to(self.device),
|
||||||
|
n_timesteps=10
|
||||||
|
)
|
||||||
|
new_cache = {k: v.clone() for k, v in cache.items()}
|
||||||
|
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
|
||||||
|
return new_cache
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward_streaming(
|
||||||
|
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
||||||
|
):
|
||||||
|
if speaker_id not in self.speaker_cache:
|
||||||
|
assert prompt_audio is not None, "prompt_audio is required for new speaker"
|
||||||
|
assert prompt_audio_sample_rate == 16000
|
||||||
|
|
||||||
|
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate])
|
||||||
|
|
||||||
|
token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0]))
|
||||||
|
prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
|
||||||
|
prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
|
||||||
|
|
||||||
|
prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
|
||||||
|
|
||||||
|
cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||||
|
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
|
||||||
|
|
||||||
|
if request_id not in self.streaming_flow_cache:
|
||||||
|
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
||||||
|
self.hift_cache_dict[request_id] = dict(
|
||||||
|
mel=torch.zeros(1, 80, 0, device='cuda'),
|
||||||
|
source=torch.zeros(1, 1, 0, device='cuda'),
|
||||||
|
speech=torch.zeros(1, 0, device='cuda'),
|
||||||
|
)
|
||||||
|
|
||||||
|
current_request_cache = self.streaming_flow_cache[request_id]
|
||||||
|
|
||||||
|
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
|
||||||
|
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
||||||
|
|
||||||
|
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
|
||||||
|
token=generated_speech_tokens,
|
||||||
|
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
|
||||||
|
cache=current_request_cache,
|
||||||
|
last_chunk=last_chunk,
|
||||||
|
n_timesteps=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
|
||||||
|
|
||||||
|
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
|
||||||
|
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
|
||||||
|
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
|
||||||
|
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
|
||||||
|
], dim=4)
|
||||||
|
|
||||||
|
hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
|
||||||
|
hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
|
||||||
|
hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
|
||||||
|
mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone()
|
||||||
|
|
||||||
|
speech, source = self.hift(mel, hift_cache_source)
|
||||||
|
|
||||||
|
# overlap speech smooth
|
||||||
|
if hift_cache_speech.shape[-1] > 0:
|
||||||
|
speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
|
||||||
|
|
||||||
|
# update vocoder cache
|
||||||
|
self.hift_cache_dict[request_id] = dict(
|
||||||
|
mel=mel[..., -self.mel_cache_len:].clone().detach(),
|
||||||
|
source=source[:, :, -self.source_cache_len:].clone().detach(),
|
||||||
|
speech=speech[:, -self.source_cache_len:].clone().detach(),
|
||||||
|
)
|
||||||
|
if not last_chunk:
|
||||||
|
speech = speech[:, :-self.source_cache_len]
|
||||||
|
|
||||||
|
if last_chunk:
|
||||||
|
assert request_id in self.streaming_flow_cache
|
||||||
|
self.streaming_flow_cache.pop(request_id)
|
||||||
|
self.hift_cache_dict.pop(request_id)
|
||||||
|
|
||||||
|
return speech
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||||
|
for item in batch:
|
||||||
|
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||||
|
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||||
|
prompt_audios_list.append(audio)
|
||||||
|
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||||
|
ids.append(item['id'])
|
||||||
|
|
||||||
|
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--enable-trt", action="store_true")
|
||||||
|
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=1)
|
||||||
|
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
||||||
|
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
||||||
|
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
||||||
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
dataset_name = "yuekai/seed_tts_cosy2"
|
||||||
|
|
||||||
|
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||||
|
|
||||||
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||||
|
|
||||||
|
for _ in range(args.warmup):
|
||||||
|
start_time = time.time()
|
||||||
|
for batch in data_loader:
|
||||||
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
||||||
|
|
||||||
|
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
for id, wav in zip(ids, generated_wavs):
|
||||||
|
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
||||||
|
end_time = time.time()
|
||||||
|
epoch_time = end_time - start_time
|
||||||
|
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||||
69
runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt
Normal file
69
runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
name: "token2wav_dit"
|
||||||
|
backend: "python"
|
||||||
|
max_batch_size: ${triton_max_batch_size}
|
||||||
|
|
||||||
|
dynamic_batching {
|
||||||
|
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||||
|
priority_levels: 10
|
||||||
|
default_priority_level: 10
|
||||||
|
}
|
||||||
|
|
||||||
|
parameters [
|
||||||
|
{
|
||||||
|
key: "model_dir",
|
||||||
|
value: {string_value:"${model_dir}"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "target_speech_tokens"
|
||||||
|
data_type: TYPE_INT32
|
||||||
|
dims: [-1]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reference_wav"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [-1]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reference_wav_len"
|
||||||
|
data_type: TYPE_INT32
|
||||||
|
dims: [1]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "finalize"
|
||||||
|
data_type: TYPE_BOOL
|
||||||
|
dims: [ 1 ]
|
||||||
|
reshape: { shape: [ ] }
|
||||||
|
optional: true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "waveform"
|
||||||
|
data_type: TYPE_FP32
|
||||||
|
dims: [ -1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
count: 1
|
||||||
|
kind: KIND_CPU
|
||||||
|
}
|
||||||
|
]
|
||||||
652
runtime/triton_trtllm/offline_inference.py
Normal file
652
runtime/triton_trtllm/offline_inference.py
Normal file
@@ -0,0 +1,652 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Example Usage
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python3 offline_inference.py \
|
||||||
|
--output-dir $output_dir \
|
||||||
|
--llm-model-name-or-path $huggingface_model_local_dir \
|
||||||
|
--token2wav-path $model_scope_model_local_dir \
|
||||||
|
--backend $backend \
|
||||||
|
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
|
||||||
|
--engine-dir $trt_engines_dir \
|
||||||
|
--split-name ${dataset} || exit 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from cosyvoice.utils.file_utils import load_wav
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
import soundfile as sf
|
||||||
|
import s3tokenizer
|
||||||
|
from functools import partial
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
try:
|
||||||
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def send_request_async(client, url, payload):
|
||||||
|
response = await client.post(url, json=payload, timeout=None)
|
||||||
|
response.raise_for_status()
|
||||||
|
response_json = response.json()
|
||||||
|
return response_json['choices'][0]['message']['content']
|
||||||
|
|
||||||
|
|
||||||
|
async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
tasks = []
|
||||||
|
for chat in chats:
|
||||||
|
payload = {
|
||||||
|
"model": model_name,
|
||||||
|
"messages": chat,
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
"top_k": top_k,
|
||||||
|
"repetition_penalty": 1.1,
|
||||||
|
"stop": ["<|eos1|>", "<|eos|>"],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
tasks.append(send_request_async(client, api_base, payload))
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_speech_ids(speech_tokens_str):
|
||||||
|
"""Extract speech IDs from token strings like <|s_23456|>"""
|
||||||
|
speech_ids = []
|
||||||
|
for token_str in speech_tokens_str:
|
||||||
|
if token_str.startswith('<|s_') and token_str.endswith('|>'):
|
||||||
|
num_str = token_str[4:-2]
|
||||||
|
num = int(num_str)
|
||||||
|
speech_ids.append(num)
|
||||||
|
else:
|
||||||
|
print(f"Unexpected token: {token_str}")
|
||||||
|
return speech_ids
|
||||||
|
|
||||||
|
|
||||||
|
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||||
|
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||||
|
speech_id_str = ""
|
||||||
|
for token in cosy2_tokens:
|
||||||
|
speech_id_str += f"<|s_{token}|>"
|
||||||
|
return speech_id_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
|
||||||
|
parser.add_argument(
|
||||||
|
"--split-name",
|
||||||
|
type=str,
|
||||||
|
default="wenetspeech4tts",
|
||||||
|
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir", required=True, type=str, help="dir to save result"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="batch size (per-device) for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token2wav-batch-size",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="batch size (per-device) for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers", type=int, default=0, help="workers for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefetch", type=int, default=None, help="prefetch for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model-name-or-path",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="LLM model path (includes both model and tokenizer)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token2wav-path",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="CosyVoice2 token2wav model path",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-text",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The prompt text for CosyVoice2",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-speech-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The path to the prompt speech for CosyVoice2",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-p",
|
||||||
|
type=float,
|
||||||
|
default=0.95,
|
||||||
|
help="top p for sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="temperature for sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="top k for sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
type=str,
|
||||||
|
default="hf",
|
||||||
|
choices=["hf", "trtllm", "vllm", "trtllm-serve"],
|
||||||
|
help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--engine-dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-free-gpu-memory-fraction",
|
||||||
|
type=float,
|
||||||
|
default=0.6,
|
||||||
|
help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--openai-api-base",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8000/v1/chat/completions",
|
||||||
|
help="OpenAI API base URL (for trtllm-serve backend)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--openai-model-name",
|
||||||
|
type=str,
|
||||||
|
default="trt_engines_bfloat16",
|
||||||
|
help="Model name to use with OpenAI API (for trtllm-serve backend)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||||
|
"""Simplified data collator for batch_size=1 processing"""
|
||||||
|
collator_start_time = time.time()
|
||||||
|
total_audio_processing_time = 0
|
||||||
|
total_speech_tokenization_time = 0
|
||||||
|
total_text_tokenization_time = 0
|
||||||
|
|
||||||
|
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
||||||
|
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
||||||
|
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||||
|
prompt_text_after_apply_template_list = []
|
||||||
|
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
|
||||||
|
chat_list = []
|
||||||
|
for _, item in enumerate(batch):
|
||||||
|
audio_processing_start_time = time.time()
|
||||||
|
prompt_text, target_text = (
|
||||||
|
item["prompt_text"],
|
||||||
|
item["target_text"],
|
||||||
|
)
|
||||||
|
prompt_text_list.append(prompt_text)
|
||||||
|
full_text = prompt_text + target_text
|
||||||
|
full_text_list.append(full_text)
|
||||||
|
# remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
|
||||||
|
puncts = ['"', '(', ')', '“', '”', '‘', '(', ')', '\'']
|
||||||
|
for p in puncts:
|
||||||
|
if p in full_text:
|
||||||
|
full_text = full_text.replace(p, '')
|
||||||
|
print(f"removed {p} from {full_text}")
|
||||||
|
|
||||||
|
# get prompt audio for CosyVoice2 (convert to 16kHz)
|
||||||
|
ref_audio_org, ref_sr = (
|
||||||
|
item["prompt_audio"]["array"],
|
||||||
|
item["prompt_audio"]["sampling_rate"],
|
||||||
|
)
|
||||||
|
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
|
||||||
|
print(ref_audio_org.shape)
|
||||||
|
|
||||||
|
if ref_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||||
|
ref_audio = resampler(ref_audio_org)
|
||||||
|
else:
|
||||||
|
ref_audio = ref_audio_org
|
||||||
|
|
||||||
|
prompt_audio_list.append(ref_audio)
|
||||||
|
audio_processing_end_time = time.time()
|
||||||
|
total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
|
||||||
|
|
||||||
|
speech_tokenization_start_time = time.time()
|
||||||
|
if "prompt_audio_cosy2_tokens" in item:
|
||||||
|
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
|
||||||
|
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
|
||||||
|
else:
|
||||||
|
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
|
||||||
|
|
||||||
|
if len(mels) > 0:
|
||||||
|
mels, mels_lens = s3tokenizer.padding(mels)
|
||||||
|
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
||||||
|
for i in range(len(codes)):
|
||||||
|
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
||||||
|
speech_tokenization_end_time = time.time()
|
||||||
|
total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
|
||||||
|
|
||||||
|
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
|
||||||
|
text_tokenization_start_time = time.time()
|
||||||
|
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
||||||
|
# Create chat template for LLM generation
|
||||||
|
chat = [
|
||||||
|
{"role": "user", "content": full_text_list[i]},
|
||||||
|
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
|
||||||
|
]
|
||||||
|
chat_list.append(chat)
|
||||||
|
|
||||||
|
assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
|
||||||
|
|
||||||
|
input_ids = tokenizer.apply_chat_template(
|
||||||
|
chat,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors='pt',
|
||||||
|
continue_final_message=True
|
||||||
|
)
|
||||||
|
input_ids_list.append(input_ids.squeeze(0))
|
||||||
|
|
||||||
|
prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
|
||||||
|
|
||||||
|
prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
|
||||||
|
text_tokenization_end_time = time.time()
|
||||||
|
total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
|
||||||
|
|
||||||
|
ids = [item["id"] for item in batch]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids_list,
|
||||||
|
"ids": ids,
|
||||||
|
"prompt_text": prompt_text_list,
|
||||||
|
"prompt_audio_list": prompt_audio_list,
|
||||||
|
"prompt_text_after_apply_template": prompt_text_after_apply_template_list,
|
||||||
|
"audio_processing_time": total_audio_processing_time,
|
||||||
|
"speech_tokenization_time": total_speech_tokenization_time,
|
||||||
|
"text_tokenization_time": total_text_tokenization_time,
|
||||||
|
"chat_list": chat_list
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed():
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
rank = int(os.environ.get("RANK", 0))
|
||||||
|
print(
|
||||||
|
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||||
|
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
return world_size, local_rank, rank
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
local_rank, world_size, rank = 0, 1, 0
|
||||||
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
||||||
|
|
||||||
|
if args.backend == "hf":
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
||||||
|
model.eval()
|
||||||
|
model.to(device)
|
||||||
|
runner = None
|
||||||
|
elif args.backend == "trtllm":
|
||||||
|
if args.engine_dir is None:
|
||||||
|
raise ValueError("--engine-dir is required when backend is 'trtllm'")
|
||||||
|
|
||||||
|
runtime_rank = tensorrt_llm.mpi_rank()
|
||||||
|
model = None
|
||||||
|
|
||||||
|
runner_kwargs = dict(
|
||||||
|
engine_dir=args.engine_dir,
|
||||||
|
rank=runtime_rank,
|
||||||
|
max_output_len=2048,
|
||||||
|
enable_context_fmha_fp32_acc=False,
|
||||||
|
max_batch_size=args.batch_size,
|
||||||
|
max_input_len=512,
|
||||||
|
kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
|
||||||
|
cuda_graph_mode=False,
|
||||||
|
gather_generation_logits=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
|
||||||
|
elif args.backend == "vllm":
|
||||||
|
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
|
||||||
|
runner = None
|
||||||
|
elif args.backend == "trtllm-serve":
|
||||||
|
model = None
|
||||||
|
runner = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||||
|
if 'Step-Audio-2-mini' in args.token2wav_path:
|
||||||
|
from token2wav_dit import CosyVoice2_Token2Wav
|
||||||
|
else:
|
||||||
|
assert 'CosyVoice2-0.5B' in args.token2wav_path
|
||||||
|
from token2wav import CosyVoice2_Token2Wav
|
||||||
|
token2wav_model = CosyVoice2_Token2Wav(
|
||||||
|
model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
|
||||||
|
)
|
||||||
|
if args.prompt_speech_path:
|
||||||
|
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
|
||||||
|
else:
|
||||||
|
prompt_speech_16k = None
|
||||||
|
s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
|
||||||
|
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
|
||||||
|
dataset = load_dataset(
|
||||||
|
dataset_name,
|
||||||
|
split=args.split_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = None
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
prefetch_factor=args.prefetch,
|
||||||
|
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
|
||||||
|
)
|
||||||
|
for _ in range(3):
|
||||||
|
print(f"Running {_} times")
|
||||||
|
total_llm_time = 0
|
||||||
|
total_token2wav_time = 0
|
||||||
|
total_data_load_time = 0
|
||||||
|
total_llm_post_processing_time = 0
|
||||||
|
total_audio_save_time = 0
|
||||||
|
total_audio_processing_time_in_collator = 0
|
||||||
|
total_speech_tokenization_time_in_collator = 0
|
||||||
|
total_text_tokenization_time_in_collator = 0
|
||||||
|
total_audio_samples = 0
|
||||||
|
start_time = time.time()
|
||||||
|
total_steps = len(dataset)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||||
|
|
||||||
|
last_batch_end_time = time.time()
|
||||||
|
for batch in dataloader:
|
||||||
|
data_loaded_time = time.time()
|
||||||
|
total_data_load_time += data_loaded_time - last_batch_end_time
|
||||||
|
total_audio_processing_time_in_collator += batch["audio_processing_time"]
|
||||||
|
total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
|
||||||
|
total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
|
||||||
|
with torch.no_grad():
|
||||||
|
llm_start_time = time.time()
|
||||||
|
if args.backend == "hf":
|
||||||
|
input_ids_list = batch["input_ids"]
|
||||||
|
if len(input_ids_list) == 1:
|
||||||
|
input_ids = input_ids_list[0].unsqueeze(0)
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
else:
|
||||||
|
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||||
|
input_ids_list_new = [
|
||||||
|
torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
|
||||||
|
for input_ids in input_ids_list
|
||||||
|
]
|
||||||
|
input_ids = torch.stack(input_ids_list_new)
|
||||||
|
attention_mask = torch.zeros_like(input_ids)
|
||||||
|
for i in range(len(input_ids_list)):
|
||||||
|
attention_mask[i, :len(input_ids_list[i])] = 1
|
||||||
|
|
||||||
|
input_ids = input_ids.to(device)
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids=input_ids.to(device),
|
||||||
|
attention_mask=attention_mask.to(device),
|
||||||
|
max_new_tokens=2048,
|
||||||
|
do_sample=True,
|
||||||
|
top_p=args.top_p,
|
||||||
|
temperature=args.temperature,
|
||||||
|
repetition_penalty=1.1,
|
||||||
|
top_k=args.top_k,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elif args.backend == "trtllm":
|
||||||
|
batch_input_ids = list(batch["input_ids"])
|
||||||
|
input_lengths = [x.size(0) for x in batch_input_ids]
|
||||||
|
|
||||||
|
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
|
||||||
|
print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
|
||||||
|
outputs = runner.generate(
|
||||||
|
batch_input_ids=batch_input_ids,
|
||||||
|
max_new_tokens=2048,
|
||||||
|
end_id=end_id,
|
||||||
|
pad_id=end_id,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_k=args.top_k,
|
||||||
|
top_p=args.top_p,
|
||||||
|
repetition_penalty=1.1,
|
||||||
|
num_return_sequences=1,
|
||||||
|
streaming=False,
|
||||||
|
output_sequence_lengths=True,
|
||||||
|
output_generation_logits=False,
|
||||||
|
return_dict=True,
|
||||||
|
return_all_generated_tokens=False
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
|
||||||
|
num_output_sents, num_beams, _ = output_ids.size()
|
||||||
|
assert num_beams == 1
|
||||||
|
beam = 0
|
||||||
|
batch_size = len(batch["input_ids"])
|
||||||
|
num_return_sequences = num_output_sents // batch_size
|
||||||
|
assert num_return_sequences == 1
|
||||||
|
outputs = []
|
||||||
|
for i in range(batch_size * num_return_sequences):
|
||||||
|
batch_idx = i // num_return_sequences
|
||||||
|
seq_idx = i % num_return_sequences
|
||||||
|
output_begin = input_lengths[batch_idx]
|
||||||
|
output_end = sequence_lengths[i][beam]
|
||||||
|
outputs_i = output_ids[i][beam][:output_end].tolist()
|
||||||
|
outputs.append(outputs_i)
|
||||||
|
elif args.backend == "vllm":
|
||||||
|
input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.top_p,
|
||||||
|
top_k=args.top_k,
|
||||||
|
repetition_penalty=1.1,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
|
||||||
|
print(outputs)
|
||||||
|
for j, output in enumerate(outputs):
|
||||||
|
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
|
||||||
|
elif args.backend == "trtllm-serve":
|
||||||
|
if args.batch_size > 1:
|
||||||
|
outputs = asyncio.run(send_batch_requests_async(
|
||||||
|
args.openai_api_base,
|
||||||
|
args.openai_model_name,
|
||||||
|
batch["chat_list"],
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.top_k,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
outputs = []
|
||||||
|
for chat in batch["chat_list"]:
|
||||||
|
payload = {
|
||||||
|
"model": args.openai_model_name,
|
||||||
|
"messages": chat,
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"top_p": args.top_p,
|
||||||
|
"top_k": args.top_k,
|
||||||
|
"repetition_penalty": 1.1,
|
||||||
|
"stop": ["<|eos1|>", "<|eos|>"],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
response = requests.post(args.openai_api_base, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
response_json = response.json()
|
||||||
|
generated_content = response_json['choices'][0]['message']['content']
|
||||||
|
outputs.append(generated_content)
|
||||||
|
|
||||||
|
llm_end_time = time.time()
|
||||||
|
total_llm_time += (llm_end_time - llm_start_time)
|
||||||
|
|
||||||
|
items_for_token_2wav = []
|
||||||
|
for i in range(len(batch["ids"])):
|
||||||
|
llm_post_processing_start_time = time.time()
|
||||||
|
if args.backend == "trtllm-serve":
|
||||||
|
speech_tokens_str = outputs[i].strip().split('><')
|
||||||
|
if len(speech_tokens_str) > 1:
|
||||||
|
speech_tokens_str = [
|
||||||
|
t if t.startswith('<') else '<' + t for t in speech_tokens_str
|
||||||
|
]
|
||||||
|
speech_tokens_str = [
|
||||||
|
t if t.endswith('>') else t + '>' for t in speech_tokens_str
|
||||||
|
]
|
||||||
|
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||||
|
else:
|
||||||
|
input_length = len(batch["input_ids"][i])
|
||||||
|
generated_ids = outputs[i][input_length:]
|
||||||
|
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||||
|
print(i, speech_ids)
|
||||||
|
if len(speech_ids) == 0:
|
||||||
|
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if args.prompt_text is not None:
|
||||||
|
current_prompt_text = args.prompt_text
|
||||||
|
current_prompt_audio = prompt_speech_16k
|
||||||
|
else:
|
||||||
|
current_prompt_text = batch["prompt_text"][i]
|
||||||
|
current_prompt_audio = batch["prompt_audio_list"][i]
|
||||||
|
|
||||||
|
llm_post_processing_end_time = time.time()
|
||||||
|
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
|
||||||
|
if current_prompt_audio is not None:
|
||||||
|
items_for_token_2wav.append({
|
||||||
|
"speech_ids": speech_ids,
|
||||||
|
"prompt_audio": current_prompt_audio.squeeze(0),
|
||||||
|
"id": batch["ids"][i]
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||||
|
|
||||||
|
for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
|
||||||
|
t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
|
||||||
|
if not t2w_batch:
|
||||||
|
continue
|
||||||
|
|
||||||
|
t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
|
||||||
|
t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
|
||||||
|
t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
|
||||||
|
t2w_ids = [item["id"] for item in t2w_batch]
|
||||||
|
|
||||||
|
token2wav_start_time = time.time()
|
||||||
|
generated_wavs = token2wav_model(
|
||||||
|
t2w_generated_speech_tokens_list,
|
||||||
|
t2w_prompt_audios_list,
|
||||||
|
t2w_prompt_audios_sample_rate,
|
||||||
|
)
|
||||||
|
token2wav_end_time = time.time()
|
||||||
|
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
|
||||||
|
|
||||||
|
audio_save_start_time = time.time()
|
||||||
|
for j, audio_hat in enumerate(generated_wavs):
|
||||||
|
generated_wave = audio_hat.squeeze().cpu().numpy()
|
||||||
|
total_audio_samples += len(generated_wave)
|
||||||
|
target_sample_rate = 24000
|
||||||
|
|
||||||
|
utt = t2w_ids[j]
|
||||||
|
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
|
||||||
|
print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
|
||||||
|
audio_save_end_time = time.time()
|
||||||
|
total_audio_save_time += audio_save_end_time - audio_save_start_time
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.update(world_size * len(batch["ids"]))
|
||||||
|
|
||||||
|
last_batch_end_time = time.time()
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.close()
|
||||||
|
end_time = time.time()
|
||||||
|
target_sample_rate = 24000
|
||||||
|
total_audio_duration_seconds = total_audio_samples / target_sample_rate
|
||||||
|
|
||||||
|
log_file_path = os.path.join(args.output_dir, "log.txt")
|
||||||
|
with open(log_file_path, 'w') as f:
|
||||||
|
args_dict = vars(args)
|
||||||
|
log_data = {
|
||||||
|
"args": args_dict,
|
||||||
|
"data_load_time_seconds": total_data_load_time,
|
||||||
|
"audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
|
||||||
|
"speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
|
||||||
|
"text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
|
||||||
|
"llm_time_seconds": total_llm_time,
|
||||||
|
"llm_post_processing_time_seconds": total_llm_post_processing_time,
|
||||||
|
"token2wav_time_seconds": total_token2wav_time,
|
||||||
|
"audio_save_time_seconds": total_audio_save_time,
|
||||||
|
"total_audio_duration_seconds": total_audio_duration_seconds,
|
||||||
|
"pipeline_time_seconds": end_time - start_time,
|
||||||
|
}
|
||||||
|
print(log_data)
|
||||||
|
f.write(json.dumps(log_data, indent=4))
|
||||||
|
print(f"Metrics logged to {log_file_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
if args.backend == "vllm":
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
elif args.backend == "trtllm":
|
||||||
|
import tensorrt_llm
|
||||||
|
from tensorrt_llm.runtime import ModelRunnerCpp
|
||||||
|
elif args.backend == "hf":
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
elif args.backend == "trtllm-serve":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||||
|
main(args)
|
||||||
@@ -15,6 +15,8 @@ trt_engines_dir=./trt_engines_${trt_dtype}
|
|||||||
|
|
||||||
model_repo=./model_repo_cosyvoice2
|
model_repo=./model_repo_cosyvoice2
|
||||||
|
|
||||||
|
use_spk2info_cache=False
|
||||||
|
|
||||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
echo "Cloning CosyVoice"
|
echo "Cloning CosyVoice"
|
||||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
|
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
|
||||||
@@ -25,8 +27,11 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
echo "Downloading CosyVoice2-0.5B"
|
echo "Downloading CosyVoice2-0.5B"
|
||||||
|
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
|
||||||
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
|
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
|
||||||
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
|
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
|
||||||
|
# download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
|
||||||
|
wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
@@ -57,9 +62,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
cosyvoice2_dir="cosyvoice2"
|
cosyvoice2_dir="cosyvoice2"
|
||||||
|
|
||||||
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
||||||
cp -r ./model_repo/audio_tokenizer $model_repo
|
|
||||||
cp -r ./model_repo/tensorrt_llm $model_repo
|
cp -r ./model_repo/tensorrt_llm $model_repo
|
||||||
cp -r ./model_repo/token2wav $model_repo
|
cp -r ./model_repo/token2wav $model_repo
|
||||||
|
if [ $use_spk2info_cache == "False" ]; then
|
||||||
|
cp -r ./model_repo/audio_tokenizer $model_repo
|
||||||
|
cp -r ./model_repo/speaker_embedding $model_repo
|
||||||
|
fi
|
||||||
|
|
||||||
ENGINE_PATH=$trt_engines_dir
|
ENGINE_PATH=$trt_engines_dir
|
||||||
MAX_QUEUE_DELAY_MICROSECONDS=0
|
MAX_QUEUE_DELAY_MICROSECONDS=0
|
||||||
@@ -67,13 +75,15 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
|
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
|
||||||
BLS_INSTANCE_NUM=4
|
BLS_INSTANCE_NUM=4
|
||||||
TRITON_MAX_BATCH_SIZE=16
|
TRITON_MAX_BATCH_SIZE=16
|
||||||
DECOUPLED_MODE=False
|
DECOUPLED_MODE=True # True for streaming, False for offline
|
||||||
|
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
||||||
|
if [ $use_spk2info_cache == "False" ]; then
|
||||||
|
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
|
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
@@ -82,7 +92,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
echo "Single request test http"
|
echo "Single request test http, only work for offline TTS mode"
|
||||||
python3 client_http.py \
|
python3 client_http.py \
|
||||||
--reference-audio ./assets/prompt_audio.wav \
|
--reference-audio ./assets/prompt_audio.wav \
|
||||||
--reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
--reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
||||||
@@ -93,14 +103,40 @@ fi
|
|||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
echo "Running benchmark client grpc"
|
echo "Running benchmark client grpc"
|
||||||
num_task=4
|
num_task=4
|
||||||
# set mode=streaming, when decoupled=True
|
|
||||||
# set mode=offline, when decoupled=False
|
mode=streaming
|
||||||
mode=offline
|
BLS_INSTANCE_NUM=4
|
||||||
|
|
||||||
python3 client_grpc.py \
|
python3 client_grpc.py \
|
||||||
--server-addr localhost \
|
--server-addr localhost \
|
||||||
--model-name cosyvoice2 \
|
--model-name cosyvoice2 \
|
||||||
--num-tasks $num_task \
|
--num-tasks $num_task \
|
||||||
--mode $mode \
|
--mode $mode \
|
||||||
|
--use-spk2info-cache $use_spk2info_cache \
|
||||||
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
||||||
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_4_${trt_dtype}
|
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
|
echo "stage 6: Offline inference benchmark"
|
||||||
|
n_gpus=1
|
||||||
|
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
|
||||||
|
backend=trtllm # hf, trtllm, vllm
|
||||||
|
|
||||||
|
batch_sizes=(16 8 4 2 1)
|
||||||
|
token2wav_batch_size=1
|
||||||
|
for batch_size in ${batch_sizes[@]}; do
|
||||||
|
for dataset in ${datasets[@]}; do
|
||||||
|
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python3 offline_inference.py \
|
||||||
|
--output-dir $output_dir \
|
||||||
|
--llm-model-name-or-path $huggingface_model_local_dir \
|
||||||
|
--token2wav-path $model_scope_model_local_dir \
|
||||||
|
--backend $backend \
|
||||||
|
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
|
||||||
|
--engine-dir $trt_engines_dir \
|
||||||
|
--split-name ${dataset} || exit 1
|
||||||
|
done
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|||||||
225
runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh
Normal file
225
runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
cosyvoice_path=/workspace/CosyVoice
|
||||||
|
stepaudio2_path=/workspace/Step-Audio2
|
||||||
|
|
||||||
|
export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
|
||||||
|
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
|
||||||
|
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
|
||||||
|
|
||||||
|
stage=$1
|
||||||
|
stop_stage=$2
|
||||||
|
|
||||||
|
huggingface_model_local_dir=./cosyvoice2_llm
|
||||||
|
model_scope_model_local_dir=./CosyVoice2-0.5B
|
||||||
|
step_audio_model_dir=./Step-Audio-2-mini
|
||||||
|
|
||||||
|
trt_dtype=bfloat16
|
||||||
|
trt_weights_dir=./trt_weights_${trt_dtype}
|
||||||
|
trt_engines_dir=./trt_engines_${trt_dtype}
|
||||||
|
|
||||||
|
model_repo=./model_repo_cosyvoice2_dit
|
||||||
|
bls_instance_num=10
|
||||||
|
|
||||||
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
|
|
||||||
|
echo "Cloning Step-Audio2-mini"
|
||||||
|
git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt $stepaudio2_path
|
||||||
|
|
||||||
|
echo "Cloning CosyVoice"
|
||||||
|
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
|
||||||
|
cd $cosyvoice_path
|
||||||
|
git submodule update --init --recursive
|
||||||
|
cd runtime/triton_trtllm
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
echo "Downloading CosyVoice2-0.5B"
|
||||||
|
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
|
||||||
|
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
|
||||||
|
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
|
||||||
|
|
||||||
|
echo "Step-Audio2-mini"
|
||||||
|
huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini
|
||||||
|
cd $step_audio_model_dir/token2wav
|
||||||
|
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx
|
||||||
|
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx
|
||||||
|
cd -
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
echo "Converting checkpoint to TensorRT weights"
|
||||||
|
python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
|
||||||
|
--output_dir $trt_weights_dir \
|
||||||
|
--dtype $trt_dtype || exit 1
|
||||||
|
|
||||||
|
echo "Building TensorRT engines"
|
||||||
|
trtllm-build --checkpoint_dir $trt_weights_dir \
|
||||||
|
--output_dir $trt_engines_dir \
|
||||||
|
--max_batch_size 64 \
|
||||||
|
--max_num_tokens 32768 \
|
||||||
|
--gemm_plugin $trt_dtype || exit 1
|
||||||
|
|
||||||
|
echo "Testing TensorRT engines"
|
||||||
|
python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
|
||||||
|
--tokenizer_dir $huggingface_model_local_dir \
|
||||||
|
--top_k 50 --top_p 0.95 --temperature 0.8 \
|
||||||
|
--engine_dir=$trt_engines_dir || exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
echo "Creating model repository async mode"
|
||||||
|
rm -rf $model_repo
|
||||||
|
mkdir -p $model_repo
|
||||||
|
cosyvoice2_dir="cosyvoice2_dit"
|
||||||
|
token2wav_dir="token2wav_dit"
|
||||||
|
|
||||||
|
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
||||||
|
cp -r ./model_repo/${token2wav_dir} $model_repo
|
||||||
|
cp -r ./model_repo/audio_tokenizer $model_repo
|
||||||
|
cp -r ./model_repo/speaker_embedding $model_repo
|
||||||
|
|
||||||
|
|
||||||
|
ENGINE_PATH=$trt_engines_dir
|
||||||
|
MAX_QUEUE_DELAY_MICROSECONDS=0
|
||||||
|
MODEL_DIR=$model_scope_model_local_dir
|
||||||
|
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
|
||||||
|
BLS_INSTANCE_NUM=$bls_instance_num
|
||||||
|
TRITON_MAX_BATCH_SIZE=1
|
||||||
|
DECOUPLED_MODE=True # Only streaming TTS mode is supported using Nvidia Triton for now
|
||||||
|
STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav
|
||||||
|
|
||||||
|
python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
|
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
|
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
|
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve"
|
||||||
|
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64 --kv_cache_free_gpu_memory_fraction 0.4 &
|
||||||
|
tritonserver --model-repository $model_repo --http-port 18000 &
|
||||||
|
wait
|
||||||
|
# Test using curl
|
||||||
|
# curl http://localhost:8000/v1/chat/completions \
|
||||||
|
# -H "Content-Type: application/json" \
|
||||||
|
# -d '{
|
||||||
|
# "model": "",
|
||||||
|
# "messages":[{"role": "user", "content": "Where is New York?"},
|
||||||
|
# {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
|
||||||
|
# "max_tokens": 512,
|
||||||
|
# "temperature": 0.8,
|
||||||
|
# "top_p": 0.95,
|
||||||
|
# "top_k": 50,
|
||||||
|
# "stop": ["<|eos1|>"],
|
||||||
|
# "repetition_penalty": 1.2,
|
||||||
|
# "stream": false
|
||||||
|
# }'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
echo "Running benchmark client"
|
||||||
|
num_task=4
|
||||||
|
mode=streaming
|
||||||
|
BLS_INSTANCE_NUM=$bls_instance_num
|
||||||
|
|
||||||
|
python3 client_grpc.py \
|
||||||
|
--server-addr localhost \
|
||||||
|
--server-port 8001 \
|
||||||
|
--model-name cosyvoice2_dit \
|
||||||
|
--num-tasks $num_task \
|
||||||
|
--mode $mode \
|
||||||
|
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
||||||
|
--log-dir ./log_single_gpu_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
|
echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script"
|
||||||
|
|
||||||
|
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
|
||||||
|
backend=trtllm # hf, trtllm, vllm, trtllm-serve
|
||||||
|
|
||||||
|
batch_sizes=(16)
|
||||||
|
token2wav_batch_size=1
|
||||||
|
|
||||||
|
for batch_size in ${batch_sizes[@]}; do
|
||||||
|
for dataset in ${datasets[@]}; do
|
||||||
|
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
|
||||||
|
CUDA_VISIBLE_DEVICES=1 \
|
||||||
|
python3 offline_inference.py \
|
||||||
|
--output-dir $output_dir \
|
||||||
|
--llm-model-name-or-path $huggingface_model_local_dir \
|
||||||
|
--token2wav-path $step_audio_model_dir/token2wav \
|
||||||
|
--backend $backend \
|
||||||
|
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
|
||||||
|
--engine-dir $trt_engines_dir \
|
||||||
|
--split-name ${dataset} || exit 1
|
||||||
|
done
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
|
echo "Running Step-Audio2-mini DiT Token2Wav inference using a single python script"
|
||||||
|
export CUDA_VISIBLE_DEVICES=1
|
||||||
|
# Note: Using pre-computed cosyvoice2 tokens
|
||||||
|
python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential
|
||||||
|
# Offline Token2wav inference
|
||||||
|
python3 token2wav_dit.py --enable-trt
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
|
echo "Disaggregated Server: LLM and Token2wav on different GPUs"
|
||||||
|
echo "Starting LLM server on GPU 0"
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64 --kv_cache_free_gpu_memory_fraction 0.4 &
|
||||||
|
echo "Starting Token2wav server on GPUs 1-3"
|
||||||
|
Token2wav_num_gpus=3
|
||||||
|
http_port=17000
|
||||||
|
grpc_port=18000
|
||||||
|
metrics_port=16000
|
||||||
|
for i in $(seq 0 $(($Token2wav_num_gpus - 1))); do
|
||||||
|
echo "Starting server on GPU $i"
|
||||||
|
http_port=$((http_port + 1))
|
||||||
|
grpc_port=$((grpc_port + 1))
|
||||||
|
metrics_port=$((metrics_port + 1))
|
||||||
|
# Two instances of Token2wav server on the same GPU
|
||||||
|
CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
|
||||||
|
http_port=$((http_port + 1))
|
||||||
|
grpc_port=$((grpc_port + 1))
|
||||||
|
metrics_port=$((metrics_port + 1))
|
||||||
|
CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||||
|
echo "Running benchmark client for Disaggregated Server"
|
||||||
|
per_gpu_instances=2
|
||||||
|
mode=streaming
|
||||||
|
BLS_INSTANCE_NUM=$bls_instance_num
|
||||||
|
Token2wav_num_gpus=(1 2 3)
|
||||||
|
concurrent_tasks=(1 2 3 4 5 6)
|
||||||
|
for n_gpu in ${Token2wav_num_gpus[@]}; do
|
||||||
|
echo "Test 1 GPU for LLM server and $n_gpu GPUs for Token2wav servers"
|
||||||
|
for concurrent_task in ${concurrent_tasks[@]}; do
|
||||||
|
num_instances=$((per_gpu_instances * n_gpu))
|
||||||
|
for i in $(seq 1 $num_instances); do
|
||||||
|
port=$(($i + 18000))
|
||||||
|
python3 client_grpc.py \
|
||||||
|
--server-addr localhost \
|
||||||
|
--server-port $port \
|
||||||
|
--model-name cosyvoice2_dit \
|
||||||
|
--num-tasks $concurrent_task \
|
||||||
|
--mode $mode \
|
||||||
|
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
||||||
|
--log-dir ./log_disagg_concurrent_tasks_${concurrent_task}_per_instance_total_token2wav_instances_${num_instances}_port_${port} &
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
done
|
||||||
|
done
|
||||||
|
fi
|
||||||
@@ -15,11 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
|
||||||
import csv
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
122
runtime/triton_trtllm/streaming_inference.py
Normal file
122
runtime/triton_trtllm/streaming_inference.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
import torchaudio
|
||||||
|
import time
|
||||||
|
from token2wav_dit import CosyVoice2_Token2Wav
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||||
|
prompt_speech_tokens_list, prompt_text_list = [], []
|
||||||
|
for item in batch:
|
||||||
|
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||||
|
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||||
|
prompt_audios_list.append(audio)
|
||||||
|
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||||
|
ids.append(item['id'])
|
||||||
|
prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens'])
|
||||||
|
prompt_text_list.append(item['prompt_text'])
|
||||||
|
|
||||||
|
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--enable-trt", action="store_true")
|
||||||
|
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=1)
|
||||||
|
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
||||||
|
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
||||||
|
parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
|
||||||
|
parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"])
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
|
dataset_name = args.dataset_name
|
||||||
|
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||||
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||||
|
|
||||||
|
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
|
||||||
|
|
||||||
|
CHUNK_SIZE = 25
|
||||||
|
token_frame_rate = 25
|
||||||
|
OVERLAP_SIZE = 0
|
||||||
|
|
||||||
|
warmup_times = 3
|
||||||
|
for _ in range(warmup_times):
|
||||||
|
start_time = time.time()
|
||||||
|
total_forward_count = 0
|
||||||
|
for batch in data_loader:
|
||||||
|
tts_speech_list = []
|
||||||
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
|
||||||
|
|
||||||
|
id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
|
||||||
|
|
||||||
|
assert prompt_audio_sample_rate == 16000
|
||||||
|
|
||||||
|
prompt_text = prompt_text_list[0]
|
||||||
|
prompt_speech_tokens = prompt_speech_tokens_list[0]
|
||||||
|
|
||||||
|
semantic_token_ids_arr, token_offset = [], 0
|
||||||
|
flow_prompt_speech_token_len = len(prompt_speech_tokens)
|
||||||
|
|
||||||
|
buffer = generated_speech_tokens
|
||||||
|
output_wavs = []
|
||||||
|
chunk_index = 0
|
||||||
|
while True:
|
||||||
|
if args.strategy == "equal":
|
||||||
|
this_chunk_size = CHUNK_SIZE
|
||||||
|
elif args.strategy == "exponential":
|
||||||
|
this_chunk_size = token_frame_rate * (2 ** chunk_index)
|
||||||
|
|
||||||
|
if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
|
||||||
|
wavs = token2wav_model.forward_streaming(
|
||||||
|
buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len],
|
||||||
|
False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio,
|
||||||
|
prompt_audio_sample_rate=prompt_audio_sample_rate
|
||||||
|
)
|
||||||
|
buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
|
||||||
|
|
||||||
|
output_wavs.append(wavs)
|
||||||
|
total_forward_count += 1
|
||||||
|
chunk_index += 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
wavs = token2wav_model.forward_streaming(
|
||||||
|
buffer, True, request_id=id, speaker_id=f"{id}",
|
||||||
|
prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate
|
||||||
|
)
|
||||||
|
output_wavs.append(wavs)
|
||||||
|
total_forward_count += 1
|
||||||
|
# chunk_index += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
for i, wav in enumerate(output_wavs):
|
||||||
|
output_wavs[i] = wav.cpu().numpy().squeeze()
|
||||||
|
|
||||||
|
audios = output_wavs
|
||||||
|
reconstructed_audio = np.concatenate(audios)
|
||||||
|
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
if _ == 0:
|
||||||
|
token2wav_model.speaker_cache = {}
|
||||||
|
print(f"Warmup time: {end_time - start_time} seconds")
|
||||||
|
print("clear speaker cache")
|
||||||
|
elif _ == 1:
|
||||||
|
print(f"Cost time without speaker cache: {end_time - start_time} seconds")
|
||||||
|
else:
|
||||||
|
print(f"Cost time with speaker cache: {end_time - start_time} seconds")
|
||||||
|
print(f"Total flow matching forward calls: {total_forward_count}")
|
||||||
335
runtime/triton_trtllm/token2wav.py
Normal file
335
runtime/triton_trtllm/token2wav.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Example Usage
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python3 token2wav.py --enable-trt || exit 1
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
|
||||||
|
from flashcosyvoice.modules.hifigan import HiFTGenerator
|
||||||
|
from flashcosyvoice.utils.audio import mel_spectrogram
|
||||||
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
import onnxruntime
|
||||||
|
import s3tokenizer
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from datasets import load_dataset
|
||||||
|
import torchaudio
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||||
|
import tensorrt as trt
|
||||||
|
logging.info("Converting onnx to trt...")
|
||||||
|
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||||
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
|
builder = trt.Builder(logger)
|
||||||
|
network = builder.create_network(network_flags)
|
||||||
|
parser = trt.OnnxParser(network, logger)
|
||||||
|
config = builder.create_builder_config()
|
||||||
|
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||||
|
if fp16:
|
||||||
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
|
profile = builder.create_optimization_profile()
|
||||||
|
# load onnx model
|
||||||
|
with open(onnx_model, "rb") as f:
|
||||||
|
if not parser.parse(f.read()):
|
||||||
|
for error in range(parser.num_errors):
|
||||||
|
print(parser.get_error(error))
|
||||||
|
raise ValueError('failed to parse {}'.format(onnx_model))
|
||||||
|
# set input shapes
|
||||||
|
for i in range(len(trt_kwargs['input_names'])):
|
||||||
|
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
||||||
|
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
||||||
|
# set input and output data type
|
||||||
|
for i in range(network.num_inputs):
|
||||||
|
input_tensor = network.get_input(i)
|
||||||
|
input_tensor.dtype = tensor_dtype
|
||||||
|
for i in range(network.num_outputs):
|
||||||
|
output_tensor = network.get_output(i)
|
||||||
|
output_tensor.dtype = tensor_dtype
|
||||||
|
config.add_optimization_profile(profile)
|
||||||
|
engine_bytes = builder.build_serialized_network(network, config)
|
||||||
|
# save trt engine
|
||||||
|
with open(trt_model, "wb") as f:
|
||||||
|
f.write(engine_bytes)
|
||||||
|
logging.info("Succesfully convert onnx to trt...")
|
||||||
|
|
||||||
|
|
||||||
|
class TrtContextWrapper:
|
||||||
|
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||||
|
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||||
|
self.trt_engine = trt_engine
|
||||||
|
self.device = device
|
||||||
|
for _ in range(trt_concurrent):
|
||||||
|
trt_context = trt_engine.create_execution_context()
|
||||||
|
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
|
||||||
|
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||||
|
self.trt_context_pool.put([trt_context, trt_stream])
|
||||||
|
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||||
|
|
||||||
|
def acquire_estimator(self):
|
||||||
|
return self.trt_context_pool.get(), self.trt_engine
|
||||||
|
|
||||||
|
def release_estimator(self, context, stream):
|
||||||
|
self.trt_context_pool.put([context, stream])
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||||
|
def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0):
|
||||||
|
super().__init__()
|
||||||
|
self.device_id = device_id
|
||||||
|
self.device = f"cuda:{device_id}"
|
||||||
|
|
||||||
|
self.flow = CausalMaskedDiffWithXvec()
|
||||||
|
self.flow.half()
|
||||||
|
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
||||||
|
self.flow.to(self.device).eval()
|
||||||
|
|
||||||
|
self.hift = HiFTGenerator()
|
||||||
|
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
|
||||||
|
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||||
|
self.hift.to(self.device).eval()
|
||||||
|
|
||||||
|
option = onnxruntime.SessionOptions()
|
||||||
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
option.intra_op_num_threads = 1
|
||||||
|
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval()
|
||||||
|
|
||||||
|
gpu = "l20"
|
||||||
|
if enable_trt:
|
||||||
|
self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan',
|
||||||
|
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
||||||
|
1,
|
||||||
|
True)
|
||||||
|
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||||
|
f'{model_dir}/campplus.onnx',
|
||||||
|
1,
|
||||||
|
False)
|
||||||
|
|
||||||
|
def forward_spk_embedding(self, spk_feat):
|
||||||
|
if isinstance(self.spk_model, onnxruntime.InferenceSession):
|
||||||
|
return self.spk_model.run(
|
||||||
|
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
||||||
|
)[0].flatten().tolist()
|
||||||
|
else:
|
||||||
|
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
|
||||||
|
# NOTE need to synchronize when switching stream
|
||||||
|
with torch.cuda.device(self.device_id):
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
|
||||||
|
batch_size = spk_feat.size(0)
|
||||||
|
|
||||||
|
with stream:
|
||||||
|
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
|
||||||
|
output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
|
||||||
|
|
||||||
|
data_ptrs = [spk_feat.contiguous().data_ptr(),
|
||||||
|
output_tensor.contiguous().data_ptr()]
|
||||||
|
for i, j in enumerate(data_ptrs):
|
||||||
|
|
||||||
|
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||||
|
# run trt engine
|
||||||
|
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
self.spk_model.release_estimator(spk_model, stream)
|
||||||
|
|
||||||
|
return output_tensor.cpu().numpy().flatten().tolist()
|
||||||
|
|
||||||
|
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
|
||||||
|
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
|
||||||
|
trt_kwargs = self.get_spk_trt_kwargs()
|
||||||
|
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
|
||||||
|
import tensorrt as trt
|
||||||
|
with open(spk_model, 'rb') as f:
|
||||||
|
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||||
|
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
|
||||||
|
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||||
|
|
||||||
|
def get_spk_trt_kwargs(self):
|
||||||
|
min_shape = [(1, 4, 80)]
|
||||||
|
opt_shape = [(1, 500, 80)]
|
||||||
|
max_shape = [(1, 3000, 80)]
|
||||||
|
input_names = ["input"]
|
||||||
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
|
|
||||||
|
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True):
|
||||||
|
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||||
|
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||||
|
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_bs=2, max_batch_size=16)
|
||||||
|
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16)
|
||||||
|
del self.flow.decoder.estimator
|
||||||
|
import tensorrt as trt
|
||||||
|
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||||
|
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||||
|
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||||
|
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||||
|
|
||||||
|
def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64):
|
||||||
|
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
|
||||||
|
opt_shape = [(opt_bs * 2, 80, 500), (opt_bs * 2, 1, 500), (opt_bs * 2, 80, 500), (opt_bs * 2, 80, 500), (opt_bs * 2,), (opt_bs * 2, 80)]
|
||||||
|
max_shape = [(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2,),
|
||||||
|
(max_batch_size * 2, 80)]
|
||||||
|
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
|
||||||
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
|
|
||||||
|
def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
|
||||||
|
prompt_speech_tokens_list, prompt_speech_mels_list = [], []
|
||||||
|
for audio in prompt_audios_list:
|
||||||
|
assert len(audio.shape) == 1
|
||||||
|
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
|
||||||
|
prompt_speech_mels_list.append(log_mel)
|
||||||
|
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
|
||||||
|
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
|
||||||
|
prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
|
||||||
|
)
|
||||||
|
for i in range(len(prompt_speech_tokens)):
|
||||||
|
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
||||||
|
prompt_speech_tokens_list.append(speech_tokens_i)
|
||||||
|
return prompt_speech_tokens_list
|
||||||
|
|
||||||
|
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
||||||
|
spk_emb_for_flow = []
|
||||||
|
for audio in prompt_audios_list:
|
||||||
|
assert len(audio.shape) == 1
|
||||||
|
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
|
||||||
|
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
|
||||||
|
spk_emb = self.forward_spk_embedding(spk_feat)
|
||||||
|
|
||||||
|
spk_emb_for_flow.append(spk_emb)
|
||||||
|
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
||||||
|
return spk_emb_for_flow
|
||||||
|
|
||||||
|
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
||||||
|
prompt_mels_for_flow = []
|
||||||
|
prompt_mels_lens_for_flow = []
|
||||||
|
for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
|
||||||
|
assert len(audio.shape) == 1
|
||||||
|
audio = audio.unsqueeze(0)
|
||||||
|
if sample_rate != 24000:
|
||||||
|
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
|
||||||
|
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
|
||||||
|
mel_len = mel.shape[0]
|
||||||
|
prompt_mels_for_flow.append(mel)
|
||||||
|
prompt_mels_lens_for_flow.append(mel_len)
|
||||||
|
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
||||||
|
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
||||||
|
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
||||||
|
|
||||||
|
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor,
|
||||||
|
prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
||||||
|
batch_size = prompt_mels_for_flow.shape[0]
|
||||||
|
flow_inputs = []
|
||||||
|
flow_inputs_lens = []
|
||||||
|
for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
|
||||||
|
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
|
||||||
|
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
|
||||||
|
|
||||||
|
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
|
||||||
|
flow_inputs_lens = torch.tensor(flow_inputs_lens)
|
||||||
|
|
||||||
|
with torch.amp.autocast(self.device, dtype=torch.float16):
|
||||||
|
generated_mels, generated_mels_lens = self.flow(
|
||||||
|
flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
|
||||||
|
prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device),
|
||||||
|
streaming=False, finalize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return generated_mels, generated_mels_lens
|
||||||
|
|
||||||
|
def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
|
||||||
|
batch_size = generated_mels.shape[0]
|
||||||
|
generated_wavs = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
|
||||||
|
wav, _ = self.hift(speech_feat=mel)
|
||||||
|
generated_wavs.append(wav)
|
||||||
|
return generated_wavs
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward(
|
||||||
|
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||||
|
):
|
||||||
|
# assert all item in prompt_audios_sample_rate is 16000
|
||||||
|
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
||||||
|
|
||||||
|
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
||||||
|
|
||||||
|
generated_mels, generated_mels_lens = self.forward_flow(
|
||||||
|
prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||||
|
|
||||||
|
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
||||||
|
|
||||||
|
return generated_wavs
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||||
|
for _, item in enumerate(batch):
|
||||||
|
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||||
|
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||||
|
prompt_audios_list.append(audio)
|
||||||
|
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||||
|
ids.append(item['id'])
|
||||||
|
|
||||||
|
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--enable-trt", action="store_true")
|
||||||
|
parser.add_argument("--model-dir", type=str, default="./CosyVoice2-0.5B")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=4)
|
||||||
|
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
||||||
|
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
||||||
|
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
||||||
|
# mkdir output_dir if not exists
|
||||||
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
dataset_name = "yuekai/seed_tts_cosy2"
|
||||||
|
|
||||||
|
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||||
|
|
||||||
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||||
|
|
||||||
|
for _ in range(args.warmup):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for batch in data_loader:
|
||||||
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
||||||
|
|
||||||
|
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
for id, wav in zip(ids, generated_wavs):
|
||||||
|
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
epoch_time = end_time - start_time
|
||||||
|
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||||
1
runtime/triton_trtllm/token2wav_dit.py
Symbolic link
1
runtime/triton_trtllm/token2wav_dit.py
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
model_repo/token2wav_dit/1/token2wav_dit.py
|
||||||
Reference in New Issue
Block a user