mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
Compare commits
142 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f08872a82f | ||
|
|
f2ddcbe7f9 | ||
|
|
0d990d6074 | ||
|
|
c93d3dda01 | ||
|
|
84e41729ea | ||
|
|
f26cde56df | ||
|
|
66b80dbccb | ||
|
|
1822c5c908 | ||
|
|
1dcc59676f | ||
|
|
7fdd80dc64 | ||
|
|
f97d50d559 | ||
|
|
652132ebaa | ||
|
|
1ceb2b7b1e | ||
|
|
55e3e370a0 | ||
|
|
46788c7379 | ||
|
|
881177287c | ||
|
|
f88a14e41d | ||
|
|
dac6566fc3 | ||
|
|
cc91e40db8 | ||
|
|
ab7f1f4a86 | ||
|
|
e15222b17c | ||
|
|
cfa1c115b2 | ||
|
|
dd5cdb6ebf | ||
|
|
2d7ef0b719 | ||
|
|
ba5db602a9 | ||
|
|
5b94675f62 | ||
|
|
4c19646b9a | ||
|
|
63a06227d1 | ||
|
|
3b44913782 | ||
|
|
055f64d002 | ||
|
|
4d7295a9a7 | ||
|
|
8524c81acd | ||
|
|
a14e063ead | ||
|
|
2db78e7058 | ||
|
|
7538c6a73d | ||
|
|
823ae2c60d | ||
|
|
59cb2bf16c | ||
|
|
80bebb1978 | ||
|
|
bc34459bb8 | ||
|
|
9f27b42cd9 | ||
|
|
a7d6e2251a | ||
|
|
7baefaf0f2 | ||
|
|
ff0d05c380 | ||
|
|
f5816b4e51 | ||
|
|
8b54619760 | ||
|
|
2abd42220e | ||
|
|
2d6bb9bd80 | ||
|
|
0b80c0746a | ||
|
|
e98b828f33 | ||
|
|
4d4c787be0 | ||
|
|
781a49acb4 | ||
|
|
9476a063b3 | ||
|
|
3426ceb70f | ||
|
|
a460960ade | ||
|
|
f51f5c5c6a | ||
|
|
f11ba4024c | ||
|
|
089343ab0a | ||
|
|
0c50894d49 | ||
|
|
95d56cba64 | ||
|
|
095f7bad55 | ||
|
|
a6eb2c56da | ||
|
|
ca3b054a52 | ||
|
|
b02d7e61f7 | ||
|
|
6b6a5a7bd1 | ||
|
|
5640545406 | ||
|
|
5bc4b23f02 | ||
|
|
ebef63066f | ||
|
|
3298d6f3e3 | ||
|
|
f21c4764ec | ||
|
|
927addadd8 | ||
|
|
a051a09ba4 | ||
|
|
0c65d3c7ab | ||
|
|
56d9876037 | ||
|
|
b35ece675b | ||
|
|
59f02cb85d | ||
|
|
b4dd67a8af | ||
|
|
bfa835a74b | ||
|
|
622a3a19b0 | ||
|
|
d985100326 | ||
|
|
6816fc6a6f | ||
|
|
e8bf717333 | ||
|
|
fa2781405f | ||
|
|
cd26dd1932 | ||
|
|
6e01309e01 | ||
|
|
1fc8435146 | ||
|
|
a224be6117 | ||
|
|
33aee03ed5 | ||
|
|
8811e9f33a | ||
|
|
807bb6ee0b | ||
|
|
aceede59ba | ||
|
|
7cbd490253 | ||
|
|
a019a2504e | ||
|
|
f186ec3338 | ||
|
|
988d395162 | ||
|
|
4d60ff6abc | ||
|
|
be005c825f | ||
|
|
79116ac32e | ||
|
|
31a0adc73d | ||
|
|
482464ea27 | ||
|
|
444b7ff5df | ||
|
|
b207c60885 | ||
|
|
0b357ba25d | ||
|
|
0867ebcb8c | ||
|
|
52556a6de9 | ||
|
|
66ef5a097b | ||
|
|
cc1991870b | ||
|
|
8ded65e611 | ||
|
|
6971536358 | ||
|
|
86e7c2d731 | ||
|
|
8a4309d89c | ||
|
|
ad257b06e3 | ||
|
|
633b991290 | ||
|
|
e04699c6da | ||
|
|
73d261dd48 | ||
|
|
b7ec6c4678 | ||
|
|
f76f5abcc1 | ||
|
|
6b5eef62cc | ||
|
|
dc96e4c984 | ||
|
|
70991d7327 | ||
|
|
8c96081f94 | ||
|
|
dd2d926147 | ||
|
|
da41f6175b | ||
|
|
e3c2400abb | ||
|
|
a976519ada | ||
|
|
cf615011ce | ||
|
|
9ddb9e4a83 | ||
|
|
0a496c18f7 | ||
|
|
05bdf4c769 | ||
|
|
1850e2a56e | ||
|
|
47e4137651 | ||
|
|
0bc48c1180 | ||
|
|
62d082634e | ||
|
|
07cbc51cd1 | ||
|
|
d1c354eac7 | ||
|
|
1b8d194b67 | ||
|
|
b44f121102 | ||
|
|
dc196df940 | ||
|
|
178da09993 | ||
|
|
11515d0d5a | ||
|
|
5427c274e3 | ||
|
|
3387f07266 | ||
|
|
b048a2d6db |
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -52,5 +52,5 @@ jobs:
|
||||
set -eux
|
||||
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
|
||||
flake8 --version
|
||||
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
||||
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F722,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
||||
if [ $? != 0 ]; then exit 1; fi
|
||||
198
README.md
198
README.md
@@ -1,50 +1,52 @@
|
||||
[](https://github.com/Akshay090/svg-banners)
|
||||

|
||||
|
||||
## 👉🏻 CosyVoice 👈🏻
|
||||
|
||||
**CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/abs/2505.17589); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
|
||||
**Fun-CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/pdf/2505.17589); [Modelscope](https://www.modelscope.cn/models/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [Huggingface](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
|
||||
|
||||
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B)
|
||||
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/pdf/2412.10117); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B)
|
||||
|
||||
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M)
|
||||
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice-300M); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice-300M)
|
||||
|
||||
## Highlight🔥
|
||||
|
||||
**CosyVoice 2.0** has been released! Compared to version 1.0, the new version offers more accurate, more stable, faster, and better speech generation capabilities.
|
||||
### Multilingual
|
||||
- **Supported Language**: Chinese, English, Japanese, Korean, Chinese dialects (Cantonese, Sichuanese, Shanghainese, Tianjinese, Wuhanese, etc.)
|
||||
- **Crosslingual & Mixlingual**:Support zero-shot voice cloning for cross-lingual and code-switching scenarios.
|
||||
### Ultra-Low Latency
|
||||
- **Bidirectional Streaming Support**: CosyVoice 2.0 integrates offline and streaming modeling technologies.
|
||||
- **Rapid First Packet Synthesis**: Achieves latency as low as 150ms while maintaining high-quality audio output.
|
||||
### High Accuracy
|
||||
- **Improved Pronunciation**: Reduces pronunciation errors by 30% to 50% compared to CosyVoice 1.0.
|
||||
- **Benchmark Achievements**: Attains the lowest character error rate on the hard test set of the Seed-TTS evaluation set.
|
||||
### Strong Stability
|
||||
- **Consistency in Timbre**: Ensures reliable voice consistency for zero-shot and cross-language speech synthesis.
|
||||
- **Cross-language Synthesis**: Marked improvements compared to version 1.0.
|
||||
### Natural Experience
|
||||
- **Enhanced Prosody and Sound Quality**: Improved alignment of synthesized audio, raising MOS evaluation scores from 5.4 to 5.53.
|
||||
- **Emotional and Dialectal Flexibility**: Now supports more granular emotional controls and accent adjustments.
|
||||
**Fun-CosyVoice 3.0** is an advanced text-to-speech (TTS) system based on large language models (LLM), surpassing its predecessor (CosyVoice 2.0) in content consistency, speaker similarity, and prosody naturalness. It is designed for zero-shot multilingual speech synthesis in the wild.
|
||||
### Key Features
|
||||
- **Language Coverage**: Covers 9 common languages (Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian), 18+ Chinese dialects/accents (Guangdong, Minnan, Sichuan, Dongbei, Shan3xi, Shan1xi, Shanghai, Tianjin, Shandong, Ningxia, Gansu, etc.) and meanwhile supports both multi-lingual/cross-lingual zero-shot voice cloning.
|
||||
- **Content Consistency & Naturalness**: Achieves state-of-the-art performance in content consistency, speaker similarity, and prosody naturalness.
|
||||
- **Pronunciation Inpainting**: Supports pronunciation inpainting of Chinese Pinyin and English CMU phonemes, providing more controllability and thus suitable for production use.
|
||||
- **Text Normalization**: Supports reading of numbers, special symbols and various text formats without a traditional frontend module.
|
||||
- **Bi-Streaming**: Support both text-in streaming and audio-out streaming, and achieves latency as low as 150ms while maintaining high-quality audio output.
|
||||
- **Instruct Support**: Supports various instructions such as languages, dialects, emotions, speed, volume, etc.
|
||||
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [x] 2025/12
|
||||
|
||||
- [x] release Fun-CosyVoice3-0.5B-2512 base model, rl model and its training/inference script
|
||||
- [x] release Fun-CosyVoice3-0.5B modelscope gradio space
|
||||
|
||||
- [x] 2025/08
|
||||
|
||||
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
|
||||
|
||||
- [x] 2025/07
|
||||
|
||||
- [x] release cosyvoice 3.0 eval set
|
||||
- [x] release Fun-CosyVoice 3.0 eval set
|
||||
|
||||
- [x] 2025/05
|
||||
|
||||
- [x] add cosyvoice 2.0 vllm support
|
||||
- [x] add CosyVoice2-0.5B vllm support
|
||||
|
||||
- [x] 2024/12
|
||||
|
||||
- [x] 25hz cosyvoice 2.0 released
|
||||
- [x] 25hz CosyVoice2-0.5B released
|
||||
|
||||
- [x] 2024/09
|
||||
|
||||
- [x] 25hz cosyvoice base model
|
||||
- [x] 25hz cosyvoice voice conversion model
|
||||
- [x] 25hz CosyVoice-300M base model
|
||||
- [x] 25hz CosyVoice-300M voice conversion function
|
||||
|
||||
- [x] 2024/08
|
||||
|
||||
@@ -57,6 +59,27 @@
|
||||
- [x] WeTextProcessing support when ttsfrd is not available
|
||||
- [x] Fastapi server and client
|
||||
|
||||
## Evaluation
|
||||
|
||||
| Model | Open-Source | Model Size | test-zh<br>CER (%) ↓ | test-zh<br>SS (%) ↑ | test-en<br>WER (%) ↓ | test-en<br>SS (%) ↑ | test-hard<br>CER (%) ↓ | test-hard<br>SS (%) ↑ |
|
||||
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Human | - | - | 1.26 | 75.5 | 2.14 | 73.4 | - | - |
|
||||
| Seed-TTS | ❌ | - | 1.12 | 79.6 | 2.25 | 76.2 | 7.59 | 77.6 |
|
||||
| MiniMax-Speech | ❌ | - | 0.83 | 78.3 | 1.65 | 69.2 | - | - |
|
||||
| F5-TTS | ✅ | 0.3B | 1.52 | 74.1 | 2.00 | 64.7 | 8.67 | 71.3 |
|
||||
| Spark TTS | ✅ | 0.5B | 1.2 | 66.0 | 1.98 | 57.3 | - | - |
|
||||
| CosyVoice2 | ✅ | 0.5B | 1.45 | 75.7 | 2.57 | 65.9 | 6.83 | 72.4 |
|
||||
| FireRedTTS2 | ✅ | 1.5B | 1.14 | 73.2 | 1.95 | 66.5 | - | - |
|
||||
| Index-TTS2 | ✅ | 1.5B | 1.03 | 76.5 | 2.23 | 70.6 | 7.12 | 75.5 |
|
||||
| VibeVoice-1.5B | ✅ | 1.5B | 1.16 | 74.4 | 3.04 | 68.9 | - | - |
|
||||
| VibeVoice-Realtime | ✅ | 0.5B | - | - | 2.05 | 63.3 | - | - |
|
||||
| HiggsAudio-v2 | ✅ | 3B | 1.50 | 74.0 | 2.44 | 67.7 | - | - |
|
||||
| VoxCPM | ✅ | 0.5B | 0.93 | 77.2 | 1.85 | 72.9 | 8.87 | 73.0 |
|
||||
| GLM-TTS | ✅ | 1.5B | 1.03 | 76.1 | - | - | - | - |
|
||||
| GLM-TTS RL | ✅ | 1.5B | 0.89 | 76.4 | - | - | - | - |
|
||||
| Fun-CosyVoice3-0.5B-2512 | ✅ | 0.5B | 1.21 | 78.0 | 2.24 | 71.8 | 6.71 | 75.8 |
|
||||
| Fun-CosyVoice3-0.5B-2512_RL | ✅ | 0.5B | 0.81 | 77.4 | 1.68 | 69.5 | 5.44 | 75.0 |
|
||||
|
||||
|
||||
## Install
|
||||
|
||||
@@ -87,26 +110,26 @@
|
||||
|
||||
### Model download
|
||||
|
||||
We strongly recommend that you download our pretrained `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
|
||||
We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
|
||||
|
||||
``` python
|
||||
# SDK模型下载
|
||||
# modelscope SDK model download
|
||||
from modelscope import snapshot_download
|
||||
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
|
||||
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
||||
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
|
||||
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
|
||||
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
|
||||
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
||||
```
|
||||
|
||||
``` sh
|
||||
# git模型下载,请确保已安装git lfs
|
||||
mkdir -p pretrained_models
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
|
||||
# for oversea users, huggingface SDK model download
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
|
||||
snapshot_download('FunAudioLLM/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
||||
snapshot_download('FunAudioLLM/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
|
||||
snapshot_download('FunAudioLLM/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
|
||||
snapshot_download('FunAudioLLM/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
|
||||
snapshot_download('FunAudioLLM/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
||||
```
|
||||
|
||||
Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.
|
||||
@@ -122,94 +145,28 @@ pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
### Basic Usage
|
||||
|
||||
We strongly recommend using `CosyVoice2-0.5B` for better performance.
|
||||
Follow the code below for detailed usage of each model.
|
||||
|
||||
``` python
|
||||
import sys
|
||||
sys.path.append('third_party/Matcha-TTS')
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
import torchaudio
|
||||
We strongly recommend using `Fun-CosyVoice3-0.5B` for better performance.
|
||||
Follow the code in `example.py` for detailed usage of each model.
|
||||
```sh
|
||||
python example.py
|
||||
```
|
||||
|
||||
#### CosyVoice2 Usage
|
||||
```python
|
||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, load_vllm=False, fp16=False)
|
||||
#### vLLM Usage
|
||||
CosyVoice2/3 now supports **vLLM 0.11.x+ (V1 engine)** and **vLLM 0.9.0 (legacy)**.
|
||||
Older vllm version(<0.9.0) do not support CosyVoice inference, and versions in between (e.g., 0.10.x) are not tested.
|
||||
|
||||
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
||||
# zero_shot usage
|
||||
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# save zero_shot spk for future usage
|
||||
assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
cosyvoice.save_spkinfo()
|
||||
|
||||
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# instruct usage
|
||||
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# bistream usage, you can use generator as input, this is useful when using text llm model as input
|
||||
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
|
||||
def text_generator():
|
||||
yield '收到好友从远方寄来的生日礼物,'
|
||||
yield '那份意外的惊喜与深深的祝福'
|
||||
yield '让我心中充满了甜蜜的快乐,'
|
||||
yield '笑容如花儿般绽放。'
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
```
|
||||
|
||||
#### CosyVoice2 vllm Usage
|
||||
If you want to use vllm for inference, please install `vllm==v0.9.0`. Older vllm version do not support CosyVoice2 inference.
|
||||
|
||||
Notice that `vllm==v0.9.0` has a lot of specific requirements, for example `torch==2.7.0`. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
|
||||
Notice that `vllm` has a lot of specific requirements. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
|
||||
|
||||
``` sh
|
||||
conda create -n cosyvoice_vllm --clone cosyvoice
|
||||
conda activate cosyvoice_vllm
|
||||
pip install vllm==v0.9.0 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
# for vllm==0.9.0
|
||||
pip install vllm==v0.9.0 transformers==4.51.3 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
# for vllm>=0.11.0
|
||||
pip install vllm==v0.11.0 transformers==4.57.1 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
python vllm_example.py
|
||||
```
|
||||
|
||||
#### CosyVoice Usage
|
||||
```python
|
||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
|
||||
# sft usage
|
||||
print(cosyvoice.list_available_spks())
|
||||
# change stream=True for chunk stream inference
|
||||
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
|
||||
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
|
||||
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
||||
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
# cross_lingual usage
|
||||
prompt_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
# vc usage
|
||||
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
||||
source_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
|
||||
for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
|
||||
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
||||
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
```
|
||||
|
||||
#### Start web demo
|
||||
|
||||
You can use our web demo page to get familiar with CosyVoice quickly.
|
||||
@@ -223,7 +180,7 @@ python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
|
||||
|
||||
#### Advanced Usage
|
||||
|
||||
For advanced users, we have provided training and inference scripts in `examples/libritts/cosyvoice/run.sh`.
|
||||
For advanced users, we have provided training and inference scripts in `examples/libritts`.
|
||||
|
||||
#### Build for deployment
|
||||
|
||||
@@ -242,6 +199,17 @@ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /o
|
||||
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
||||
```
|
||||
|
||||
#### Using Nvidia TensorRT-LLM for deployment
|
||||
|
||||
Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
|
||||
To quick start:
|
||||
|
||||
``` sh
|
||||
cd runtime/triton_trtllm
|
||||
docker compose up -d
|
||||
```
|
||||
For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
|
||||
|
||||
## Discussion & Communication
|
||||
|
||||
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 94 KiB After Width: | Height: | Size: 120 KiB |
Binary file not shown.
@@ -23,7 +23,7 @@ import torch
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
|
||||
|
||||
@@ -57,15 +57,9 @@ def main():
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
|
||||
try:
|
||||
model = CosyVoice(args.model_dir)
|
||||
except Exception:
|
||||
try:
|
||||
model = CosyVoice2(args.model_dir)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
model = AutoModel(model_dir=args.model_dir)
|
||||
|
||||
if not isinstance(model, CosyVoice2):
|
||||
if model.__class__.__name__ == 'CosyVoice':
|
||||
# 1. export llm text_encoder
|
||||
llm_text_encoder = model.model.llm.text_encoder
|
||||
script = get_optimized_script(llm_text_encoder)
|
||||
@@ -89,14 +83,16 @@ def main():
|
||||
script = get_optimized_script(flow_encoder.half())
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export flow_encoder')
|
||||
else:
|
||||
# 3. export flow encoder
|
||||
elif model.__class__.__name__ == 'CosyVoice2':
|
||||
# 1. export flow encoder
|
||||
flow_encoder = model.model.flow.encoder
|
||||
script = get_optimized_script(flow_encoder)
|
||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(flow_encoder.half())
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export flow_encoder')
|
||||
else:
|
||||
raise ValueError('unsupported model type')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -27,7 +27,7 @@ from tqdm import tqdm
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
|
||||
|
||||
@@ -58,13 +58,7 @@ def main():
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
try:
|
||||
model = CosyVoice(args.model_dir)
|
||||
except Exception:
|
||||
try:
|
||||
model = CosyVoice2(args.model_dir)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
model = AutoModel(model_dir=args.model_dir)
|
||||
|
||||
# 1. export flow decoder estimator
|
||||
estimator = model.model.flow.decoder.estimator
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torchaudio
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from tqdm import tqdm
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||
from cosyvoice.dataset.dataset import Dataset
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='inference with your model')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
||||
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
||||
parser.add_argument('--tts_text', required=True, help='tts input file')
|
||||
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
|
||||
parser.add_argument('--llm_model', required=True, help='llm model file')
|
||||
parser.add_argument('--flow_model', required=True, help='flow model file')
|
||||
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
||||
parser.add_argument('--gpu',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='gpu id for this rank, -1 for cpu')
|
||||
parser.add_argument('--mode',
|
||||
default='sft',
|
||||
choices=['sft', 'zero_shot'],
|
||||
help='inference mode')
|
||||
parser.add_argument('--result_dir', required=True, help='asr result file')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
||||
|
||||
# Init cosyvoice models from configs
|
||||
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
try:
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
|
||||
model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
|
||||
except Exception:
|
||||
try:
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
|
||||
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
||||
|
||||
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
||||
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
||||
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
||||
|
||||
sample_rate = configs['sample_rate']
|
||||
del configs
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
fn = os.path.join(args.result_dir, 'wav.scp')
|
||||
f = open(fn, 'w')
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(test_data_loader)):
|
||||
utts = batch["utts"]
|
||||
assert len(utts) == 1, "inference mode only support batchsize 1"
|
||||
text_token = batch["text_token"].to(device)
|
||||
text_token_len = batch["text_token_len"].to(device)
|
||||
tts_index = batch["tts_index"]
|
||||
tts_text_token = batch["tts_text_token"].to(device)
|
||||
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
||||
speech_token = batch["speech_token"].to(device)
|
||||
speech_token_len = batch["speech_token_len"].to(device)
|
||||
speech_feat = batch["speech_feat"].to(device)
|
||||
speech_feat_len = batch["speech_feat_len"].to(device)
|
||||
utt_embedding = batch["utt_embedding"].to(device)
|
||||
spk_embedding = batch["spk_embedding"].to(device)
|
||||
if args.mode == 'sft':
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
||||
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
||||
else:
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
||||
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
||||
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
||||
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
||||
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
||||
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
||||
tts_speeches = []
|
||||
for model_output in model.tts(**model_input):
|
||||
tts_speeches.append(model_output['tts_speech'])
|
||||
tts_speeches = torch.concat(tts_speeches, dim=1)
|
||||
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
||||
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
||||
torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
|
||||
f.write('{} {}\n'.format(tts_key, tts_fn))
|
||||
f.flush()
|
||||
f.close()
|
||||
logging.info('Result wav.scp saved in {}'.format(fn))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
|
||||
main()
|
||||
@@ -49,6 +49,7 @@ def get_args():
|
||||
parser.add_argument('--train_data', required=True, help='train data file')
|
||||
parser.add_argument('--cv_data', required=True, help='cv data file')
|
||||
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
|
||||
parser.add_argument('--onnx_path', required=False, help='onnx path, which is required for online feature extraction')
|
||||
parser.add_argument('--checkpoint', help='checkpoint model')
|
||||
parser.add_argument('--model_dir', required=True, help='save model dir')
|
||||
parser.add_argument('--tensorboard_dir',
|
||||
@@ -96,6 +97,7 @@ def get_args():
|
||||
@record
|
||||
def main():
|
||||
args = get_args()
|
||||
os.environ['onnx_path'] = args.onnx_path
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
# gan train has some special initialization logic
|
||||
@@ -104,12 +106,10 @@ def main():
|
||||
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
||||
if gan is True:
|
||||
override_dict.pop('hift')
|
||||
try:
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
|
||||
except Exception:
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides=override_dict)
|
||||
if args.qwen_pretrain_path is not None:
|
||||
override_dict['qwen_pretrain_path'] = args.qwen_pretrain_path
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides=override_dict)
|
||||
if gan is True:
|
||||
configs['train_conf'] = configs['train_conf_gan']
|
||||
configs['train_conf'].update(vars(args))
|
||||
|
||||
@@ -19,7 +19,7 @@ from hyperpyyaml import load_hyperpyyaml
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.class_utils import get_model_type
|
||||
|
||||
@@ -27,7 +27,6 @@ from cosyvoice.utils.class_utils import get_model_type
|
||||
class CosyVoice:
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||
self.instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
if not os.path.exists(model_dir):
|
||||
@@ -37,7 +36,7 @@ class CosyVoice:
|
||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
||||
assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
||||
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||
configs['feat_extractor'],
|
||||
'{}/campplus.onnx'.format(model_dir),
|
||||
@@ -67,9 +66,9 @@ class CosyVoice:
|
||||
spks = list(self.frontend.spk2info.keys())
|
||||
return spks
|
||||
|
||||
def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id):
|
||||
def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
|
||||
assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
|
||||
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '')
|
||||
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '')
|
||||
del model_input['text']
|
||||
del model_input['text_len']
|
||||
self.frontend.spk2info[zero_shot_spk_id] = model_input
|
||||
@@ -89,12 +88,14 @@ class CosyVoice:
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||
def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||
if self.__class__.__name__ == 'CosyVoice3' and '<|endofprompt|>' not in prompt_text + tts_text:
|
||||
logging.warning('<|endofprompt|> not found in CosyVoice3 inference, check your input text')
|
||||
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
||||
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
||||
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
|
||||
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
@@ -103,9 +104,9 @@ class CosyVoice:
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||
def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
|
||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
@@ -115,9 +116,7 @@ class CosyVoice:
|
||||
start_time = time.time()
|
||||
|
||||
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
||||
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
|
||||
if self.instruct is False:
|
||||
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
||||
assert self.__class__.__name__ == 'CosyVoice', 'inference_instruct is only implemented for CosyVoice!'
|
||||
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
||||
@@ -129,8 +128,8 @@ class CosyVoice:
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
||||
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
||||
def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
|
||||
model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate)
|
||||
start_time = time.time()
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
@@ -142,7 +141,6 @@ class CosyVoice:
|
||||
class CosyVoice2(CosyVoice):
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
||||
self.instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
if not os.path.exists(model_dir):
|
||||
@@ -160,9 +158,9 @@ class CosyVoice2(CosyVoice):
|
||||
'{}/spk2info.pt'.format(model_dir),
|
||||
configs['allowed_special'])
|
||||
self.sample_rate = configs['sample_rate']
|
||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||
load_jit, load_trt, fp16 = False, False, False
|
||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
|
||||
load_jit, load_trt, load_vllm, fp16 = False, False, False, False
|
||||
logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
@@ -178,13 +176,9 @@ class CosyVoice2(CosyVoice):
|
||||
self.fp16)
|
||||
del configs
|
||||
|
||||
def inference_instruct(self, *args, **kwargs):
|
||||
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
||||
|
||||
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
||||
def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
|
||||
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
@@ -192,3 +186,55 @@ class CosyVoice2(CosyVoice):
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
class CosyVoice3(CosyVoice2):
|
||||
|
||||
def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
if not os.path.exists(model_dir):
|
||||
model_dir = snapshot_download(model_dir)
|
||||
hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir)
|
||||
if not os.path.exists(hyper_yaml_path):
|
||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||
assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
|
||||
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||
configs['feat_extractor'],
|
||||
'{}/campplus.onnx'.format(model_dir),
|
||||
'{}/speech_tokenizer_v3.onnx'.format(model_dir),
|
||||
'{}/spk2info.pt'.format(model_dir),
|
||||
configs['allowed_special'])
|
||||
self.sample_rate = configs['sample_rate']
|
||||
if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
|
||||
load_trt, fp16 = False, False
|
||||
logging.warning('no cuda device, set load_trt/fp16 to False')
|
||||
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if load_vllm:
|
||||
self.model.load_vllm('{}/vllm'.format(model_dir))
|
||||
if load_trt:
|
||||
if self.fp16 is True:
|
||||
logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
del configs
|
||||
|
||||
|
||||
def AutoModel(**kwargs):
|
||||
if not os.path.exists(kwargs['model_dir']):
|
||||
kwargs['model_dir'] = snapshot_download(kwargs['model_dir'])
|
||||
if os.path.exists('{}/cosyvoice.yaml'.format(kwargs['model_dir'])):
|
||||
return CosyVoice(**kwargs)
|
||||
elif os.path.exists('{}/cosyvoice2.yaml'.format(kwargs['model_dir'])):
|
||||
return CosyVoice2(**kwargs)
|
||||
elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
|
||||
return CosyVoice3(**kwargs)
|
||||
else:
|
||||
raise TypeError('No valid model type found!')
|
||||
|
||||
@@ -20,19 +20,10 @@ import numpy as np
|
||||
import whisper
|
||||
from typing import Callable
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
import torchaudio
|
||||
import os
|
||||
import re
|
||||
import inflect
|
||||
try:
|
||||
import ttsfrd
|
||||
use_ttsfrd = True
|
||||
except ImportError:
|
||||
print("failed to import ttsfrd, use wetext instead")
|
||||
from wetext import Normalizer as ZhNormalizer
|
||||
from wetext import Normalizer as EnNormalizer
|
||||
use_ttsfrd = False
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.file_utils import logging, load_wav
|
||||
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
||||
|
||||
|
||||
@@ -56,21 +47,33 @@ class CosyVoiceFrontEnd:
|
||||
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
||||
"CPUExecutionProvider"])
|
||||
if os.path.exists(spk2info):
|
||||
self.spk2info = torch.load(spk2info, map_location=self.device)
|
||||
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
|
||||
else:
|
||||
self.spk2info = {}
|
||||
self.allowed_special = allowed_special
|
||||
self.use_ttsfrd = use_ttsfrd
|
||||
if self.use_ttsfrd:
|
||||
self.inflect_parser = inflect.engine()
|
||||
# NOTE compatible when no text frontend tool is avaliable
|
||||
try:
|
||||
import ttsfrd
|
||||
self.frd = ttsfrd.TtsFrontendEngine()
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
||||
'failed to initialize ttsfrd resource'
|
||||
self.frd.set_lang_type('pinyinvg')
|
||||
else:
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
|
||||
self.en_tn_model = EnNormalizer()
|
||||
self.inflect_parser = inflect.engine()
|
||||
self.text_frontend = 'ttsfrd'
|
||||
logging.info('use ttsfrd frontend')
|
||||
except:
|
||||
try:
|
||||
from wetext import Normalizer as ZhNormalizer
|
||||
from wetext import Normalizer as EnNormalizer
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
|
||||
self.en_tn_model = EnNormalizer()
|
||||
self.text_frontend = 'wetext'
|
||||
logging.info('use wetext frontend')
|
||||
except:
|
||||
self.text_frontend = ''
|
||||
logging.info('no frontend is avaliable')
|
||||
|
||||
|
||||
def _extract_text_token(self, text):
|
||||
if isinstance(text, Generator):
|
||||
@@ -89,7 +92,8 @@ class CosyVoiceFrontEnd:
|
||||
for i in range(text_token.shape[1]):
|
||||
yield text_token[:, i: i + 1]
|
||||
|
||||
def _extract_speech_token(self, speech):
|
||||
def _extract_speech_token(self, prompt_wav):
|
||||
speech = load_wav(prompt_wav, 16000)
|
||||
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
||||
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
||||
speech_token = self.speech_tokenizer_session.run(None,
|
||||
@@ -101,7 +105,8 @@ class CosyVoiceFrontEnd:
|
||||
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||
return speech_token, speech_token_len
|
||||
|
||||
def _extract_spk_embedding(self, speech):
|
||||
def _extract_spk_embedding(self, prompt_wav):
|
||||
speech = load_wav(prompt_wav, 16000)
|
||||
feat = kaldi.fbank(speech,
|
||||
num_mel_bins=80,
|
||||
dither=0,
|
||||
@@ -112,7 +117,8 @@ class CosyVoiceFrontEnd:
|
||||
embedding = torch.tensor([embedding]).to(self.device)
|
||||
return embedding
|
||||
|
||||
def _extract_speech_feat(self, speech):
|
||||
def _extract_speech_feat(self, prompt_wav):
|
||||
speech = load_wav(prompt_wav, 24000)
|
||||
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
||||
@@ -122,15 +128,19 @@ class CosyVoiceFrontEnd:
|
||||
if isinstance(text, Generator):
|
||||
logging.info('get tts_text generator, will skip text_normalize!')
|
||||
return [text]
|
||||
# NOTE skip text_frontend when ssml symbol in text
|
||||
if '<|' in text and '|>' in text:
|
||||
text_frontend = False
|
||||
if text_frontend is False or text == '':
|
||||
return [text] if split is True else text
|
||||
text = text.strip()
|
||||
if self.use_ttsfrd:
|
||||
if self.text_frontend == 'ttsfrd':
|
||||
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
||||
text = ''.join(texts)
|
||||
else:
|
||||
if contains_chinese(text):
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
if self.text_frontend == 'wetext':
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = text.replace("\n", "")
|
||||
text = replace_blank(text)
|
||||
text = replace_corner_mark(text)
|
||||
@@ -141,7 +151,8 @@ class CosyVoiceFrontEnd:
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
else:
|
||||
text = self.en_tn_model.normalize(text)
|
||||
if self.text_frontend == 'wetext':
|
||||
text = self.en_tn_model.normalize(text)
|
||||
text = spell_out_number(text, self.inflect_parser)
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
@@ -154,32 +165,31 @@ class CosyVoiceFrontEnd:
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
||||
return model_input
|
||||
|
||||
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
||||
def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id):
|
||||
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
||||
if zero_shot_spk_id == '':
|
||||
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
||||
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
||||
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
||||
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
|
||||
speech_token, speech_token_len = self._extract_speech_token(prompt_wav)
|
||||
if resample_rate == 24000:
|
||||
# cosyvoice2, force speech_feat % speech_token = 2
|
||||
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
||||
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
||||
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
||||
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
||||
embedding = self._extract_spk_embedding(prompt_wav)
|
||||
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
||||
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
||||
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
||||
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
||||
'llm_embedding': embedding, 'flow_embedding': embedding}
|
||||
else:
|
||||
model_input = self.spk2info[zero_shot_spk_id]
|
||||
model_input = {**self.spk2info[zero_shot_spk_id]}
|
||||
model_input['text'] = tts_text_token
|
||||
model_input['text_len'] = tts_text_token_len
|
||||
return model_input
|
||||
|
||||
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
||||
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
||||
def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id):
|
||||
model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id)
|
||||
# in cross lingual mode, we remove prompt in llm
|
||||
del model_input['prompt_text']
|
||||
del model_input['prompt_text_len']
|
||||
@@ -191,22 +201,21 @@ class CosyVoiceFrontEnd:
|
||||
model_input = self.frontend_sft(tts_text, spk_id)
|
||||
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
||||
del model_input['llm_embedding']
|
||||
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
||||
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text)
|
||||
model_input['prompt_text'] = instruct_text_token
|
||||
model_input['prompt_text_len'] = instruct_text_token_len
|
||||
return model_input
|
||||
|
||||
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
||||
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
||||
def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id):
|
||||
model_input = self.frontend_zero_shot(tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id)
|
||||
del model_input['llm_prompt_speech_token']
|
||||
del model_input['llm_prompt_speech_token_len']
|
||||
return model_input
|
||||
|
||||
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
||||
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
||||
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
||||
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
||||
def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):
|
||||
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav)
|
||||
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav)
|
||||
embedding = self._extract_spk_embedding(prompt_wav)
|
||||
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
||||
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
||||
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
||||
|
||||
@@ -38,9 +38,6 @@ class CosyVoiceModel:
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
||||
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
||||
self.token_overlap_len = 20
|
||||
@@ -63,14 +60,15 @@ class CosyVoiceModel:
|
||||
self.mel_overlap_dict = {}
|
||||
self.flow_cache_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
self.silent_tokens = []
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device, weights_only=True), strict=True)
|
||||
self.llm.to(self.device).eval()
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device, weights_only=True), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
# in case hift_model is a hifigan model
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device, weights_only=True).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
@@ -101,26 +99,33 @@ class CosyVoiceModel:
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
cur_silent_token_num, max_silent_token_num = 0, 5
|
||||
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
||||
if isinstance(text, Generator):
|
||||
assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
|
||||
for i in self.llm.inference_bistream(text=text,
|
||||
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
|
||||
token_generator = self.llm.inference_bistream(text=text,
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device))
|
||||
else:
|
||||
token_generator = self.llm.inference(text=text.to(self.device),
|
||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device)):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
else:
|
||||
for i in self.llm.inference(text=text.to(self.device),
|
||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device),
|
||||
uuid=uuid):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
embedding=llm_embedding.to(self.device),
|
||||
uuid=uuid)
|
||||
for i in token_generator:
|
||||
if i in self.silent_tokens:
|
||||
cur_silent_token_num += 1
|
||||
if cur_silent_token_num > max_silent_token_num:
|
||||
continue
|
||||
else:
|
||||
cur_silent_token_num = 0
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
self.llm_end_dict[uuid] = True
|
||||
|
||||
def vc_job(self, source_speech_token, uuid):
|
||||
@@ -129,7 +134,7 @@ class CosyVoiceModel:
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
|
||||
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
@@ -249,11 +254,12 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
# NOTE must matching training static_chunk_size
|
||||
self.token_hop_len = 25
|
||||
# NOTE increase token_hop_len incrementally to avoid duplicate inference
|
||||
self.token_max_hop_len = 4 * self.token_hop_len
|
||||
self.stream_scale_factor = 2
|
||||
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||
# hift cache
|
||||
self.mel_cache_len = 8
|
||||
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||
@@ -266,6 +272,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
self.silent_tokens = []
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
@@ -284,7 +291,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
@@ -350,6 +357,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
stream=stream,
|
||||
finalize=False)
|
||||
token_offset += this_token_hop_len
|
||||
self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
|
||||
break
|
||||
@@ -384,3 +392,59 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
|
||||
class CosyVoice3Model(CosyVoice2Model):
|
||||
|
||||
def __init__(self,
|
||||
llm: torch.nn.Module,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
# NOTE must matching training static_chunk_size
|
||||
self.token_hop_len = 25
|
||||
# NOTE increase token_hop_len incrementally to avoid duplicate inference
|
||||
self.token_max_hop_len = 4 * self.token_hop_len
|
||||
self.stream_scale_factor = 2
|
||||
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||
# rtf and decoding related
|
||||
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||
self.lock = threading.Lock()
|
||||
# dict used to store session related variable
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
# FSQ silent and breath token
|
||||
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
streaming=stream,
|
||||
finalize=finalize)
|
||||
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||
# append mel cache
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
hift_cache_mel = self.hift_cache_dict[uuid]['mel']
|
||||
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||
self.hift_cache_dict[uuid]['mel'] = tts_mel
|
||||
else:
|
||||
self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
|
||||
if speed != 1.0:
|
||||
assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
|
||||
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||
tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
|
||||
tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
|
||||
self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
|
||||
return tts_speech
|
||||
|
||||
@@ -145,7 +145,11 @@ def Dataset(data_list_file,
|
||||
shuffle=shuffle,
|
||||
partition=partition)
|
||||
# map partial arg to padding func
|
||||
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
|
||||
for i in range(1, len(data_pipeline)):
|
||||
if data_pipeline[i].func.__name__ == 'compute_fbank' and gan is True:
|
||||
data_pipeline[i] = partial(data_pipeline[i], token_mel_ratio=0)
|
||||
if data_pipeline[i].func.__name__ == 'padding':
|
||||
data_pipeline[i] = partial(data_pipeline[i], gan=gan, dpo=dpo)
|
||||
for func in data_pipeline:
|
||||
dataset = Processor(dataset, func, mode=mode)
|
||||
return dataset
|
||||
|
||||
@@ -16,17 +16,19 @@ import random
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import whisper
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch.nn.functional as F
|
||||
import pyworld as pw
|
||||
|
||||
from cosyvoice.utils.onnx import embedding_extractor, online_feature
|
||||
|
||||
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
||||
|
||||
|
||||
def parquet_opener(data, mode='train', tts_data={}):
|
||||
def parquet_opener(data, mode='train'):
|
||||
""" Give url or local file, return file descriptor
|
||||
Inplace operation.
|
||||
|
||||
@@ -44,12 +46,8 @@ def parquet_opener(data, mode='train', tts_data={}):
|
||||
df = df.to_pandas()
|
||||
for i in range(len(df)):
|
||||
sample.update(dict(df.loc[i]))
|
||||
if mode == 'train':
|
||||
# NOTE do not return sample directly, must initialize a new dict
|
||||
yield {**sample}
|
||||
else:
|
||||
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
||||
yield {**sample, 'tts_index': index, 'tts_text': text}
|
||||
# NOTE do not return sample directly, must initialize a new dict
|
||||
yield {**sample}
|
||||
except Exception as ex:
|
||||
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
||||
|
||||
@@ -96,9 +94,9 @@ def filter(data,
|
||||
continue
|
||||
if len(sample['text_token']) > token_max_length:
|
||||
continue
|
||||
if len(sample['speech_token']) == 0:
|
||||
if online_feature is False and len(sample['speech_token']) == 0:
|
||||
continue
|
||||
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
|
||||
if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
|
||||
continue
|
||||
if num_frames != 0:
|
||||
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
||||
@@ -159,7 +157,7 @@ def truncate(data, truncate_length=24576, mode='train'):
|
||||
|
||||
def compute_fbank(data,
|
||||
feat_extractor,
|
||||
token_mel_ratio=0,
|
||||
num_frames=-1,
|
||||
mode='train'):
|
||||
""" Extract fbank
|
||||
|
||||
@@ -174,14 +172,28 @@ def compute_fbank(data,
|
||||
assert 'speech' in sample
|
||||
assert 'utt' in sample
|
||||
assert 'text_token' in sample
|
||||
waveform = sample['speech']
|
||||
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
||||
if token_mel_ratio != 0:
|
||||
# trim to align speech_token and speech_feat
|
||||
token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
|
||||
feat = feat[:token_mel_ratio * token_len]
|
||||
sample["speech_token"] = sample["speech_token"][:token_len]
|
||||
sample['speech_feat'] = feat
|
||||
# NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
|
||||
if num_frames != -1:
|
||||
index = int(np.ceil(sample['speech'].shape[1] / num_frames))
|
||||
sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
|
||||
sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
|
||||
yield sample
|
||||
|
||||
|
||||
def compute_whisper_fbank(data, num_frames=-1, mode='train'):
|
||||
""" Extract whisper fbank
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
if num_frames != -1:
|
||||
assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
|
||||
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
|
||||
sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
|
||||
yield sample
|
||||
|
||||
|
||||
@@ -220,8 +232,13 @@ def parse_embedding(data, normalize, mode='train'):
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
||||
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
||||
if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
|
||||
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
|
||||
embedding = embedding_extractor.inference(sample['speech_16k'])
|
||||
sample['spk_embedding'] = sample['utt_embedding'] = embedding
|
||||
else:
|
||||
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
||||
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
||||
if normalize:
|
||||
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
||||
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
||||
@@ -242,6 +259,8 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
||||
for sample in data:
|
||||
assert 'text' in sample
|
||||
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
||||
if 'instruct' in sample:
|
||||
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
|
||||
yield sample
|
||||
|
||||
|
||||
@@ -256,13 +275,14 @@ def shuffle(data, shuffle_size=10000, mode='train'):
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
buf = []
|
||||
yield_size = int(shuffle_size / 2)
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) >= shuffle_size:
|
||||
random.shuffle(buf)
|
||||
for x in buf:
|
||||
for x in buf[:yield_size]:
|
||||
yield x
|
||||
buf = []
|
||||
buf = buf[yield_size:]
|
||||
# The sample left over
|
||||
random.shuffle(buf)
|
||||
for x in buf:
|
||||
@@ -368,65 +388,42 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
||||
"""
|
||||
for sample in data:
|
||||
assert isinstance(sample, list)
|
||||
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
||||
dtype=torch.int32)
|
||||
order = torch.argsort(speech_feat_len, descending=True)
|
||||
|
||||
utts = [sample[i]['utt'] for i in order]
|
||||
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
||||
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
||||
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
||||
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
||||
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
||||
speech_token = pad_sequence(speech_token,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
speech_feat = [sample[i]['speech_feat'] for i in order]
|
||||
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
||||
speech_feat = pad_sequence(speech_feat,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
text = [sample[i]['text'] for i in order]
|
||||
order = torch.argsort(torch.tensor([x['speech'].size(1) for x in sample], dtype=torch.int32), descending=True)
|
||||
batch = {}
|
||||
batch['utts'] = [sample[i]['utt'] for i in order]
|
||||
batch['text'] = [sample[i]['text'] for i in order]
|
||||
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
||||
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
||||
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
||||
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
||||
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
||||
batch = {
|
||||
"utts": utts,
|
||||
"speech": speech,
|
||||
"speech_len": speech_len,
|
||||
"speech_token": speech_token,
|
||||
"speech_token_len": speech_token_len,
|
||||
"speech_feat": speech_feat,
|
||||
"speech_feat_len": speech_feat_len,
|
||||
"text": text,
|
||||
"text_token": text_token,
|
||||
"text_token_len": text_token_len,
|
||||
"utt_embedding": utt_embedding,
|
||||
"spk_embedding": spk_embedding,
|
||||
}
|
||||
batch['text_token_len'] = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
||||
batch['text_token'] = pad_sequence(text_token, batch_first=True, padding_value=0)
|
||||
speech_feat = [sample[i]['speech_feat'] for i in order]
|
||||
batch['speech_feat_len'] = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
||||
batch['speech_feat'] = pad_sequence(speech_feat, batch_first=True, padding_value=0)
|
||||
batch['utt_embedding'] = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
||||
batch['spk_embedding'] = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
||||
if torch.tensor(['instruct_token' in sample[i] for i in order]).all():
|
||||
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
|
||||
batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
|
||||
batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
|
||||
if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
|
||||
whisper_feat = [sample[i]['whisper_feat'] for i in order]
|
||||
batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
|
||||
batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
|
||||
if torch.tensor(['speech_token' in sample[i] for i in order]).all():
|
||||
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
||||
batch['speech_token_len'] = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
||||
batch['speech_token'] = pad_sequence(speech_token, batch_first=True, padding_value=0)
|
||||
if gan is True:
|
||||
# in gan train, we need pitch_feat
|
||||
# in gan train, we need speech/pitch_feat
|
||||
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
||||
batch['speech_len'] = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
||||
batch['speech'] = pad_sequence(speech, batch_first=True, padding_value=0)
|
||||
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
||||
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
||||
pitch_feat = pad_sequence(pitch_feat,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
batch["pitch_feat"] = pitch_feat
|
||||
batch["pitch_feat_len"] = pitch_feat_len
|
||||
else:
|
||||
# only gan train needs speech, delete it to save memory
|
||||
del batch["speech"]
|
||||
del batch["speech_len"]
|
||||
batch['pitch_feat_len'] = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
||||
batch['pitch_feat'] = pad_sequence(pitch_feat, batch_first=True, padding_value=0)
|
||||
if dpo is True:
|
||||
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
|
||||
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
|
||||
reject_speech_token = pad_sequence(reject_speech_token,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
batch['reject_speech_token'] = reject_speech_token
|
||||
batch['reject_speech_token_len'] = reject_speech_token_len
|
||||
batch['reject_speech_token_len'] = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
|
||||
batch['reject_speech_token'] = pad_sequence(reject_speech_token, batch_first=True, padding_value=0)
|
||||
if use_spk_embedding is True:
|
||||
batch["embedding"] = batch["spk_embedding"]
|
||||
else:
|
||||
|
||||
176
cosyvoice/flow/DiT/dit.py
Normal file
176
cosyvoice/flow/DiT/dit.py
Normal file
@@ -0,0 +1,176 @@
|
||||
|
||||
"""
|
||||
ein notation:
|
||||
b - batch
|
||||
n - sequence
|
||||
nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||
from cosyvoice.flow.DiT.modules import (
|
||||
TimestepEmbedding,
|
||||
ConvNeXtV2Block,
|
||||
CausalConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
AdaLayerNormZero_Final,
|
||||
precompute_freqs_cis,
|
||||
get_pos_embed_indices,
|
||||
)
|
||||
|
||||
|
||||
# Text embedding
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
|
||||
if conv_layers > 0:
|
||||
self.extra_modeling = True
|
||||
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(
|
||||
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
||||
)
|
||||
else:
|
||||
self.extra_modeling = False
|
||||
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
||||
batch, text_len = text.shape[0], text.shape[1]
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
text = F.pad(text, (0, seq_len - text_len), value=0)
|
||||
|
||||
if drop_text: # cfg for text
|
||||
text = torch.zeros_like(text)
|
||||
|
||||
text = self.text_embed(text) # b n -> b n d
|
||||
|
||||
# possible extra modeling
|
||||
if self.extra_modeling:
|
||||
# sinus pos emb
|
||||
batch_start = torch.zeros((batch,), dtype=torch.long)
|
||||
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
||||
text_pos_embed = self.freqs_cis[pos_idx]
|
||||
text = text + text_pos_embed
|
||||
|
||||
# convnextv2 blocks
|
||||
text = self.text_blocks(text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
# noised input audio and context mixing embedding
|
||||
|
||||
|
||||
class InputEmbedding(nn.Module):
|
||||
def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
|
||||
super().__init__()
|
||||
spk_dim = 0 if spk_dim is None else spk_dim
|
||||
self.spk_dim = spk_dim
|
||||
self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
|
||||
self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"],
|
||||
cond: float["b n d"],
|
||||
text_embed: float["b n d"],
|
||||
spks: float["b d"],
|
||||
):
|
||||
to_cat = [x, cond, text_embed]
|
||||
if self.spk_dim > 0:
|
||||
spks = repeat(spks, "b c -> b t c", t=x.shape[1])
|
||||
to_cat.append(spks)
|
||||
|
||||
x = self.proj(torch.cat(to_cat, dim=-1))
|
||||
x = self.conv_pos_embed(x) + x
|
||||
return x
|
||||
|
||||
|
||||
# Transformer backbone using DiT blocks
|
||||
|
||||
|
||||
class DiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth=8,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.1,
|
||||
ff_mult=4,
|
||||
mel_dim=80,
|
||||
mu_dim=None,
|
||||
long_skip_connection=False,
|
||||
spk_dim=None,
|
||||
out_channels=None,
|
||||
static_chunk_size=50,
|
||||
num_decoding_left_chunks=2
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_embed = TimestepEmbedding(dim)
|
||||
if mu_dim is None:
|
||||
mu_dim = mel_dim
|
||||
self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
|
||||
|
||||
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
||||
)
|
||||
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
||||
|
||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
self.out_channels = out_channels
|
||||
self.static_chunk_size = static_chunk_size
|
||||
self.num_decoding_left_chunks = num_decoding_left_chunks
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||
x = x.transpose(1, 2)
|
||||
mu = mu.transpose(1, 2)
|
||||
cond = cond.transpose(1, 2)
|
||||
spks = spks.unsqueeze(dim=1)
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if t.ndim == 0:
|
||||
t = t.repeat(batch)
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(t)
|
||||
x = self.input_embed(x, cond, mu, spks.squeeze(1))
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
|
||||
if self.long_skip_connection is not None:
|
||||
residual = x
|
||||
|
||||
if streaming is True:
|
||||
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
|
||||
else:
|
||||
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, t, mask=attn_mask.bool(), rope=rope)
|
||||
|
||||
if self.long_skip_connection is not None:
|
||||
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
||||
|
||||
x = self.norm_out(x, t)
|
||||
output = self.proj_out(x).transpose(1, 2)
|
||||
return output
|
||||
616
cosyvoice/flow/DiT/modules.py
Normal file
616
cosyvoice/flow/DiT/modules.py
Normal file
@@ -0,0 +1,616 @@
|
||||
|
||||
"""
|
||||
ein notation:
|
||||
b - batch
|
||||
n - sequence
|
||||
nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||
|
||||
|
||||
# raw wav to mel spec
|
||||
class MelSpec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
filter_length=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
target_sample_rate=24_000,
|
||||
normalize=False,
|
||||
power=1,
|
||||
norm=None,
|
||||
center=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_mel_channels = n_mel_channels
|
||||
|
||||
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=target_sample_rate,
|
||||
n_fft=filter_length,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
n_mels=n_mel_channels,
|
||||
power=power,
|
||||
center=center,
|
||||
normalized=normalize,
|
||||
norm=norm,
|
||||
)
|
||||
|
||||
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
||||
|
||||
def forward(self, inp):
|
||||
if len(inp.shape) == 3:
|
||||
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
|
||||
|
||||
assert len(inp.shape) == 2
|
||||
|
||||
if self.dummy.device != inp.device:
|
||||
self.to(inp.device)
|
||||
|
||||
mel = self.mel_stft(inp)
|
||||
mel = mel.clamp(min=1e-5).log()
|
||||
return mel
|
||||
|
||||
|
||||
# sinusoidal position embedding
|
||||
|
||||
|
||||
class SinusPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
# convolutional position embedding
|
||||
|
||||
|
||||
class ConvPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim, kernel_size=31, groups=16):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 != 0
|
||||
self.conv1d = nn.Sequential(
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||
nn.Mish(),
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||
if mask is not None:
|
||||
mask = mask[..., None]
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.conv1d(x)
|
||||
out = x.permute(0, 2, 1)
|
||||
|
||||
if mask is not None:
|
||||
out = out.masked_fill(~mask, 0.0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CausalConvPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim, kernel_size=31, groups=16):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 != 0
|
||||
self.kernel_size = kernel_size
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
|
||||
nn.Mish(),
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||
if mask is not None:
|
||||
mask = mask[..., None]
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
|
||||
x = self.conv1(x)
|
||||
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
|
||||
x = self.conv2(x)
|
||||
out = x.permute(0, 2, 1)
|
||||
|
||||
if mask is not None:
|
||||
out = out.masked_fill(~mask, 0.0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# rotary positional embedding related
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
freqs_cos = torch.cos(freqs) # real part
|
||||
freqs_sin = torch.sin(freqs) # imaginary part
|
||||
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||
|
||||
|
||||
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
||||
# length = length if isinstance(length, int) else length.max()
|
||||
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
||||
pos = (
|
||||
start.unsqueeze(1)
|
||||
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
||||
)
|
||||
# avoid extra long error.
|
||||
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
||||
return pos
|
||||
|
||||
|
||||
# Global Response Normalization layer (Instance Normalization ?)
|
||||
|
||||
|
||||
class GRN(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * Nx) + self.beta + x
|
||||
|
||||
|
||||
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
||||
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
||||
|
||||
|
||||
class ConvNeXtV2Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
intermediate_dim: int,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
padding = (dilation * (7 - 1)) // 2
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
||||
) # depthwise conv
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.grn = GRN(intermediate_dim)
|
||||
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = x.transpose(1, 2) # b n d -> b d n
|
||||
x = self.dwconv(x)
|
||||
x = x.transpose(1, 2) # b d n -> b n d
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.grn(x)
|
||||
x = self.pwconv2(x)
|
||||
return residual + x
|
||||
|
||||
|
||||
# AdaLayerNormZero
|
||||
# return with modulated x for attn input, and params for later mlp modulation
|
||||
|
||||
|
||||
class AdaLayerNormZero(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(dim, dim * 6)
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb=None):
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
||||
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
# AdaLayerNormZero for final layer
|
||||
# return only with modulated x for attn input, cuz no more mlp modulation
|
||||
|
||||
|
||||
class AdaLayerNormZero_Final(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(dim, dim * 2)
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
# FeedForward
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
activation = nn.GELU(approximate=approximate)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
||||
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.ff(x)
|
||||
|
||||
|
||||
# Attention with possible joint part
|
||||
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
processor: JointAttnProcessor | AttnProcessor,
|
||||
dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
context_dim: Optional[int] = None, # if not None -> joint attention
|
||||
context_pre_only=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
self.processor = processor
|
||||
|
||||
self.dim = dim
|
||||
self.heads = heads
|
||||
self.inner_dim = dim_head * heads
|
||||
self.dropout = dropout
|
||||
|
||||
self.context_dim = context_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
self.to_q = nn.Linear(dim, self.inner_dim)
|
||||
self.to_k = nn.Linear(dim, self.inner_dim)
|
||||
self.to_v = nn.Linear(dim, self.inner_dim)
|
||||
|
||||
if self.context_dim is not None:
|
||||
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
||||
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
||||
if self.context_pre_only is not None:
|
||||
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b n d"] = None, # context c # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.Tensor:
|
||||
if c is not None:
|
||||
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
||||
else:
|
||||
return self.processor(self, x, mask=mask, rope=rope)
|
||||
|
||||
|
||||
# Attention processor
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(x)
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
|
||||
# apply rotary position embedding
|
||||
if rope is not None:
|
||||
freqs, xpos_scale = rope
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
|
||||
# attention
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = mask
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
x = x.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
x = attn.to_out[0](x)
|
||||
# dropout
|
||||
x = attn.to_out[1](x)
|
||||
|
||||
if mask is not None:
|
||||
if mask.dim() == 2:
|
||||
mask = mask.unsqueeze(-1)
|
||||
else:
|
||||
mask = mask[:, 0, -1].unsqueeze(-1)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Joint Attention processor for MM-DiT
|
||||
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||
|
||||
|
||||
class JointAttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.FloatTensor:
|
||||
residual = x
|
||||
|
||||
batch_size = c.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(x)
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
|
||||
# `context` projections.
|
||||
c_query = attn.to_q_c(c)
|
||||
c_key = attn.to_k_c(c)
|
||||
c_value = attn.to_v_c(c)
|
||||
|
||||
# apply rope for context and noised input independently
|
||||
if rope is not None:
|
||||
freqs, xpos_scale = rope
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
if c_rope is not None:
|
||||
freqs, xpos_scale = c_rope
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
||||
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
||||
|
||||
# attention
|
||||
query = torch.cat([query, c_query], dim=1)
|
||||
key = torch.cat([key, c_key], dim=1)
|
||||
value = torch.cat([value, c_value], dim=1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
x = x.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
x, c = (
|
||||
x[:, : residual.shape[1]],
|
||||
x[:, residual.shape[1]:],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
x = attn.to_out[0](x)
|
||||
# dropout
|
||||
x = attn.to_out[1](x)
|
||||
if not attn.context_pre_only:
|
||||
c = attn.to_out_c(c)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(-1)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
||||
|
||||
return x, c
|
||||
|
||||
|
||||
# DiT Block
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.attn_norm = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor=AttnProcessor(),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
|
||||
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
||||
# pre-norm & modulation for attention input
|
||||
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
||||
|
||||
# attention
|
||||
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
||||
|
||||
# process attention output for input x
|
||||
x = x + gate_msa.unsqueeze(1) * attn_output
|
||||
|
||||
ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
ff_output = self.ff(ff_norm)
|
||||
x = x + gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# MMDiT Block https://arxiv.org/abs/2403.03206
|
||||
|
||||
|
||||
class MMDiTBlock(nn.Module):
|
||||
r"""
|
||||
modified from diffusers/src/diffusers/models/attention.py
|
||||
|
||||
notes.
|
||||
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
||||
_x: noised input related. (right part)
|
||||
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
||||
"""
|
||||
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
|
||||
super().__init__()
|
||||
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
||||
self.attn_norm_x = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor=JointAttnProcessor(),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
context_dim=dim,
|
||||
context_pre_only=context_pre_only,
|
||||
)
|
||||
|
||||
if not context_pre_only:
|
||||
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
else:
|
||||
self.ff_norm_c = None
|
||||
self.ff_c = None
|
||||
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
|
||||
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
|
||||
# pre-norm & modulation for attention input
|
||||
if self.context_pre_only:
|
||||
norm_c = self.attn_norm_c(c, t)
|
||||
else:
|
||||
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
|
||||
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
|
||||
|
||||
# attention
|
||||
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
|
||||
|
||||
# process attention output for context c
|
||||
if self.context_pre_only:
|
||||
c = None
|
||||
else: # if not last layer
|
||||
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
||||
|
||||
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
c_ff_output = self.ff_c(norm_c)
|
||||
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
||||
|
||||
# process attention output for input x
|
||||
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
||||
|
||||
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
||||
x_ff_output = self.ff_x(norm_x)
|
||||
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
||||
|
||||
return c, x
|
||||
|
||||
|
||||
# time step conditioning embedding
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, dim, freq_embed_dim=256):
|
||||
super().__init__()
|
||||
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
||||
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, timestep: float["b"]): # noqa: F821
|
||||
time_hidden = self.time_embed(timestep)
|
||||
time_hidden = time_hidden.to(timestep.dtype)
|
||||
time = self.time_mlp(time_hidden) # b d
|
||||
return time
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os, logging
|
||||
import random
|
||||
from typing import Dict, Optional
|
||||
import torch
|
||||
@@ -19,6 +19,7 @@ import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from omegaconf import DictConfig
|
||||
from cosyvoice.utils.mask import make_pad_mask
|
||||
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
|
||||
|
||||
|
||||
class MaskedDiffWithXvec(torch.nn.Module):
|
||||
@@ -37,14 +38,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
||||
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
||||
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.decoder_conf = decoder_conf
|
||||
self.mel_feat_conf = mel_feat_conf
|
||||
self.vocab_size = vocab_size
|
||||
self.output_type = output_type
|
||||
self.input_frame_rate = input_frame_rate
|
||||
@@ -165,14 +163,11 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
||||
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
||||
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.decoder_conf = decoder_conf
|
||||
self.mel_feat_conf = mel_feat_conf
|
||||
self.vocab_size = vocab_size
|
||||
self.output_type = output_type
|
||||
self.input_frame_rate = input_frame_rate
|
||||
@@ -185,14 +180,19 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
self.only_mask_loss = only_mask_loss
|
||||
self.token_mel_ratio = token_mel_ratio
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
if online_feature is True:
|
||||
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
token = batch['speech_token'].to(device)
|
||||
token_len = batch['speech_token_len'].to(device)
|
||||
if 'speech_token' not in batch:
|
||||
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||
else:
|
||||
token = batch['speech_token'].to(device)
|
||||
token_len = batch['speech_token_len'].to(device)
|
||||
feat = batch['speech_feat'].to(device)
|
||||
feat_len = batch['speech_feat_len'].to(device)
|
||||
embedding = batch['embedding'].to(device)
|
||||
@@ -279,3 +279,165 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat.float(), None
|
||||
|
||||
|
||||
class CausalMaskedDiffWithDiT(torch.nn.Module):
|
||||
def __init__(self,
|
||||
input_size: int = 512,
|
||||
output_size: int = 80,
|
||||
spk_embed_dim: int = 192,
|
||||
output_type: str = "mel",
|
||||
vocab_size: int = 4096,
|
||||
input_frame_rate: int = 50,
|
||||
only_mask_loss: bool = True,
|
||||
token_mel_ratio: int = 2,
|
||||
pre_lookahead_len: int = 3,
|
||||
pre_lookahead_layer: torch.nn.Module = None,
|
||||
decoder: torch.nn.Module = None,
|
||||
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.decoder_conf = decoder_conf
|
||||
self.vocab_size = vocab_size
|
||||
self.output_type = output_type
|
||||
self.input_frame_rate = input_frame_rate
|
||||
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
||||
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
self.pre_lookahead_layer = pre_lookahead_layer
|
||||
self.decoder = decoder
|
||||
self.only_mask_loss = only_mask_loss
|
||||
self.token_mel_ratio = token_mel_ratio
|
||||
if online_feature is True:
|
||||
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
if 'speech_token' not in batch:
|
||||
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||
else:
|
||||
token = batch['speech_token'].to(device)
|
||||
token_len = batch['speech_token_len'].to(device)
|
||||
feat = batch['speech_feat'].to(device)
|
||||
feat_len = batch['speech_feat_len'].to(device)
|
||||
embedding = batch['embedding'].to(device)
|
||||
|
||||
# NOTE unified training, static_chunk_size > 0 or = 0
|
||||
streaming = True if random.random() < 0.5 else False
|
||||
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
|
||||
# text encode
|
||||
h = self.pre_lookahead_layer(token)
|
||||
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
|
||||
mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros(feat.shape, device=token.device)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
continue
|
||||
index = random.randint(0, int(0.3 * j))
|
||||
conds[i, :index] = feat[i, :index]
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
feat.transpose(1, 2).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
h.transpose(1, 2).contiguous(),
|
||||
embedding,
|
||||
cond=conds,
|
||||
streaming=streaming,
|
||||
)
|
||||
return {'loss': loss}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self,
|
||||
token,
|
||||
token_len,
|
||||
prompt_token,
|
||||
prompt_token_len,
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
streaming,
|
||||
finalize):
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
|
||||
# text encode
|
||||
if finalize is True:
|
||||
h = self.pre_lookahead_layer(token)
|
||||
else:
|
||||
h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
|
||||
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat, _ = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
streaming=streaming
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat.float(), None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
|
||||
model = configs['flow']
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_len = 10 * model.decoder.estimator.static_chunk_size
|
||||
chunk_size = model.decoder.estimator.static_chunk_size
|
||||
context_size = model.pre_lookahead_layer.pre_lookahead_len
|
||||
token = torch.randint(0, 6561, size=(1, max_len)).to(device)
|
||||
token_len = torch.tensor([max_len]).to(device)
|
||||
prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
|
||||
prompt_token_len = torch.tensor([chunk_size]).to(device)
|
||||
prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
|
||||
prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
|
||||
prompt_embedding = torch.rand(1, 192).to(device)
|
||||
pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
|
||||
for i in range(0, max_len, chunk_size):
|
||||
finalize = True if i + chunk_size + context_size >= max_len else False
|
||||
pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
|
||||
prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
|
||||
pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
|
||||
print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())
|
||||
|
||||
@@ -91,12 +91,13 @@ class ConditionalCFM(BASECFM):
|
||||
sol = []
|
||||
|
||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
# NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
for step in range(1, len(t_span)):
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
x_in[:] = x
|
||||
@@ -173,8 +174,7 @@ class ConditionalCFM(BASECFM):
|
||||
|
||||
# random timestep
|
||||
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
||||
|
||||
# sample noise p(x_0)
|
||||
z = torch.randn_like(x1)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ try:
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
except ImportError:
|
||||
from torch.nn.utils import weight_norm
|
||||
from cosyvoice.transformer.convolution import CausalConv1d
|
||||
|
||||
|
||||
class ConvRNNF0Predictor(nn.Module):
|
||||
@@ -56,3 +57,47 @@ class ConvRNNF0Predictor(nn.Module):
|
||||
x = self.condnet(x)
|
||||
x = x.transpose(1, 2)
|
||||
return torch.abs(self.classifier(x).squeeze(-1))
|
||||
|
||||
|
||||
class CausalConvRNNF0Predictor(nn.Module):
|
||||
def __init__(self,
|
||||
num_class: int = 1,
|
||||
in_channels: int = 80,
|
||||
cond_channels: int = 512
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_class = num_class
|
||||
self.condnet = nn.Sequential(
|
||||
weight_norm(
|
||||
CausalConv1d(in_channels, cond_channels, kernel_size=4, causal_type='right')
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||
),
|
||||
nn.ELU(),
|
||||
)
|
||||
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
||||
|
||||
def forward(self, x: torch.Tensor, finalize: bool = True) -> torch.Tensor:
|
||||
if finalize is True:
|
||||
x = self.condnet[0](x)
|
||||
else:
|
||||
x = self.condnet[0](x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:])
|
||||
for i in range(1, len(self.condnet)):
|
||||
x = self.condnet[i](x)
|
||||
x = x.transpose(1, 2)
|
||||
return torch.abs(self.classifier(x).squeeze(-1))
|
||||
|
||||
@@ -28,7 +28,7 @@ try:
|
||||
except ImportError:
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.distributions.uniform import Uniform
|
||||
|
||||
from cosyvoice.transformer.convolution import CausalConv1d, CausalConv1dDownSample, CausalConv1dUpsample
|
||||
from cosyvoice.transformer.activation import Snake
|
||||
from cosyvoice.utils.common import get_padding
|
||||
from cosyvoice.utils.common import init_weights
|
||||
@@ -50,8 +50,10 @@ class ResBlock(torch.nn.Module):
|
||||
channels: int = 512,
|
||||
kernel_size: int = 3,
|
||||
dilations: List[int] = [1, 3, 5],
|
||||
causal: bool = False,
|
||||
):
|
||||
super(ResBlock, self).__init__()
|
||||
self.causal = causal
|
||||
self.convs1 = nn.ModuleList()
|
||||
self.convs2 = nn.ModuleList()
|
||||
|
||||
@@ -64,7 +66,14 @@ class ResBlock(torch.nn.Module):
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation,
|
||||
padding=get_padding(kernel_size, dilation)
|
||||
padding=get_padding(kernel_size, dilation)) if causal is False else
|
||||
CausalConv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation,
|
||||
causal_type='left'
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -76,7 +85,14 @@ class ResBlock(torch.nn.Module):
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)
|
||||
padding=get_padding(kernel_size, 1)) if causal is False else
|
||||
CausalConv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
causal_type='left'
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -139,11 +155,13 @@ class SineGen(torch.nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, f0):
|
||||
""" sine_tensor, uv = forward(f0)
|
||||
input F0: tensor(batchsize=1, dim=1, length)
|
||||
f0 for unvoiced steps should be 0
|
||||
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||
output uv: tensor(batchsize=1, length, 1)
|
||||
"""
|
||||
:param f0: [B, 1, sample_len], Hz
|
||||
:return: [B, 1, sample_len]
|
||||
"""
|
||||
|
||||
f0 = f0.transpose(1, 2)
|
||||
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
||||
for i in range(self.harmonic_num + 1):
|
||||
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
||||
@@ -168,59 +186,7 @@ class SineGen(torch.nn.Module):
|
||||
# first: set the unvoiced part to 0 by uv
|
||||
# then: additive noise
|
||||
sine_waves = sine_waves * uv + noise
|
||||
return sine_waves, uv, noise
|
||||
|
||||
|
||||
class SourceModuleHnNSF(torch.nn.Module):
|
||||
""" SourceModule for hn-nsf
|
||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0)
|
||||
sampling_rate: sampling_rate in Hz
|
||||
harmonic_num: number of harmonic above F0 (default: 0)
|
||||
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||
note that amplitude of noise in unvoiced is decided
|
||||
by sine_amp
|
||||
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
uv (batchsize, length, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0):
|
||||
super(SourceModuleHnNSF, self).__init__()
|
||||
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = add_noise_std
|
||||
|
||||
# to produce sine waveforms
|
||||
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
||||
sine_amp, add_noise_std, voiced_threshod)
|
||||
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = torch.nn.Tanh()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
"""
|
||||
# source for harmonic branch
|
||||
with torch.no_grad():
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
||||
sine_wavs = sine_wavs.transpose(1, 2)
|
||||
uv = uv.transpose(1, 2)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
|
||||
# source for noise branch, in the same shape as uv
|
||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||
return sine_merge, noise, uv
|
||||
return sine_waves.transpose(1, 2), uv.transpose(1, 2), noise
|
||||
|
||||
|
||||
class SineGen2(torch.nn.Module):
|
||||
@@ -242,7 +208,8 @@ class SineGen2(torch.nn.Module):
|
||||
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||
sine_amp=0.1, noise_std=0.003,
|
||||
voiced_threshold=0,
|
||||
flag_for_pulse=False):
|
||||
flag_for_pulse=False,
|
||||
causal=False):
|
||||
super(SineGen2, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = noise_std
|
||||
@@ -252,6 +219,11 @@ class SineGen2(torch.nn.Module):
|
||||
self.voiced_threshold = voiced_threshold
|
||||
self.flag_for_pulse = flag_for_pulse
|
||||
self.upsample_scale = upsample_scale
|
||||
self.causal = causal
|
||||
if causal is True:
|
||||
self.rand_ini = torch.rand(1, 9)
|
||||
self.rand_ini[:, 0] = 0
|
||||
self.sine_waves = torch.rand(1, 300 * 24000, 9)
|
||||
|
||||
def _f02uv(self, f0):
|
||||
# generate uv signal
|
||||
@@ -267,9 +239,12 @@ class SineGen2(torch.nn.Module):
|
||||
rad_values = (f0_values / self.sampling_rate) % 1
|
||||
|
||||
# initial phase noise (no noise for fundamental component)
|
||||
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
||||
rand_ini[:, 0] = 0
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||
if self.training is False and self.causal is True:
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
|
||||
else:
|
||||
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
||||
rand_ini[:, 0] = 0
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||
|
||||
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||
if not self.flag_for_pulse:
|
||||
@@ -279,7 +254,7 @@ class SineGen2(torch.nn.Module):
|
||||
|
||||
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
||||
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
||||
scale_factor=self.upsample_scale, mode="nearest" if self.causal is True else 'linear').transpose(1, 2)
|
||||
sines = torch.sin(phase)
|
||||
else:
|
||||
# If necessary, make sure that the first time step of every
|
||||
@@ -331,7 +306,10 @@ class SineGen2(torch.nn.Module):
|
||||
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||
# . for voiced regions is self.noise_std
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
if self.training is False and self.causal is True:
|
||||
noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
|
||||
else:
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
|
||||
# first: set the unvoiced part to 0 by uv
|
||||
# then: additive noise
|
||||
@@ -339,7 +317,7 @@ class SineGen2(torch.nn.Module):
|
||||
return sine_waves, uv, noise
|
||||
|
||||
|
||||
class SourceModuleHnNSF2(torch.nn.Module):
|
||||
class SourceModuleHnNSF(torch.nn.Module):
|
||||
""" SourceModule for hn-nsf
|
||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0)
|
||||
@@ -358,19 +336,24 @@ class SourceModuleHnNSF2(torch.nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0):
|
||||
super(SourceModuleHnNSF2, self).__init__()
|
||||
add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
|
||||
super(SourceModuleHnNSF, self).__init__()
|
||||
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = add_noise_std
|
||||
|
||||
# to produce sine waveforms
|
||||
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
|
||||
sine_amp, add_noise_std, voiced_threshod)
|
||||
if sinegen_type == '1':
|
||||
self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
||||
else:
|
||||
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal)
|
||||
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = torch.nn.Tanh()
|
||||
self.causal = causal
|
||||
if causal is True:
|
||||
self.uv = torch.rand(1, 300 * 24000, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@@ -385,7 +368,10 @@ class SourceModuleHnNSF2(torch.nn.Module):
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
|
||||
# source for noise branch, in the same shape as uv
|
||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||
if self.training is False and self.causal is True:
|
||||
noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
|
||||
else:
|
||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||
return sine_merge, noise, uv
|
||||
|
||||
|
||||
@@ -425,15 +411,16 @@ class HiFTGenerator(nn.Module):
|
||||
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
# NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
|
||||
this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
|
||||
self.m_source = this_SourceModuleHnNSF(
|
||||
# NOTE in CosyVoice2, we use the original SineGen implementation
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=sampling_rate,
|
||||
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
||||
harmonic_num=nb_harmonics,
|
||||
sine_amp=nsf_alpha,
|
||||
add_noise_std=nsf_sigma,
|
||||
voiced_threshod=nsf_voiced_threshold)
|
||||
voiced_threshod=nsf_voiced_threshold,
|
||||
sinegen_type='1' if self.sampling_rate == 22050 else '2',
|
||||
causal=False)
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
||||
|
||||
self.conv_pre = weight_norm(
|
||||
@@ -580,3 +567,180 @@ class HiFTGenerator(nn.Module):
|
||||
s[:, :, :cache_source.shape[2]] = cache_source
|
||||
generated_speech = self.decode(x=speech_feat, s=s)
|
||||
return generated_speech, s
|
||||
|
||||
|
||||
class CausalHiFTGenerator(HiFTGenerator):
|
||||
"""
|
||||
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
||||
https://arxiv.org/abs/2309.09493
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 80,
|
||||
base_channels: int = 512,
|
||||
nb_harmonics: int = 8,
|
||||
sampling_rate: int = 22050,
|
||||
nsf_alpha: float = 0.1,
|
||||
nsf_sigma: float = 0.003,
|
||||
nsf_voiced_threshold: float = 10,
|
||||
upsample_rates: List[int] = [8, 8],
|
||||
upsample_kernel_sizes: List[int] = [16, 16],
|
||||
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
||||
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
||||
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
source_resblock_kernel_sizes: List[int] = [7, 11],
|
||||
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
||||
lrelu_slope: float = 0.1,
|
||||
audio_limit: float = 0.99,
|
||||
conv_pre_look_right: int = 4,
|
||||
f0_predictor: torch.nn.Module = None,
|
||||
):
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
self.out_channels = 1
|
||||
self.nb_harmonics = nb_harmonics
|
||||
self.sampling_rate = sampling_rate
|
||||
self.istft_params = istft_params
|
||||
self.lrelu_slope = lrelu_slope
|
||||
self.audio_limit = audio_limit
|
||||
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=sampling_rate,
|
||||
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
||||
harmonic_num=nb_harmonics,
|
||||
sine_amp=nsf_alpha,
|
||||
add_noise_std=nsf_sigma,
|
||||
voiced_threshod=nsf_voiced_threshold,
|
||||
sinegen_type='1' if self.sampling_rate == 22050 else '2',
|
||||
causal=True)
|
||||
self.upsample_rates = upsample_rates
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
||||
|
||||
self.conv_pre = weight_norm(
|
||||
CausalConv1d(in_channels, base_channels, conv_pre_look_right + 1, 1, causal_type='right')
|
||||
)
|
||||
|
||||
# Up
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
CausalConv1dUpsample(
|
||||
base_channels // (2**i),
|
||||
base_channels // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Down
|
||||
self.source_downs = nn.ModuleList()
|
||||
self.source_resblocks = nn.ModuleList()
|
||||
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
||||
downsample_cum_rates = np.cumprod(downsample_rates)
|
||||
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
||||
if u == 1:
|
||||
self.source_downs.append(
|
||||
CausalConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1, causal_type='left')
|
||||
)
|
||||
else:
|
||||
self.source_downs.append(
|
||||
CausalConv1dDownSample(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
|
||||
)
|
||||
|
||||
self.source_resblocks.append(
|
||||
ResBlock(base_channels // (2 ** (i + 1)), k, d, causal=True)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = base_channels // (2**(i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(ResBlock(ch, k, d, causal=True))
|
||||
|
||||
self.conv_post = weight_norm(CausalConv1d(ch, istft_params["n_fft"] + 2, 7, 1, causal_type='left'))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
||||
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
||||
self.conv_pre_look_right = conv_pre_look_right
|
||||
self.f0_predictor = f0_predictor
|
||||
|
||||
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
|
||||
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
||||
if finalize is True:
|
||||
x = self.conv_pre(x)
|
||||
else:
|
||||
x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
|
||||
s_stft_real = s_stft_real[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
|
||||
s_stft_imag = s_stft_imag[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
|
||||
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
x = self.ups[i](x)
|
||||
|
||||
if i == self.num_upsamples - 1:
|
||||
x = self.reflection_pad(x)
|
||||
|
||||
# fusion
|
||||
si = self.source_downs[i](s_stft)
|
||||
si = self.source_resblocks[i](si)
|
||||
x = x + si
|
||||
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
||||
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
||||
|
||||
x = self._istft(magnitude, phase)
|
||||
if finalize is False:
|
||||
x = x[:, :-int(np.prod(self.upsample_rates) * self.istft_params['hop_len'])]
|
||||
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
||||
return x
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
|
||||
# mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
|
||||
self.f0_predictor.to(torch.float64)
|
||||
f0 = self.f0_predictor(speech_feat.to(torch.float64), finalize=finalize).to(speech_feat)
|
||||
# f0->source
|
||||
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
s, _, _ = self.m_source(s)
|
||||
s = s.transpose(1, 2)
|
||||
if finalize is True:
|
||||
generated_speech = self.decode(x=speech_feat, s=s, finalize=finalize)
|
||||
else:
|
||||
generated_speech = self.decode(x=speech_feat[:, :, :-self.f0_predictor.condnet[0].causal_padding], s=s, finalize=finalize)
|
||||
return generated_speech, s
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'llm': None, 'flow': None})
|
||||
model = configs['hift']
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_len, chunk_size, context_size = 300, 30, 8
|
||||
mel = torch.rand(1, 80, max_len).to(device)
|
||||
pred_gt, _ = model.inference(mel)
|
||||
for i in range(0, max_len, chunk_size):
|
||||
finalize = True if i + chunk_size + context_size >= max_len else False
|
||||
pred_chunk, _ = model.inference(mel[:, :, : i + chunk_size + context_size], finalize=finalize)
|
||||
pred_chunk = pred_chunk[:, i * 480:]
|
||||
print((pred_gt[:, i * 480:i * 480 + pred_chunk.shape[1]] - pred_chunk).abs().max().item())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,11 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import queue
|
||||
import os, queue
|
||||
import random
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, Optional, Callable, List, Generator
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
@@ -27,6 +28,7 @@ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
||||
from cosyvoice.utils.common import th_accuracy
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.mask import make_pad_mask
|
||||
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
|
||||
|
||||
|
||||
class TransformerLM(torch.nn.Module):
|
||||
@@ -56,8 +58,9 @@ class TransformerLM(torch.nn.Module):
|
||||
)
|
||||
|
||||
# 2. build speech token language model related modules
|
||||
self.sos_eos = 0
|
||||
self.sos = 0
|
||||
self.task_id = 1
|
||||
self.eos_token = self.speech_token_size
|
||||
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
||||
self.llm = llm
|
||||
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
|
||||
@@ -85,10 +88,10 @@ class TransformerLM(torch.nn.Module):
|
||||
encoder_out = self.text_encoder_affine_layer(encoder_out)
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
||||
def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
||||
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
||||
lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
||||
for i in range(len(text_token))]
|
||||
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
||||
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
||||
@@ -126,15 +129,15 @@ class TransformerLM(torch.nn.Module):
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
embedding = embedding.unsqueeze(1)
|
||||
|
||||
# 3. eos and task_id
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
# 3. sos and task_id
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
|
||||
# 4. encode speech_token
|
||||
speech_token = self.speech_embedding(speech_token)
|
||||
|
||||
# 5. unpad and pad
|
||||
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
|
||||
lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
|
||||
task_id_emb, speech_token, speech_token_len)
|
||||
|
||||
# 6. run lm forward
|
||||
@@ -154,7 +157,7 @@ class TransformerLM(torch.nn.Module):
|
||||
num_trials, max_trials = 0, 100
|
||||
while True:
|
||||
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
||||
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
||||
if (not ignore_eos) or (top_ids < self.speech_token_size):
|
||||
break
|
||||
num_trials += 1
|
||||
if num_trials > max_trials:
|
||||
@@ -193,13 +196,13 @@ class TransformerLM(torch.nn.Module):
|
||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
|
||||
# 4. cal min/max_length
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
@@ -215,11 +218,8 @@ class TransformerLM(torch.nn.Module):
|
||||
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
||||
device=lm_input.device)).to(torch.bool))
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
# force continue decode first token
|
||||
if i == 0:
|
||||
logp[:, self.speech_token_size] = -float('inf')
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
|
||||
if top_ids == self.eos_token:
|
||||
break
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
@@ -276,9 +276,10 @@ class Qwen2LM(TransformerLM):
|
||||
self.llm_output_size = llm_output_size
|
||||
self.speech_token_size = speech_token_size
|
||||
# 2. build speech token language model related modules
|
||||
self.sos_eos = 0
|
||||
self.sos = 0
|
||||
self.task_id = 1
|
||||
self.fill_token = 2
|
||||
self.eos_token = speech_token_size
|
||||
self.fill_token = speech_token_size + 2
|
||||
|
||||
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
||||
self.llm = llm
|
||||
@@ -300,19 +301,29 @@ class Qwen2LM(TransformerLM):
|
||||
# 5. vllm related
|
||||
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
||||
self.vllm_output_queue = {}
|
||||
if online_feature is True:
|
||||
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
|
||||
|
||||
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
|
||||
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
|
||||
lm_target, lm_input = [], []
|
||||
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
|
||||
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
|
||||
# NOTE add instruct_token in CosyVoice3
|
||||
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
|
||||
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
|
||||
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
|
||||
else:
|
||||
instruct_token = [torch.empty(0).to(text_token[0])] * len(text_token)
|
||||
instruct_token_emb = [torch.empty(0, 896).to(text_token_emb[0])] * len(text_token)
|
||||
instruct_token_len = torch.zeros(len(text_token)).to(text_token_len)
|
||||
for i in range(len(text_token)):
|
||||
# bistream sequence
|
||||
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
||||
this_lm_target, this_lm_input = [], []
|
||||
this_lm_target.append(IGNORE_ID)
|
||||
this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
|
||||
this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
|
||||
this_lm_target += [IGNORE_ID] * instruct_token_len[i]
|
||||
this_lm_input.append(instruct_token_emb[i])
|
||||
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
||||
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
||||
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
||||
@@ -320,22 +331,21 @@ class Qwen2LM(TransformerLM):
|
||||
assert len(this_speech_token) == self.mix_ratio[1]
|
||||
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
|
||||
this_lm_target += this_speech_token
|
||||
this_lm_target.append(self.speech_token_size + 2)
|
||||
this_lm_target.append(self.fill_token)
|
||||
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
|
||||
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
|
||||
else:
|
||||
this_lm_target += [-1] * len(this_text_token)
|
||||
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
|
||||
this_lm_target.append(self.speech_token_size)
|
||||
this_lm_target.append(self.eos_token)
|
||||
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
|
||||
this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
|
||||
this_lm_input.append(task_id_emb.squeeze(dim=0))
|
||||
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
|
||||
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
|
||||
# unistream sequence
|
||||
else:
|
||||
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
|
||||
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
|
||||
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
|
||||
this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
|
||||
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
|
||||
lm_target.append(this_lm_target)
|
||||
lm_input.append(this_lm_input)
|
||||
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
||||
@@ -357,17 +367,25 @@ class Qwen2LM(TransformerLM):
|
||||
"""
|
||||
text_token = batch['text_token'].to(device)
|
||||
text_token_len = batch['text_token_len'].to(device)
|
||||
speech_token = batch['speech_token'].to(device)
|
||||
speech_token_len = batch['speech_token_len'].to(device)
|
||||
if 'speech_token' not in batch:
|
||||
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||
else:
|
||||
speech_token = batch['speech_token'].to(device)
|
||||
speech_token_len = batch['speech_token_len'].to(device)
|
||||
|
||||
# 1. encode text_token
|
||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||
|
||||
# 3. sos and task_id
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
|
||||
# 2. encode speech_token
|
||||
speech_token_emb = self.speech_embedding(speech_token)
|
||||
|
||||
# 3. prepare llm_input/target
|
||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
|
||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
||||
speech_token, speech_token_emb, speech_token_len)
|
||||
lm_target = lm_target.to(device)
|
||||
|
||||
# 4. run lm forward
|
||||
@@ -392,6 +410,10 @@ class Qwen2LM(TransformerLM):
|
||||
# 1. encode text_token
|
||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||
|
||||
# 3. sos and task_id
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
|
||||
# 2. encode speech_token
|
||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
|
||||
@@ -401,8 +423,8 @@ class Qwen2LM(TransformerLM):
|
||||
speech_token_combined_emb = self.speech_embedding(speech_token_combined)
|
||||
|
||||
# 3. prepare llm_input/target
|
||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
|
||||
speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
|
||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
|
||||
task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
|
||||
lm_target = lm_target.to(device)
|
||||
|
||||
# 4. run lm forward
|
||||
@@ -420,8 +442,8 @@ class Qwen2LM(TransformerLM):
|
||||
rejected_lm_mask = rejected_lm_target == IGNORE_ID
|
||||
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||
chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1)
|
||||
rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
|
||||
chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
|
||||
rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
|
||||
return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -445,13 +467,13 @@ class Qwen2LM(TransformerLM):
|
||||
text = self.llm.model.model.embed_tokens(text)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
|
||||
# 4. cal min/max_length
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
@@ -500,11 +522,9 @@ class Qwen2LM(TransformerLM):
|
||||
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
|
||||
if top_ids in self.stop_token_ids:
|
||||
break
|
||||
if top_ids > self.speech_token_size:
|
||||
continue
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
@@ -526,20 +546,20 @@ class Qwen2LM(TransformerLM):
|
||||
|
||||
device = prompt_text.device
|
||||
# 1. prepare input
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_eos_emb], dim=1)
|
||||
lm_input = torch.concat([sos_emb], dim=1)
|
||||
|
||||
# 2. iterate text
|
||||
out_tokens = []
|
||||
cache = None
|
||||
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
||||
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
||||
next_fill_index = -1
|
||||
next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
|
||||
for this_text in text:
|
||||
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
||||
# prompt_speech_token_emb not empty, try append to lm_input
|
||||
@@ -554,12 +574,12 @@ class Qwen2LM(TransformerLM):
|
||||
break
|
||||
# no prompt_speech_token_emb remain, can decode some speech token
|
||||
if prompt_speech_token_emb.size(1) == 0:
|
||||
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
||||
if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
||||
logging.info('get fill token, need to append more text token')
|
||||
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||
lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
||||
logging.info('append {} text token'.format(lm_input_text.size(1)))
|
||||
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
|
||||
if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
|
||||
lm_input = lm_input_text
|
||||
else:
|
||||
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
||||
@@ -574,16 +594,16 @@ class Qwen2LM(TransformerLM):
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
||||
top_ids = self.speech_token_size + 2
|
||||
top_ids = self.fill_token
|
||||
next_fill_index += (self.mix_ratio[1] + 1)
|
||||
else:
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
|
||||
if top_ids == self.speech_token_size + 2:
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
|
||||
if top_ids == self.fill_token:
|
||||
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
||||
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.speech_token_size + 2:
|
||||
if top_ids == self.fill_token:
|
||||
break
|
||||
else:
|
||||
raise ValueError('should not get token {}'.format(top_ids))
|
||||
@@ -599,13 +619,142 @@ class Qwen2LM(TransformerLM):
|
||||
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False)
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.speech_token_size:
|
||||
if top_ids == self.eos_token:
|
||||
break
|
||||
else:
|
||||
raise ValueError('should not get token {}'.format(top_ids))
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
|
||||
class CosyVoice3LM(Qwen2LM):
|
||||
def __init__(
|
||||
self,
|
||||
llm_input_size: int,
|
||||
llm_output_size: int,
|
||||
speech_token_size: int,
|
||||
llm: torch.nn.Module,
|
||||
sampling: Callable,
|
||||
length_normalized_loss: bool = True,
|
||||
lsm_weight: float = 0.0,
|
||||
mix_ratio: List[int] = [5, 15],
|
||||
):
|
||||
torch.nn.Module.__init__(self)
|
||||
self.llm_input_size = llm_input_size
|
||||
self.llm_output_size = llm_output_size
|
||||
self.speech_token_size = speech_token_size
|
||||
# 2. build speech token language model related modules
|
||||
self.sos = speech_token_size + 0
|
||||
self.eos_token = speech_token_size + 1
|
||||
self.task_id = speech_token_size + 2
|
||||
self.fill_token = speech_token_size + 3
|
||||
|
||||
self.llm = llm
|
||||
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
|
||||
self.criterion_ce = LabelSmoothingLoss(
|
||||
size=speech_token_size + 200,
|
||||
padding_idx=IGNORE_ID,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
# 3. [Optional] build speech token related modules
|
||||
self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
|
||||
|
||||
# 4. sampling method
|
||||
self.sampling = sampling
|
||||
self.mix_ratio = mix_ratio
|
||||
|
||||
# 5. vllm related
|
||||
self.stop_token_ids = [speech_token_size + i for i in range(200)]
|
||||
self.vllm_output_queue = {}
|
||||
if online_feature is True:
|
||||
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
text: (B, L, D)
|
||||
text_lengths: (B,)
|
||||
audio: (B, T, N) or (B, T)
|
||||
audio_lengths: (B,)
|
||||
"""
|
||||
text_token = batch['text_token'].to(device)
|
||||
text_token_len = batch['text_token_len'].to(device)
|
||||
if 'speech_token' not in batch:
|
||||
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||
else:
|
||||
speech_token = batch['speech_token'].to(device)
|
||||
speech_token_len = batch['speech_token_len'].to(device)
|
||||
|
||||
# NOTE should append instruct_token to sequence, not implemented yet
|
||||
instruct_token = batch['instruct_token'].to(device)
|
||||
instruct_token_len = batch['instruct_token_len'].to(device)
|
||||
|
||||
# 1. encode text_token
|
||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||
instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
|
||||
|
||||
# 3. sos and task_id
|
||||
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
|
||||
# 2. encode speech_token
|
||||
speech_token_emb = self.speech_embedding(speech_token)
|
||||
|
||||
# 3. prepare llm_input/target
|
||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
||||
speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
|
||||
lm_target = lm_target.to(device)
|
||||
|
||||
# 4. run lm forward
|
||||
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||
logits = self.llm_decoder(lm_output)
|
||||
loss = self.criterion_ce(logits, lm_target.to(device))
|
||||
acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
|
||||
return {'loss': loss, 'acc': acc}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_len: torch.Tensor,
|
||||
prompt_text: torch.Tensor,
|
||||
prompt_text_len: torch.Tensor,
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
uuid: str = '',
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
text_len += prompt_text_len
|
||||
text = self.llm.model.model.embed_tokens(text)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
|
||||
# 4. cal min/max_length
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||
|
||||
# 5. step by step decode
|
||||
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
|
||||
yield token
|
||||
|
||||
@@ -238,7 +238,7 @@ def get_tokenizer(
|
||||
)
|
||||
|
||||
|
||||
class QwenTokenizer():
|
||||
class CosyVoice2Tokenizer():
|
||||
def __init__(self, token_path, skip_special_tokens=True):
|
||||
super().__init__()
|
||||
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
||||
@@ -271,9 +271,57 @@ class QwenTokenizer():
|
||||
return text
|
||||
|
||||
|
||||
class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
|
||||
def __init__(self, token_path, skip_special_tokens=True):
|
||||
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
||||
special_tokens = {
|
||||
'eos_token': '<|endoftext|>',
|
||||
'pad_token': '<|endoftext|>',
|
||||
'additional_special_tokens': [
|
||||
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
||||
'[breath]', '<strong>', '</strong>', '[noise]',
|
||||
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
||||
'[quick_breath]',
|
||||
"<laughter>", "</laughter>",
|
||||
"[hissing]", "[sigh]", "[vocalized-noise]",
|
||||
"[lipsmack]", "[mn]", "<|endofsystem|>",
|
||||
"[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]",
|
||||
"[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]",
|
||||
"[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]",
|
||||
"[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]",
|
||||
"[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]",
|
||||
"[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]",
|
||||
"[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]",
|
||||
"[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]",
|
||||
"[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]",
|
||||
"[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]",
|
||||
"[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]",
|
||||
"[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]",
|
||||
"[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]",
|
||||
"[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]",
|
||||
"[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]",
|
||||
"[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]",
|
||||
"[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]",
|
||||
"[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]",
|
||||
"[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]",
|
||||
"[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]"
|
||||
]
|
||||
}
|
||||
self.special_tokens = special_tokens
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
||||
self.tokenizer.add_special_tokens(special_tokens)
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_qwen_tokenizer(
|
||||
token_path: str,
|
||||
skip_special_tokens: bool
|
||||
) -> QwenTokenizer:
|
||||
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
||||
skip_special_tokens: bool,
|
||||
version: str = 'cosyvoice2'
|
||||
):
|
||||
if version == 'cosyvoice2':
|
||||
return CosyVoice2Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
||||
elif version == 'cosyvoice3':
|
||||
return CosyVoice3Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
@@ -143,3 +144,115 @@ class ConvolutionModule(nn.Module):
|
||||
x.masked_fill_(~mask_pad, 0.0)
|
||||
|
||||
return x.transpose(1, 2), new_cache
|
||||
|
||||
|
||||
# NOTE(Xiang Lyu) causal conv module used in convolution-based vocoder
|
||||
class CausalConv1d(torch.nn.Conv1d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
causal_type: str = 'left',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
||||
kernel_size, stride=1,
|
||||
padding=0, dilation=dilation,
|
||||
groups=groups, bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device, dtype=dtype)
|
||||
assert stride == 1
|
||||
self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2 + (kernel_size + 1) % 2
|
||||
assert causal_type in ['left', 'right']
|
||||
self.causal_type = causal_type
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor]:
|
||||
input_timestep = x.shape[2]
|
||||
if cache.size(2) == 0:
|
||||
cache = torch.zeros(x.shape[0], x.shape[1], self.causal_padding).to(x)
|
||||
assert cache.size(2) == self.causal_padding
|
||||
if self.causal_type == 'left':
|
||||
x = torch.concat([cache, x], dim=2)
|
||||
else:
|
||||
x = torch.concat([x, cache], dim=2)
|
||||
x = super(CausalConv1d, self).forward(x)
|
||||
assert x.shape[2] == input_timestep
|
||||
return x
|
||||
|
||||
|
||||
class CausalConv1dDownSample(torch.nn.Conv1d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
super(CausalConv1dDownSample, self).__init__(in_channels, out_channels,
|
||||
kernel_size, stride,
|
||||
padding=0, dilation=dilation,
|
||||
groups=groups, bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device, dtype=dtype)
|
||||
assert stride != 1 and dilation == 1
|
||||
assert kernel_size % stride == 0
|
||||
self.causal_padding = stride - 1
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if cache.size(2) == 0:
|
||||
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||
else:
|
||||
assert cache.size(2) == self.causal_padding
|
||||
x = torch.concat([cache, x], dim=2)
|
||||
x = super(CausalConv1dDownSample, self).forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalConv1dUpsample(torch.nn.Conv1d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
super(CausalConv1dUpsample, self).__init__(in_channels, out_channels,
|
||||
kernel_size, 1,
|
||||
padding=0, dilation=dilation,
|
||||
groups=groups, bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device, dtype=dtype)
|
||||
assert dilation == 1
|
||||
self.causal_padding = kernel_size - 1
|
||||
self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = self.upsample(x)
|
||||
input_timestep = x.shape[2]
|
||||
if cache.size(2) == 0:
|
||||
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||
else:
|
||||
assert cache.size(2) == self.causal_padding
|
||||
x = torch.concat([cache, x], dim=2)
|
||||
x = super(CausalConv1dUpsample, self).forward(x)
|
||||
assert input_timestep == x.shape[2]
|
||||
return x
|
||||
|
||||
@@ -64,17 +64,18 @@ class Upsample1D(nn.Module):
|
||||
|
||||
|
||||
class PreLookaheadLayer(nn.Module):
|
||||
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
||||
def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
self.conv1 = nn.Conv1d(
|
||||
channels, channels,
|
||||
in_channels, channels,
|
||||
kernel_size=pre_lookahead_len + 1,
|
||||
stride=1, padding=0,
|
||||
)
|
||||
self.conv2 = nn.Conv1d(
|
||||
channels, channels,
|
||||
channels, in_channels,
|
||||
kernel_size=3, stride=1, padding=0,
|
||||
)
|
||||
|
||||
@@ -199,7 +200,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
||||
# convolution module definition
|
||||
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
||||
cnn_module_norm, causal)
|
||||
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
|
||||
self.pre_lookahead_layer = PreLookaheadLayer(in_channels=512, channels=512, pre_lookahead_len=3)
|
||||
self.encoders = torch.nn.ModuleList([
|
||||
ConformerEncoderLayer(
|
||||
output_size,
|
||||
|
||||
@@ -32,10 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
|
||||
RelPositionMultiHeadedAttention)
|
||||
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
||||
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
||||
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
|
||||
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
|
||||
from cosyvoice.hifigan.generator import HiFTGenerator
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||
from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM
|
||||
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
|
||||
from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
|
||||
|
||||
|
||||
COSYVOICE_ACTIVATION_CLASSES = {
|
||||
@@ -80,4 +80,6 @@ def get_model_type(configs):
|
||||
return CosyVoiceModel
|
||||
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||
return CosyVoice2Model
|
||||
if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
|
||||
return CosyVoice3Model
|
||||
raise TypeError('No valid model type found!')
|
||||
|
||||
@@ -25,6 +25,33 @@ import torch
|
||||
|
||||
IGNORE_ID = -1
|
||||
|
||||
instruct_list = ["You are a helpful assistant. 请用广东话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用东北话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用甘肃话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用贵州话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用河南话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用湖北话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用湖南话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用江西话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用闽南话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用宁夏话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用山西话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用陕西话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用山东话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用上海话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用四川话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用天津话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用云南话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. Please say a sentence as loudly as possible.<|endofprompt|>",
|
||||
"You are a helpful assistant. Please say a sentence in a very soft voice.<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用尽可能慢地语速说一句话。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请非常开心地说一句话。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请非常伤心地说一句话。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请非常生气地说一句话。<|endofprompt|>",
|
||||
"You are a helpful assistant. 我想体验一下小猪佩奇风格,可以吗?<|endofprompt|>",
|
||||
"You are a helpful assistant. 你可以尝试用机器人的方式解答吗?<|endofprompt|>"]
|
||||
|
||||
|
||||
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
||||
"""Perform padding for the list of tensors.
|
||||
@@ -112,6 +139,7 @@ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25,
|
||||
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
||||
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
|
||||
if rep_num >= win_size * tau_r:
|
||||
weighted_scores[top_ids] = -float('inf')
|
||||
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
|
||||
return top_ids
|
||||
|
||||
@@ -130,12 +158,12 @@ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
||||
break
|
||||
prob = torch.tensor(prob).to(weighted_scores)
|
||||
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
|
||||
top_ids = indices[prob.multinomial(1, replacement=True)]
|
||||
top_ids = indices[prob.multinomial(1, replacement=True)].item()
|
||||
return top_ids
|
||||
|
||||
|
||||
def random_sampling(weighted_scores, decoded_tokens, sampling):
|
||||
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
||||
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True).item()
|
||||
return top_ids
|
||||
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ class Executor:
|
||||
for k, v in info_dict['loss_dict'].items():
|
||||
if k not in total_loss_dict:
|
||||
total_loss_dict[k] = []
|
||||
total_loss_dict[k].append(v.item() * num_utts)
|
||||
total_loss_dict[k].append(v.mean().item() * num_utts)
|
||||
log_per_step(None, info_dict)
|
||||
for k, v in total_loss_dict.items():
|
||||
total_loss_dict[k] = sum(v) / total_num_utts
|
||||
|
||||
@@ -41,11 +41,11 @@ def read_json_lists(list_file):
|
||||
return results
|
||||
|
||||
|
||||
def load_wav(wav, target_sr):
|
||||
def load_wav(wav, target_sr, min_sr=16000):
|
||||
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
|
||||
speech = speech.mean(dim=0, keepdim=True)
|
||||
if sample_rate != target_sr:
|
||||
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
||||
assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
||||
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
||||
return speech
|
||||
|
||||
@@ -88,30 +88,18 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
|
||||
# NOTE do not support bistream inference as only speech token embedding/head is kept
|
||||
def export_cosyvoice2_vllm(model, model_path, device):
|
||||
if os.path.exists(model_path):
|
||||
return
|
||||
pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
vocab_size = model.speech_embedding.num_embeddings
|
||||
feature_size = model.speech_embedding.embedding_dim
|
||||
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
dtype = torch.bfloat16
|
||||
# lm_head
|
||||
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
|
||||
with torch.no_grad():
|
||||
new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
|
||||
new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
|
||||
new_lm_head.weight[vocab_size:] = 0
|
||||
new_lm_head.bias[vocab_size:] = 0
|
||||
model.llm.model.lm_head = new_lm_head
|
||||
new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
||||
use_bias = True if model.llm_decoder.bias is not None else False
|
||||
model.llm.model.lm_head = model.llm_decoder
|
||||
# embed_tokens
|
||||
embed_tokens = model.llm.model.model.embed_tokens
|
||||
with torch.no_grad():
|
||||
new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight
|
||||
new_codec_embed.weight[vocab_size:] = 0
|
||||
model.llm.model.set_input_embeddings(new_codec_embed)
|
||||
model.llm.model.set_input_embeddings(model.speech_embedding)
|
||||
model.llm.model.to(device)
|
||||
model.llm.model.to(dtype)
|
||||
tmp_vocab_size = model.llm.model.config.vocab_size
|
||||
@@ -119,11 +107,12 @@ def export_cosyvoice2_vllm(model, model_path, device):
|
||||
del model.llm.model.generation_config.eos_token_id
|
||||
del model.llm.model.config.bos_token_id
|
||||
del model.llm.model.config.eos_token_id
|
||||
model.llm.model.config.vocab_size = pad_vocab_size
|
||||
model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings
|
||||
model.llm.model.config.tie_word_embeddings = False
|
||||
model.llm.model.config.use_bias = True
|
||||
model.llm.model.config.use_bias = use_bias
|
||||
model.llm.model.save_pretrained(model_path)
|
||||
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||
if use_bias is True:
|
||||
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||
model.llm.model.config.vocab_size = tmp_vocab_size
|
||||
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||
model.llm.model.set_input_embeddings(embed_tokens)
|
||||
|
||||
54
cosyvoice/utils/onnx.py
Normal file
54
cosyvoice/utils/onnx.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import onnxruntime
|
||||
import torch, random
|
||||
import os
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
|
||||
class SpeechTokenExtractor():
|
||||
def __init__(self, model_path):
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.speech_tokenizer_session = onnxruntime.InferenceSession(model_path,
|
||||
sess_options=option,
|
||||
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
|
||||
|
||||
def inference(self, feat, feat_lengths, device):
|
||||
speech_token = self.speech_tokenizer_session.run(None,
|
||||
{self.speech_tokenizer_session.get_inputs()[0].name:
|
||||
feat.transpose(1, 2).detach().cpu().numpy(),
|
||||
self.speech_tokenizer_session.get_inputs()[1].name:
|
||||
feat_lengths.detach().cpu().numpy()})[0]
|
||||
return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
|
||||
|
||||
|
||||
class EmbeddingExtractor():
|
||||
def __init__(self, model_path):
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.max_len = 10 * 16000
|
||||
self.campplus_session = onnxruntime.InferenceSession(model_path,
|
||||
sess_options=option,
|
||||
providers=["CPUExecutionProvider"])
|
||||
|
||||
def inference(self, speech):
|
||||
if speech.shape[1] > self.max_len:
|
||||
start_index = random.randint(0, speech.shape[1] - self.max_len)
|
||||
speech = speech[:, start_index: start_index + self.max_len]
|
||||
feat = kaldi.fbank(speech,
|
||||
num_mel_bins=80,
|
||||
dither=0,
|
||||
sample_frequency=16000)
|
||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||
embedding = self.campplus_session.run(None,
|
||||
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
||||
return torch.tensor(embedding).to(speech.device)
|
||||
|
||||
# singleton mode, only initialized once
|
||||
onnx_path = os.environ.get('onnx_path')
|
||||
if onnx_path is not None:
|
||||
embedding_extractor, online_feature = EmbeddingExtractor(model_path=os.path.join(onnx_path, 'campplus.onnx')), True
|
||||
else:
|
||||
embedding_extractor, online_feature = None, False
|
||||
@@ -53,7 +53,7 @@ def init_distributed(args):
|
||||
def init_dataset_and_dataloader(args, configs, gan, dpo):
|
||||
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
|
||||
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
|
||||
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=False, partition=False)
|
||||
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='dev', gan=gan, dpo=dpo, shuffle=False, partition=False)
|
||||
|
||||
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
|
||||
train_data_loader = DataLoader(train_dataset,
|
||||
@@ -71,7 +71,7 @@ def init_dataset_and_dataloader(args, configs, gan, dpo):
|
||||
|
||||
def check_modify_and_save_config(args, configs):
|
||||
if args.train_engine == "torch_ddp":
|
||||
configs['train_conf']["dtype"] = 'fp32'
|
||||
configs['train_conf']["dtype"] = 'bf16' if args.use_amp is True else 'fp32'
|
||||
else:
|
||||
with open(args.deepspeed_config, 'r') as fin:
|
||||
ds_configs = json.load(fin)
|
||||
@@ -164,18 +164,18 @@ def init_optimizer_and_scheduler(args, configs, model, gan):
|
||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['optim_d'] == 'adam':
|
||||
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
|
||||
elif configs['train_conf']['optim_d'] == 'adamw':
|
||||
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
|
||||
else:
|
||||
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['scheduler_d'] == 'warmuplr':
|
||||
scheduler_type = WarmupLR
|
||||
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
||||
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_d'])
|
||||
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
|
||||
scheduler_type = NoamHoldAnnealing
|
||||
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
||||
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_d'])
|
||||
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||
scheduler_type = ConstantLR
|
||||
scheduler_d = ConstantLR(optimizer_d)
|
||||
@@ -247,7 +247,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None
|
||||
dtype = torch.float32
|
||||
|
||||
if info_dict['train_engine'] == 'torch_ddp':
|
||||
autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
|
||||
autocast = torch.cuda.amp.autocast(enabled=scaler is not None, dtype=dtype)
|
||||
else:
|
||||
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
|
||||
|
||||
|
||||
@@ -23,6 +23,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import Optional
|
||||
from packaging.version import parse as vparse
|
||||
import vllm
|
||||
|
||||
# vLLM-0.11.0+ only support V1 engine
|
||||
VLLM_V1_ENGINE_ONLY: bool = vparse(vllm.__version__) >= vparse("0.11.0")
|
||||
if VLLM_V1_ENGINE_ONLY:
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
from vllm.model_executor.models.qwen2 import *
|
||||
|
||||
|
||||
@@ -87,10 +96,14 @@ class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_metadata: Optional[SamplingMetadata] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
if VLLM_V1_ENGINE_ONLY:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
self.lm_head.bias)
|
||||
else:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
||||
@@ -4,7 +4,7 @@ ARG VENV_NAME="cosyvoice"
|
||||
ENV VENV=$VENV_NAME
|
||||
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
|
||||
|
||||
ENV DEBIAN_FRONTEN=noninteractive
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
SHELL ["/bin/bash", "--login", "-c"]
|
||||
|
||||
@@ -46,6 +46,6 @@ RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||
|
||||
RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5
|
||||
RUN conda activate ${VENV} && cd CosyVoice && \
|
||||
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
|
||||
|
||||
WORKDIR /workspace/CosyVoice
|
||||
|
||||
112
example.py
Normal file
112
example.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import sys
|
||||
sys.path.append('third_party/Matcha-TTS')
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
import torchaudio
|
||||
|
||||
|
||||
def cosyvoice_example():
|
||||
""" CosyVoice Usage, check https://fun-audio-llm.github.io/ for more details
|
||||
"""
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-SFT')
|
||||
# sft usage
|
||||
print(cosyvoice.list_available_spks())
|
||||
# change stream=True for chunk stream inference
|
||||
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
|
||||
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M')
|
||||
# zero_shot usage
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
# cross_lingual usage, <|zh|><|en|><|ja|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.',
|
||||
'./asset/cross_lingual_prompt.wav')):
|
||||
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
# vc usage
|
||||
for i, j in enumerate(cosyvoice.inference_vc('./asset/cross_lingual_prompt.wav', './asset/zero_shot_prompt.wav')):
|
||||
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-Instruct')
|
||||
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
||||
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男',
|
||||
'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.<|endofprompt|>')):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
|
||||
def cosyvoice2_example():
|
||||
""" CosyVoice2 Usage, check https://funaudiollm.github.io/cosyvoice2/ for more details
|
||||
"""
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice2-0.5B')
|
||||
|
||||
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
||||
# zero_shot usage
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# save zero_shot spk for future usage
|
||||
assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', 'my_zero_shot_spk') is True
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk')):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
cosyvoice.save_spkinfo()
|
||||
|
||||
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', './asset/zero_shot_prompt.wav')):
|
||||
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# instruct usage
|
||||
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话<|endofprompt|>', './asset/zero_shot_prompt.wav')):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# bistream usage, you can use generator as input, this is useful when using text llm model as input
|
||||
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
|
||||
def text_generator():
|
||||
yield '收到好友从远方寄来的生日礼物,'
|
||||
yield '那份意外的惊喜与深深的祝福'
|
||||
yield '让我心中充满了甜蜜的快乐,'
|
||||
yield '笑容如花儿般绽放。'
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('zero_shot_bistream_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
|
||||
def cosyvoice3_example():
|
||||
""" CosyVoice3 Usage, check https://funaudiollm.github.io/cosyvoice3/ for more details
|
||||
"""
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B')
|
||||
# zero_shot usage
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('八百标兵奔北坡,北坡炮兵并排跑,炮兵怕把标兵碰,标兵怕碰炮兵炮。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L280
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('You are a helpful assistant.<|endofprompt|>[breath]因为他们那一辈人[breath]在乡里面住的要习惯一点,[breath]邻居都很活络,[breath]嗯,都很熟悉。[breath]',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# instruct usage, for supported control, check cosyvoice/utils/common.py#L28
|
||||
for i, j in enumerate(cosyvoice.inference_instruct2('好少咯,一般系放嗰啲国庆啊,中秋嗰啲可能会咯。', 'You are a helpful assistant. 请用广东话表达。<|endofprompt|>',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', 'You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# hotfix usage
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('高管也通过电话、短信、微信等方式对报道[j][ǐ]予好评。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('hotfix_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# NOTE for Japanese usage, you must translate it to katakana.
|
||||
# 歴史的世界においては、過去は単に過ぎ去ったものではない、プラトンのいう如く非有が有である。 -> レキシ テキ セカイ ニ オイ テ ワ、カコ ワ タンニ スギサッ タ モノ デ ワ ナイ、プラトン ノ イウ ゴトク ヒ ユー ガ ユー デ アル。
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('You are a helpful assistant.<|endofprompt|>レキシ テキ セカイ ニ オイ テ ワ、カコ ワ タンニ スギサッ タ モノ デ ワ ナイ、プラトン ノ イウ ゴトク ヒ ユー ガ ユー デ アル。',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('japanese_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
|
||||
def main():
|
||||
# cosyvoice_example()
|
||||
# cosyvoice2_example()
|
||||
cosyvoice3_example()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
6
examples/grpo/cosyvoice2/Dockerfile
Normal file
6
examples/grpo/cosyvoice2/Dockerfile
Normal file
@@ -0,0 +1,6 @@
|
||||
FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
|
||||
COPY requirements.txt /myworkspace/requirements.txt
|
||||
RUN pip install -r /myworkspace/requirements.txt
|
||||
RUN pip install -U nvidia-pytriton
|
||||
RUN git clone https://github.com/yuekaizhang/verl.git /myworkspace/verl -b thread && cd /myworkspace/verl && pip install --no-deps -e .
|
||||
RUN git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /myworkspace/PytritonSenseVoice && cd /myworkspace/PytritonSenseVoice && pip install -e .
|
||||
125
examples/grpo/cosyvoice2/README.md
Normal file
125
examples/grpo/cosyvoice2/README.md
Normal file
@@ -0,0 +1,125 @@
|
||||
# CosyVoice2 LLM Reinforcement Learning Recipe
|
||||
|
||||
This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Environment Setup](#environment-setup)
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Reward Function & ASR Server](#reward-function--asr-server)
|
||||
- [Training](#training)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Export Model](#export-model)
|
||||
- [Results](#results)
|
||||
- [Acknowledgement](#acknowledgement)
|
||||
|
||||
## Environment Setup
|
||||
We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile.
|
||||
```bash
|
||||
docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
|
||||
```
|
||||
If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally.
|
||||
|
||||
## Data Preparation
|
||||
|
||||
`prepare_data.py` expects a JSON/JSONL file with at least the following schema:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"text": "An example sentence to be synthesized."
|
||||
}
|
||||
```
|
||||
You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face.
|
||||
|
||||
Stage `0` converts raw JSONL files into the parquet format expected by veRL:
|
||||
|
||||
```bash
|
||||
bash run.sh 0 0
|
||||
```
|
||||
Create two JSONL files—`train.jsonl` and `test.jsonl`.
|
||||
The script will then generate two Parquet files:
|
||||
|
||||
```
|
||||
data/parquet_tiny/train.parquet
|
||||
data/parquet_tiny/test.parquet
|
||||
```
|
||||
|
||||
Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens.
|
||||
|
||||
|
||||
## Reward Function & ASR Server
|
||||
|
||||
To compute rewards, we run a lightweight server that:
|
||||
|
||||
1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model.
|
||||
2. Transcribes the waveform with **SenseVoice** ASR.
|
||||
3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1.
|
||||
|
||||
Start the server (stage `1`) in a dedicated terminal or on a separate GPU:
|
||||
|
||||
```bash
|
||||
bash run.sh 1 1
|
||||
# Triton server listens on ports 8000/8001/8002
|
||||
```
|
||||
|
||||
The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
|
||||
|
||||
## Training
|
||||
|
||||
Run stage `2` to start GRPO training:
|
||||
|
||||
```bash
|
||||
bash run.sh 2 2
|
||||
```
|
||||
|
||||
Key CLI arguments passed to `verl.trainer.main_ppo`:
|
||||
|
||||
* `algorithm.adv_estimator=grpo` – use GRPO instead of PPO.
|
||||
* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet`
|
||||
* `custom_reward_function.path=reward_tts.py` – custom reward function described above.
|
||||
|
||||
Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware.
|
||||
> [!TIP]
|
||||
> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model.
|
||||
|
||||
## Evaluation
|
||||
|
||||
After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`):
|
||||
|
||||
```bash
|
||||
bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model
|
||||
```
|
||||
|
||||
You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`):
|
||||
|
||||
```bash
|
||||
bash run.sh 4 4
|
||||
```
|
||||
|
||||
This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`.
|
||||
|
||||
> [!TIP]
|
||||
> The script also supports the Seed-TTS test set by setting `dataset=test_zh`.
|
||||
|
||||
## Export Model
|
||||
|
||||
To use the RL-trained model with the official CosyVoice repository:
|
||||
|
||||
```bash
|
||||
bash run.sh 5 5
|
||||
```
|
||||
|
||||
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
|
||||
> [!TIP]
|
||||
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
|
||||
|
||||
## Results
|
||||
|
||||
| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment |
|
||||
|-------|------------------------|------------------------------|---------|
|
||||
| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) |
|
||||
| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model |
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo).
|
||||
71
examples/grpo/cosyvoice2/huggingface_to_pretrained.py
Normal file
71
examples/grpo/cosyvoice2/huggingface_to_pretrained.py
Normal file
@@ -0,0 +1,71 @@
|
||||
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
python3 hf2pretrained.py --hf-cosyvoice2-llm-path /workspace/rl-exp/checkpoint-400 --output-path /workspace/CosyVoice2-0.5B/llm-new.pt
|
||||
"""
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-cosyvoice2-llm-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The RL trained CosyVoice2 model path in HuggingFace format",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-path",
|
||||
type=str,
|
||||
default="./llm.pt",
|
||||
help="The path to save the llm.pt",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.hf_cosyvoice2_llm_path)
|
||||
speech_start_idx = tokenizer.convert_tokens_to_ids("<|s_0|>")
|
||||
cosyvoice2_token_size = 6561 + 3
|
||||
llm_embedding_vocab_size = 2
|
||||
|
||||
hf_tensors = {}
|
||||
with safe_open(f"{args.hf_cosyvoice2_llm_path}/model.safetensors", framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
if k.startswith("lm_head.bias"):
|
||||
# RL trained model disable bias for lm_head
|
||||
continue
|
||||
new_k = "llm.model." + k
|
||||
hf_tensors[new_k] = f.get_tensor(k)
|
||||
if k.startswith("lm_head"):
|
||||
hf_tensors["llm_decoder.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
|
||||
hf_tensors["llm_decoder.bias"] = torch.zeros_like(hf_tensors["llm_decoder.weight"][:, 0])
|
||||
if k.startswith("model.embed_tokens"):
|
||||
hf_tensors["speech_embedding.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
|
||||
hf_tensors["llm_embedding.weight"] = f.get_tensor(k)[speech_start_idx + cosyvoice2_token_size:speech_start_idx + cosyvoice2_token_size + llm_embedding_vocab_size]
|
||||
|
||||
# use tie_word_embeddings=True
|
||||
hf_tensors["llm.model.model.embed_tokens.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"][:151936]
|
||||
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
|
||||
|
||||
torch.save(hf_tensors, args.output_path)
|
||||
397
examples/grpo/cosyvoice2/infer_dataset.py
Normal file
397
examples/grpo/cosyvoice2/infer_dataset.py
Normal file
@@ -0,0 +1,397 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Example Usage
|
||||
dataset=zero_shot_zh
|
||||
output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts
|
||||
|
||||
token2wav_path=/workspace/CosyVoice2-0.5B
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
torchrun --nproc_per_node=8 \
|
||||
infer_dataset.py \
|
||||
--output-dir $output_dir \
|
||||
--llm-model-name-or-path $llm_path/merged_hf_model \
|
||||
--token2wav-path $token2wav_path \
|
||||
--split-name ${dataset} || exit 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
import soundfile as sf
|
||||
import s3tokenizer
|
||||
from functools import partial
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
try:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501
|
||||
|
||||
|
||||
def audio_decode_cosyvoice2(
|
||||
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||
):
|
||||
"""
|
||||
Generate audio from tokens with optional tone and prompt embedding.
|
||||
"""
|
||||
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
|
||||
"empty", prompt_text, prompt_speech_16k, 24000
|
||||
)
|
||||
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||
token=audio_tokens.to(codec_decoder.model.device),
|
||||
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token_len=torch.tensor(
|
||||
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
|
||||
finalize=True,
|
||||
)
|
||||
|
||||
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
|
||||
return audio_hat
|
||||
|
||||
|
||||
def extract_speech_ids(speech_tokens_str):
|
||||
"""Extract speech IDs from token strings like <|s_23456|>"""
|
||||
speech_ids = []
|
||||
for token_str in speech_tokens_str:
|
||||
if token_str.startswith('<|s_') and token_str.endswith('|>'):
|
||||
num_str = token_str[4:-2]
|
||||
num = int(num_str)
|
||||
speech_ids.append(num)
|
||||
else:
|
||||
print(f"Unexpected token: {token_str}")
|
||||
return speech_ids
|
||||
|
||||
|
||||
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||
speech_id_str = ""
|
||||
for token in cosy2_tokens:
|
||||
speech_id_str += f"<|s_{token}|>"
|
||||
return speech_id_str
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", required=True, type=str, help="dir to save result"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
default=1,
|
||||
type=int,
|
||||
help="batch size (per-device) for inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers", type=int, default=1, help="workers for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefetch", type=int, default=5, help="prefetch for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-model-name-or-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="LLM model path (includes both model and tokenizer)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token2wav-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="CosyVoice2 token2wav model path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-text",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt text for CosyVoice2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-speech-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the prompt speech for CosyVoice2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="top p for sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="temperature for sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=50,
|
||||
help="top k for sampling",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
"""Simplified data collator for batch_size=1 processing"""
|
||||
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
||||
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
||||
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||
mels, prompt_audio_cosy2tokens_list = [], []
|
||||
for item in batch:
|
||||
prompt_text, target_text = (
|
||||
item["prompt_text"],
|
||||
item["target_text"],
|
||||
)
|
||||
prompt_text_list.append(prompt_text)
|
||||
# Combine prompt and target text
|
||||
full_text = prompt_text + target_text
|
||||
|
||||
# get prompt audio for CosyVoice2 (convert to 16kHz)
|
||||
ref_audio_org, ref_sr = (
|
||||
item["prompt_audio"]["array"],
|
||||
item["prompt_audio"]["sampling_rate"],
|
||||
)
|
||||
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
|
||||
# ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
|
||||
print(ref_audio_org.shape)
|
||||
|
||||
if ref_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||
ref_audio = resampler(ref_audio_org)
|
||||
else:
|
||||
ref_audio = ref_audio_org
|
||||
|
||||
prompt_audio_list.append(ref_audio)
|
||||
|
||||
if "prompt_audio_cosy2_tokens" in item:
|
||||
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
|
||||
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
|
||||
else:
|
||||
# convert to float first
|
||||
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
|
||||
|
||||
if len(mels) > 0:
|
||||
mels, mels_lens = s3tokenizer.padding(mels)
|
||||
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
||||
for i in range(len(codes)):
|
||||
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
||||
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
|
||||
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
||||
# Create chat template for LLM generation
|
||||
chat = [
|
||||
{"role": "user", "content": full_text},
|
||||
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
|
||||
]
|
||||
if 'system' in tokenizer.chat_template:
|
||||
tokenizer.chat_template = TEMPLATE
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
chat,
|
||||
tokenize=True,
|
||||
return_tensors='pt',
|
||||
continue_final_message=True
|
||||
)
|
||||
input_ids_list.append(input_ids.squeeze(0))
|
||||
|
||||
# For batch_size=1, no need to pad
|
||||
if len(input_ids_list) == 1:
|
||||
input_ids = input_ids_list[0].unsqueeze(0)
|
||||
else:
|
||||
# Handle batch > 1 if needed
|
||||
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||
input_ids_list = [
|
||||
torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
|
||||
for input_ids in input_ids_list
|
||||
]
|
||||
input_ids = torch.stack(input_ids_list)
|
||||
|
||||
ids = [item["id"] for item in batch]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"ids": ids,
|
||||
"prompt_text": prompt_text_list,
|
||||
"prompt_audio_list": prompt_audio_list,
|
||||
}
|
||||
|
||||
|
||||
def init_distributed():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
print(
|
||||
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||
)
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group("nccl")
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
world_size, local_rank, rank = init_distributed()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
# Load LLM model and tokenizer directly
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
cosyvoice_codec = CosyVoice2(
|
||||
args.token2wav_path, load_jit=True, load_trt=True, fp16=True
|
||||
)
|
||||
if args.prompt_speech_path:
|
||||
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
|
||||
else:
|
||||
prompt_speech_16k = None
|
||||
s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None
|
||||
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
|
||||
dataset = load_dataset(
|
||||
dataset_name,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch,
|
||||
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
|
||||
)
|
||||
|
||||
total_steps = len(dataset)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||
|
||||
for batch in dataloader:
|
||||
with torch.no_grad():
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
|
||||
# Generate speech tokens using LLM
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=2048, # Max length for generation
|
||||
do_sample=True,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
)
|
||||
|
||||
# Process each sample in the batch
|
||||
for i in range(len(batch["ids"])):
|
||||
# Extract generated tokens (excluding input)
|
||||
input_length = input_ids[i].shape[0]
|
||||
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
|
||||
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
# Extract speech IDs from token strings like <|s_23456|>
|
||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||
|
||||
if len(speech_ids) == 0:
|
||||
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||
continue
|
||||
|
||||
# Convert to tensor for CosyVoice2
|
||||
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
|
||||
|
||||
if args.prompt_text is not None:
|
||||
current_prompt_text = args.prompt_text
|
||||
current_prompt_audio = prompt_speech_16k
|
||||
else:
|
||||
current_prompt_text = batch["prompt_text"][i]
|
||||
current_prompt_audio = batch["prompt_audio_list"][i]
|
||||
|
||||
if current_prompt_audio is not None:
|
||||
# Generate audio using CosyVoice2
|
||||
audio_hat = audio_decode_cosyvoice2(
|
||||
audio_tokens,
|
||||
current_prompt_text,
|
||||
current_prompt_audio,
|
||||
cosyvoice_codec,
|
||||
)
|
||||
|
||||
# Convert to numpy and save
|
||||
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||
target_sample_rate = 24000
|
||||
|
||||
utt = batch["ids"][i]
|
||||
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
|
||||
|
||||
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
|
||||
else:
|
||||
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(world_size * len(batch["ids"]))
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.close()
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
86
examples/grpo/cosyvoice2/prepare_data.py
Normal file
86
examples/grpo/cosyvoice2/prepare_data.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocess the Text to Speech dataset to parquet format
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import datasets
|
||||
|
||||
from verl.utils.hdfs_io import copy, makedirs
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
|
||||
parser.add_argument("--test_file", required=True, help="Path to test JSON/JSONL file")
|
||||
parser.add_argument("--local_dir", default=None, required=True)
|
||||
parser.add_argument("--hdfs_dir", default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load datasets from local JSON files
|
||||
train_dataset = datasets.load_dataset("json", data_files=args.train_file)['train']
|
||||
test_dataset = datasets.load_dataset("json", data_files=args.test_file)['train']
|
||||
|
||||
# add a row to each data item that represents a unique id
|
||||
def make_map_fn(split):
|
||||
def process_fn(example, idx):
|
||||
text = example.pop("text")
|
||||
|
||||
# use cosyvoice2 official huggingface compatible checkpoint template
|
||||
question = text
|
||||
answer = ""
|
||||
|
||||
data = {
|
||||
"data_source": f"{args.train_file}_{args.test_file}", # Use file names as data source
|
||||
"prompt": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": answer,
|
||||
},
|
||||
],
|
||||
"ability": "text-to-speech",
|
||||
"reward_model": {"style": "rule", "ground_truth": text},
|
||||
"extra_info": {
|
||||
"split": split,
|
||||
"index": idx,
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
return process_fn
|
||||
|
||||
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
|
||||
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
|
||||
|
||||
local_dir = args.local_dir
|
||||
hdfs_dir = args.hdfs_dir
|
||||
|
||||
print(train_dataset)
|
||||
print(test_dataset)
|
||||
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
|
||||
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
|
||||
|
||||
if hdfs_dir is not None:
|
||||
makedirs(hdfs_dir)
|
||||
|
||||
copy(src=local_dir, dst=hdfs_dir)
|
||||
133
examples/grpo/cosyvoice2/pretrained_to_huggingface.py
Normal file
133
examples/grpo/cosyvoice2/pretrained_to_huggingface.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage: Instruct TTS
|
||||
python3 infer.py \
|
||||
--token2wav-path /workspace/CosyVoice2-0.5B \
|
||||
--prompt-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
||||
--prompt-speech-path ./assets/prompt_audio.wav \
|
||||
--model-path ./transformers_cosyvoice2_llm \
|
||||
--input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。"
|
||||
"""
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--pretrained-cosyvoice2-path",
|
||||
type=str,
|
||||
default="/workspace/CosyVoice2-0.5B",
|
||||
help="Token2Wav path, default to %(default)r",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default='./transformers_cosyvoice2_llm',
|
||||
help="The path to save the model",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
cosy2_model = CosyVoice2(
|
||||
args.pretrained_cosyvoice2_path, load_jit=False, load_trt=False, fp16=False
|
||||
)
|
||||
|
||||
llm = cosy2_model.model.llm.llm.model
|
||||
|
||||
speech_embedding = cosy2_model.model.llm.speech_embedding
|
||||
llm_decoder = cosy2_model.model.llm.llm_decoder
|
||||
llm_embedding = cosy2_model.model.llm.llm_embedding
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"{args.pretrained_cosyvoice2_path}/CosyVoice-BlankEN")
|
||||
special_tokens = {
|
||||
'eos_token': '<|endoftext|>',
|
||||
'pad_token': '<|endoftext|>',
|
||||
'additional_special_tokens': [
|
||||
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
||||
'[breath]', '<strong>', '</strong>', '[noise]',
|
||||
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
||||
'[quick_breath]',
|
||||
"<laughter>", "</laughter>",
|
||||
"[hissing]", "[sigh]", "[vocalized-noise]",
|
||||
"[lipsmack]", "[mn]"
|
||||
]
|
||||
}
|
||||
tokenizer.add_special_tokens(special_tokens)
|
||||
|
||||
original_tokenizer_vocab_size = len(tokenizer)
|
||||
cosyvoice2_token_size = 6561
|
||||
new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
|
||||
"<|eos1|>", "<|eos2|>", "<|eos3|>", "<|sos|>", "<|task_id|>"
|
||||
]
|
||||
num_added_tokens = tokenizer.add_tokens(new_tokens)
|
||||
|
||||
llm.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128)
|
||||
vocab_size = llm.get_input_embeddings().weight.shape[0]
|
||||
|
||||
feature_size = speech_embedding.embedding_dim
|
||||
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=vocab_size, bias=True)
|
||||
|
||||
with torch.no_grad():
|
||||
# set the weight and bias of the new lm_head to 0
|
||||
new_lm_head.weight.data.zero_()
|
||||
# make bias value -inf
|
||||
new_lm_head.bias.data.fill_(-float('inf'))
|
||||
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
|
||||
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
|
||||
|
||||
llm.lm_head = new_lm_head
|
||||
input_embeddings = llm.get_input_embeddings()
|
||||
|
||||
with torch.no_grad():
|
||||
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
|
||||
input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
|
||||
|
||||
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
|
||||
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
|
||||
original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
|
||||
llm.generation_config.eos_token_id = eos_token_ids
|
||||
llm.generation_config.temperature = 1.0
|
||||
llm.generation_config.top_p = 0.8
|
||||
llm.generation_config.top_k = 25
|
||||
|
||||
llm.config.eos_token_id = original_tokenizer_vocab_size + cosyvoice2_token_size
|
||||
llm.config.vocab_size = vocab_size
|
||||
llm.config.tie_word_embeddings = False
|
||||
llm.config.use_bias = True
|
||||
llm.to(torch.bfloat16)
|
||||
llm.save_pretrained(args.save_path)
|
||||
|
||||
TEMPLATE = (
|
||||
"{%- for message in messages %}"
|
||||
"{%- if message['role'] == 'user' %}"
|
||||
"{{- '<|sos|>' + message['content'] + '<|task_id|>' }}"
|
||||
"{%- elif message['role'] == 'assistant' %}"
|
||||
"{{- message['content']}}"
|
||||
"{%- endif %}"
|
||||
"{%- endfor %}"
|
||||
)
|
||||
tokenizer.chat_template = TEMPLATE
|
||||
tokenizer.save_pretrained(args.save_path)
|
||||
31
examples/grpo/cosyvoice2/requirements.txt
Normal file
31
examples/grpo/cosyvoice2/requirements.txt
Normal file
@@ -0,0 +1,31 @@
|
||||
conformer==0.3.2
|
||||
diffusers==0.29.0
|
||||
gdown==5.1.0
|
||||
gradio
|
||||
hydra-core==1.3.2
|
||||
HyperPyYAML==1.2.2
|
||||
inflect==7.3.1
|
||||
librosa==0.10.2
|
||||
lightning==2.2.4
|
||||
matplotlib==3.7.5
|
||||
modelscope==1.15.0
|
||||
networkx==3.1
|
||||
omegaconf==2.3.0
|
||||
onnx==1.16.0
|
||||
onnxruntime-gpu==1.18.0
|
||||
protobuf==4.25
|
||||
pydantic==2.7.0
|
||||
pyworld==0.3.4
|
||||
rich==13.7.1
|
||||
soundfile==0.12.1
|
||||
tensorboard==2.14.0
|
||||
wget==3.2
|
||||
WeTextProcessing==1.0.3
|
||||
s3tokenizer
|
||||
tensorrt
|
||||
sherpa_onnx
|
||||
jiwer
|
||||
zhon
|
||||
numpy==1.25.2
|
||||
pypinyin
|
||||
openai-whisper
|
||||
233
examples/grpo/cosyvoice2/reward_tts.py
Normal file
233
examples/grpo/cosyvoice2/reward_tts.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Reward calculation for CosyVoice2-0.5B.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import argparse
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
|
||||
REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
|
||||
|
||||
|
||||
def _parse_ids(token_str: str) -> List[int]:
|
||||
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
|
||||
|
||||
|
||||
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
|
||||
"""Send token IDs and ground-truth text to the Triton server and get reward."""
|
||||
|
||||
tokens_arr = np.array(tokens, dtype=np.int32).reshape(1, -1)
|
||||
lens_arr = np.array([[tokens_arr.shape[1]]], dtype=np.int32)
|
||||
|
||||
gt_arr = np.array([ground_truth.encode("utf-8")], dtype=object)
|
||||
|
||||
payload = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "TOKENS",
|
||||
"shape": list(tokens_arr.shape),
|
||||
"datatype": "INT32",
|
||||
"data": tokens_arr.tolist(),
|
||||
},
|
||||
{
|
||||
"name": "TOKEN_LENS",
|
||||
"shape": list(lens_arr.shape),
|
||||
"datatype": "INT32",
|
||||
"data": lens_arr.tolist(),
|
||||
},
|
||||
{
|
||||
"name": "GT_TEXT",
|
||||
"shape": [1, 1],
|
||||
"datatype": "BYTES",
|
||||
"data": [ground_truth],
|
||||
},
|
||||
]
|
||||
}
|
||||
rsp = requests.post(
|
||||
REWARD_SERVER_URL,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
verify=False,
|
||||
params={"request_id": "0"},
|
||||
)
|
||||
rsp.raise_for_status()
|
||||
result = rsp.json()
|
||||
|
||||
try:
|
||||
# Reward is returned as the first output
|
||||
return float(result["outputs"][0]["data"][0])
|
||||
except (KeyError, IndexError, TypeError):
|
||||
return 0.0
|
||||
|
||||
|
||||
def compute_score(
|
||||
data_source: str,
|
||||
solution_str: str,
|
||||
ground_truth: str,
|
||||
extra_info: dict | None = None,
|
||||
*,
|
||||
debug_dump: bool = False,
|
||||
) -> float:
|
||||
"""Return reward in [0, 1] using the Triton ASR service.
|
||||
|
||||
The reward is based on the pinyin-level WER between the ASR transcript
|
||||
produced from *solution_str* and the provided *ground_truth* text.
|
||||
"""
|
||||
|
||||
# Decode token IDs
|
||||
ids = _parse_ids(solution_str)
|
||||
|
||||
# Query remote server for reward
|
||||
try:
|
||||
reward = _remote_reward(ids, ground_truth)
|
||||
except Exception as e:
|
||||
reward = 0.0
|
||||
|
||||
if debug_dump:
|
||||
print(
|
||||
f"\033[92m[{data_source}] Remote reward: {reward:.4f}\033[0m"
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
|
||||
# CLI quick test
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def get_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test TTS CER scoring with data from JSONL file",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input", "-i",
|
||||
type=str,
|
||||
default="data/emilia_zh-cosy-tiny-test.jsonl",
|
||||
help="Path to input JSONL file"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-samples", "-n",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of samples to process (default: all)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-interactive",
|
||||
action="store_true",
|
||||
help="Run in non-interactive mode (process all samples without prompts)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Enable debug mode"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def load_jsonl(file_path: str):
|
||||
"""Load data from jsonl file."""
|
||||
data = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
data.append(json.loads(line.strip()))
|
||||
return data
|
||||
|
||||
def code_to_solution_str(code_list: List[int]) -> str:
|
||||
"""Convert code list to solution string format."""
|
||||
return ''.join([f"<|s_{code}|>" for code in code_list])
|
||||
|
||||
# Parse command line arguments
|
||||
args = get_args()
|
||||
|
||||
try:
|
||||
# Load data from jsonl file
|
||||
print(f"Loading data from: {args.input}")
|
||||
data_list = load_jsonl(args.input)
|
||||
print(f"Loaded {len(data_list)} samples")
|
||||
|
||||
# Limit samples if specified
|
||||
if args.max_samples is not None:
|
||||
data_list = data_list[:args.max_samples]
|
||||
print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
|
||||
|
||||
# Process each sample
|
||||
begin_time = time.time()
|
||||
for i, sample in enumerate(data_list):
|
||||
print(f"\n--- Sample {i+1}/{len(data_list)} ---")
|
||||
print(f"Index: {sample.get('index', 'unknown')}")
|
||||
print(f"Text: {sample['text']}")
|
||||
|
||||
# Extract required fields
|
||||
code_list = sample['code']
|
||||
ground_truth = sample['text']
|
||||
data_source = sample.get('index', f'sample_{i}') # Use index as data_source
|
||||
|
||||
# Convert code list to solution string
|
||||
solution_str = code_to_solution_str(code_list)
|
||||
print(f"Solution tokens: {len(code_list)} tokens")
|
||||
if args.debug:
|
||||
print(f"Solution string: {solution_str}")
|
||||
else:
|
||||
print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
|
||||
|
||||
# Call compute_score function
|
||||
try:
|
||||
score = compute_score(
|
||||
data_source=data_source,
|
||||
solution_str=solution_str,
|
||||
ground_truth=ground_truth,
|
||||
extra_info=None,
|
||||
debug_dump=args.debug
|
||||
)
|
||||
print(f"Final Score: {score:.4f}")
|
||||
except Exception as e:
|
||||
print(f"Error computing score: {e}")
|
||||
|
||||
# Ask user if they want to continue (for interactive mode)
|
||||
if not args.no_interactive and i < len(data_list) - 1:
|
||||
try:
|
||||
response = input("\nPress Enter to continue or 'q' to quit: ").strip().lower()
|
||||
if response == 'q':
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopped by user")
|
||||
break
|
||||
|
||||
print(f"\nProcessed {min(i+1, len(data_list))} samples")
|
||||
end_time = time.time()
|
||||
print(f"Time taken: {end_time - begin_time} seconds")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File not found - {args.input}")
|
||||
print("Please check the file path or use --input to specify correct path")
|
||||
print("Run with --help for usage information")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
159
examples/grpo/cosyvoice2/run.sh
Normal file
159
examples/grpo/cosyvoice2/run.sh
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=4
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export PYTHONPATH=/workspace/CosyVoice
|
||||
model_scope_model_path=./CosyVoice2-0.5B
|
||||
sft_model_path=./transformers_cosyvoice2_llm
|
||||
|
||||
if [ $stage -le -2 ] && [ $stop_stage -ge -2 ]; then
|
||||
log "stage -2: install dependencies locally if pre-built docker image is not available"
|
||||
conda create -n cosyvoice2 python=3.10 -y
|
||||
conda activate cosyvoice2
|
||||
# install verl
|
||||
git clone https://github.com/yuekaizhang/verl.git -b thread
|
||||
cd verl
|
||||
USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh
|
||||
pip install --no-deps -e .
|
||||
cd -
|
||||
# install requirements
|
||||
pip install -r requirements.txt
|
||||
pip install -U nvidia-pytriton
|
||||
git clone https://github.com/yuekaizhang/PytritonSenseVoice.git && cd PytritonSenseVoice && pip install -e .
|
||||
fi
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
|
||||
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
|
||||
python3 pretrained_to_huggingface.py \
|
||||
--pretrained-cosyvoice2-path $model_scope_model_path \
|
||||
--save-path $sft_model_path
|
||||
|
||||
# Or, you could use the following command to download the huggingface compatible checkpoint
|
||||
# huggingface-cli download --local-dir $sft_model_path yuekai/cosyvoice2_llm
|
||||
|
||||
# Note: we remove the lm_head's bias to make it compatible with the Qwen2.5-0.5B model in Transformers.
|
||||
fi
|
||||
|
||||
data_dir=data/parquet_aishell3
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: prepare data into verl format"
|
||||
mkdir -p $data_dir
|
||||
wget -O data/aishell-3.jsonl https://huggingface.co/datasets/SparkAudio/voxbox/resolve/main/metadata/aishell-3.jsonl
|
||||
# total 88035 samples
|
||||
head -n 80000 data/aishell-3.jsonl > data/train.jsonl
|
||||
tail -n 100 data/aishell-3.jsonl > data/test.jsonl
|
||||
python prepare_data.py \
|
||||
--train_file data/train.jsonl \
|
||||
--test_file data/test.jsonl \
|
||||
--local_dir $data_dir
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "stage 1: start token2wav asr server for reward function"
|
||||
python3 token2wav_asr_server.py --number-of-devices 8
|
||||
fi
|
||||
|
||||
exp_name=official_llm_aishell3_grpo
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "stage 2: grpo train"
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
export MKL_SERVICE_FORCE_INTEL=TRUE
|
||||
n_gpus_per_node=8
|
||||
micro_batch_size=4
|
||||
train_batch_size=32
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
algorithm.adv_estimator=grpo \
|
||||
data.train_files=$data_dir/train.parquet \
|
||||
data.val_files=$data_dir/test.parquet \
|
||||
data.train_batch_size=$train_batch_size \
|
||||
data.max_prompt_length=1024 \
|
||||
data.max_response_length=512 \
|
||||
data.truncation='error' \
|
||||
actor_rollout_ref.model.use_remove_padding=False \
|
||||
actor_rollout_ref.model.path=$sft_model_path \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \
|
||||
actor_rollout_ref.actor.use_kl_loss=False \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||
actor_rollout_ref.rollout.do_sample=true \
|
||||
actor_rollout_ref.rollout.temperature=0.8 \
|
||||
actor_rollout_ref.rollout.top_p=0.95 \
|
||||
actor_rollout_ref.rollout.top_k=25 \
|
||||
actor_rollout_ref.rollout.n=4 \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=true \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=25 \
|
||||
reward_model.reward_manager=prime \
|
||||
custom_reward_function.path=reward_tts.py \
|
||||
custom_reward_function.name=compute_score \
|
||||
trainer.project_name='cosyvoice2_grpo' \
|
||||
trainer.experiment_name=$exp_name \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.n_gpus_per_node=$n_gpus_per_node \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=100 \
|
||||
trainer.test_freq=100 \
|
||||
trainer.resume_mode='auto' \
|
||||
trainer.total_epochs=1 \
|
||||
trainer.val_before_train=False
|
||||
fi
|
||||
|
||||
steps=(100 200 300 400 500)
|
||||
for step in ${steps[@]}; do
|
||||
llm_path=./checkpoints/cosyvoice2_grpo/$exp_name/global_step_${step}
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "stage 3: merge the model"
|
||||
python -m verl.model_merger merge \
|
||||
--backend fsdp \
|
||||
--local_dir $llm_path/actor \
|
||||
--target_dir $llm_path/merged_hf_model || exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "stage 4: Test the model"
|
||||
dataset=zero_shot_zh # from CosyVoice3 test set
|
||||
# dataset=test_zh # from seed_tts test set
|
||||
output_dir=./outputs_${exp_name}_${step}_${dataset}
|
||||
|
||||
token2wav_path=/workspace/CosyVoice2-0.5B
|
||||
model_path=$llm_path/merged_hf_model
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
torchrun --nproc_per_node=8 \
|
||||
infer_dataset.py \
|
||||
--output-dir $output_dir \
|
||||
--llm-model-name-or-path $model_path \
|
||||
--token2wav-path $token2wav_path \
|
||||
--split-name ${dataset} || exit 1
|
||||
|
||||
bash scripts/compute_wer.sh $output_dir ${dataset}
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "stage 5: Convert the RL trained model to CosyVoice repo format"
|
||||
python3 huggingface_to_pretrained.py \
|
||||
--hf-cosyvoice2-llm-path $llm_path/merged_hf_model \
|
||||
--output-path /workspace/CosyVoice2-0.5B/llm-new.pt
|
||||
# You need to manually move the llm-new.pt to overwrite /workspace/CosyVoice2-0.5B/llm.pt
|
||||
# However, we found that the RL trained model accuracy would slightly drop after this conversion.
|
||||
# Please be careful or use the huggingface format inference code.
|
||||
fi
|
||||
33
examples/grpo/cosyvoice2/scripts/compute_wer.sh
Normal file
33
examples/grpo/cosyvoice2/scripts/compute_wer.sh
Normal file
@@ -0,0 +1,33 @@
|
||||
wav_dir=$1
|
||||
wav_files=$(ls $wav_dir/*.wav)
|
||||
# if wav_files is empty, then exit
|
||||
if [ -z "$wav_files" ]; then
|
||||
exit 1
|
||||
fi
|
||||
split_name=$2
|
||||
model_path=models/sherpa-onnx-paraformer-zh-2023-09-14
|
||||
|
||||
if [ ! -d $model_path ]; then
|
||||
pip install sherpa-onnx
|
||||
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||
mkdir models
|
||||
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C models
|
||||
fi
|
||||
|
||||
python3 scripts/offline-decode-files.py \
|
||||
--tokens=$model_path/tokens.txt \
|
||||
--paraformer=$model_path/model.int8.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
--sample-rate=24000 \
|
||||
--log-dir $wav_dir \
|
||||
--feature-dim=80 \
|
||||
--split-name $split_name \
|
||||
--name sherpa_onnx \
|
||||
$wav_files
|
||||
|
||||
# python3 scripts/paraformer-pytriton-client.py \
|
||||
# --log-dir $wav_dir \
|
||||
# --split-name $split_name \
|
||||
# $wav_files
|
||||
754
examples/grpo/cosyvoice2/scripts/offline-decode-files.py
Normal file
754
examples/grpo/cosyvoice2/scripts/offline-decode-files.py
Normal file
@@ -0,0 +1,754 @@
|
||||
# Copyright (c) 2023 by manyeyes
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
"""
|
||||
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||
file(s) with a non-streaming model.
|
||||
|
||||
(1) For paraformer
|
||||
|
||||
./python-api-examples/offline-decode-files.py \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--paraformer=/path/to/paraformer.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
--sample-rate=16000 \
|
||||
--feature-dim=80 \
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
(2) For transducer models from icefall
|
||||
|
||||
./python-api-examples/offline-decode-files.py \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
--sample-rate=16000 \
|
||||
--feature-dim=80 \
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
(3) For CTC models from NeMo
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
|
||||
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
|
||||
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
|
||||
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
|
||||
|
||||
(4) For Whisper models
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||
--whisper-task=transcribe \
|
||||
--num-threads=1 \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
||||
|
||||
(5) For CTC models from WeNet
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
|
||||
|
||||
(6) For tdnn models of the yesno recipe from icefall
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--sample-rate=8000 \
|
||||
--feature-dim=23 \
|
||||
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
|
||||
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||
to install sherpa-onnx and to download non-streaming pre-trained models
|
||||
used in this file.
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Iterable, TextIO, Union
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
from datasets import load_dataset
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import kaldialign
|
||||
from zhon.hanzi import punctuation
|
||||
import string
|
||||
punctuation_all = punctuation + string.punctuation
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
|
||||
def remove_punctuation(text: str) -> str:
|
||||
for x in punctuation_all:
|
||||
if x == '\'':
|
||||
continue
|
||||
text = text.replace(x, '')
|
||||
return text
|
||||
|
||||
|
||||
def store_transcripts(
|
||||
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
||||
) -> None:
|
||||
"""Save predicted results and reference transcripts to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
File to save the results to.
|
||||
texts:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
If it is a multi-talker ASR system, the ref and hyp may also be lists of
|
||||
strings.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf8") as f:
|
||||
for cut_id, ref, hyp in texts:
|
||||
if char_level:
|
||||
ref = list("".join(ref))
|
||||
hyp = list("".join(hyp))
|
||||
print(f"{cut_id}:\tref={ref}", file=f)
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, str]],
|
||||
enable_log: bool = True,
|
||||
compute_CER: bool = False,
|
||||
sclite_mode: bool = False,
|
||||
) -> float:
|
||||
"""Write statistics based on predicted results and reference transcripts.
|
||||
|
||||
It will write the following to the given file:
|
||||
|
||||
- WER
|
||||
- number of insertions, deletions, substitutions, corrects and total
|
||||
reference words. For example::
|
||||
|
||||
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||
reference words (2337 correct)
|
||||
|
||||
- The difference between the reference transcript and predicted result.
|
||||
An instance is given below::
|
||||
|
||||
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||
|
||||
The above example shows that the reference word is `EDISON`,
|
||||
but it is predicted to `ADDISON` (a substitution error).
|
||||
|
||||
Another example is::
|
||||
|
||||
FOR THE FIRST DAY (SIR->*) I THINK
|
||||
|
||||
The reference word `SIR` is missing in the predicted
|
||||
results (a deletion error).
|
||||
results:
|
||||
An iterable of tuples. The first element is the cut_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
enable_log:
|
||||
If True, also print detailed WER to the console.
|
||||
Otherwise, it is written only to the given file.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||
ins: Dict[str, int] = defaultdict(int)
|
||||
dels: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# `words` stores counts per word, as follows:
|
||||
# corr, ref_sub, hyp_sub, ins, dels
|
||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||
num_corr = 0
|
||||
ERR = "*"
|
||||
|
||||
if compute_CER:
|
||||
for i, res in enumerate(results):
|
||||
cut_id, ref, hyp = res
|
||||
ref = list("".join(ref))
|
||||
hyp = list("".join(hyp))
|
||||
results[i] = (cut_id, ref, hyp)
|
||||
|
||||
for _cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||
for ref_word, hyp_word in ali:
|
||||
if ref_word == ERR:
|
||||
ins[hyp_word] += 1
|
||||
words[hyp_word][3] += 1
|
||||
elif hyp_word == ERR:
|
||||
dels[ref_word] += 1
|
||||
words[ref_word][4] += 1
|
||||
elif hyp_word != ref_word:
|
||||
subs[(ref_word, hyp_word)] += 1
|
||||
words[ref_word][1] += 1
|
||||
words[hyp_word][2] += 1
|
||||
else:
|
||||
words[ref_word][0] += 1
|
||||
num_corr += 1
|
||||
ref_len = sum([len(r) for _, r, _ in results])
|
||||
sub_errs = sum(subs.values())
|
||||
ins_errs = sum(ins.values())
|
||||
del_errs = sum(dels.values())
|
||||
tot_errs = sub_errs + ins_errs + del_errs
|
||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
print(
|
||||
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||
f"words ({num_corr} correct)",
|
||||
file=f,
|
||||
)
|
||||
print(
|
||||
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||
for cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
combine_successive_errors = True
|
||||
if combine_successive_errors:
|
||||
ali = [[[x], [y]] for x, y in ali]
|
||||
for i in range(len(ali) - 1):
|
||||
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||
ali[i] = [[], []]
|
||||
ali = [
|
||||
[
|
||||
list(filter(lambda a: a != ERR, x)),
|
||||
list(filter(lambda a: a != ERR, y)),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
ali = list(filter(lambda x: x != [[], []], ali))
|
||||
ali = [
|
||||
[
|
||||
ERR if x == [] else " ".join(x),
|
||||
ERR if y == [] else " ".join(y),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
|
||||
print(
|
||||
f"{cut_id}:\t"
|
||||
+ " ".join(
|
||||
(
|
||||
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||
for ref_word, hyp_word in ali
|
||||
)
|
||||
),
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||
|
||||
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||
print(f"{count} {ref} -> {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("DELETIONS: count ref", file=f)
|
||||
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||
print(f"{count} {ref}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("INSERTIONS: count hyp", file=f)
|
||||
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||
print(f"{count} {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||
for _, word, counts in sorted(
|
||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||
):
|
||||
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||
ref_count = corr + ref_sub + dels
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hotwords-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The file containing hotwords, one words/phrases per line, like
|
||||
HELLO WORLD
|
||||
你好世界
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hotwords-score",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="""
|
||||
The hotword score of each token for biasing word/phrase. Used only if
|
||||
--hotwords-file is given.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--modeling-unit",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
|
||||
Used only when hotwords-file is given.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-vocab",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path to the bpe vocabulary, the bpe vocabulary is generated by
|
||||
sentencepiece, you can also export the bpe vocabulary through a bpe model
|
||||
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
|
||||
and modeling-unit is bpe or cjkchar+bpe.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--paraformer",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from Paraformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nemo-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from NeMo CTC",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wenet-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from WeNet CTC",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tdnn-model",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx for the tdnn model of the yesno recipe",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to whisper encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to whisper decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-language",
|
||||
default="",
|
||||
type=str,
|
||||
help="""It specifies the spoken language in the input audio file.
|
||||
Example values: en, fr, de, zh, jp.
|
||||
Available languages for multilingual models can be found at
|
||||
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||
If not specified, we infer the language from the input audio file.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-task",
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
type=str,
|
||||
help="""For multilingual models, if you specify translate, the output
|
||||
will be in English.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-tail-paddings",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="""Number of tail padding frames.
|
||||
We have removed the 30-second constraint from whisper, so you need to
|
||||
choose the amount of tail padding frames by yourself.
|
||||
Use -1 to use a default value for tail padding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="Valid values are greedy_search and modified_beam_search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="True to show debug messages",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="""Sample rate of the feature extractor. Must match the one
|
||||
expected by the model. Note: The input sound files can have a
|
||||
different sample rate from this argument.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feature-dim",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Feature dimension. Must match the one expected by the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||
"format with a single channel, and each sample has 16-bit, "
|
||||
"i.e., int16_t. "
|
||||
"The sample rate of the file can be arbitrary and does not need to "
|
||||
"be 16 kHz",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory containing the input sound files to decode",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory containing the input sound files to decode",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--label",
|
||||
type=str,
|
||||
default=None,
|
||||
help="wav_base_name label",
|
||||
)
|
||||
|
||||
# Dataset related arguments for loading labels when label file is not provided
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="yuekai/seed_tts_cosy2",
|
||||
help="Huggingface dataset name for loading labels",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
help="Dataset split name for loading labels",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and can be of type
|
||||
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples,
|
||||
which are normalized to the range [-1, 1].
|
||||
- Sample rate of the wave file.
|
||||
"""
|
||||
|
||||
samples, sample_rate = sf.read(wave_filename, dtype="float32")
|
||||
assert (
|
||||
samples.ndim == 1
|
||||
), f"Expected single channel, but got {samples.ndim} channels."
|
||||
|
||||
samples_float32 = samples.astype(np.float32)
|
||||
|
||||
return samples_float32, sample_rate
|
||||
|
||||
|
||||
def normalize_text_alimeeting(text: str) -> str:
|
||||
"""
|
||||
Text normalization similar to M2MeT challenge baseline.
|
||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||
"""
|
||||
import re
|
||||
text = text.replace('\u00A0', '') # test_hard
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
text = text.replace("<->", "")
|
||||
text = text.replace("<$>", "")
|
||||
text = text.replace("<#>", "")
|
||||
text = text.replace("<_>", "")
|
||||
text = text.replace("<space>", "")
|
||||
text = text.replace("`", "")
|
||||
text = text.replace("&", "")
|
||||
text = text.replace(",", "")
|
||||
if re.search("[a-zA-Z]", text):
|
||||
text = text.upper()
|
||||
text = text.replace("A", "A")
|
||||
text = text.replace("a", "A")
|
||||
text = text.replace("b", "B")
|
||||
text = text.replace("c", "C")
|
||||
text = text.replace("k", "K")
|
||||
text = text.replace("t", "T")
|
||||
text = text.replace(",", "")
|
||||
text = text.replace("丶", "")
|
||||
text = text.replace("。", "")
|
||||
text = text.replace("、", "")
|
||||
text = text.replace("?", "")
|
||||
text = remove_punctuation(text)
|
||||
return text
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert_file_exists(args.tokens)
|
||||
assert args.num_threads > 0, args.num_threads
|
||||
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
|
||||
assert_file_exists(args.paraformer)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||
paraformer=args.paraformer,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
print("Started!")
|
||||
start_time = time.time()
|
||||
|
||||
streams, results = [], []
|
||||
total_duration = 0
|
||||
|
||||
for i, wave_filename in enumerate(args.sound_files):
|
||||
assert_file_exists(wave_filename)
|
||||
samples, sample_rate = read_wave(wave_filename)
|
||||
duration = len(samples) / sample_rate
|
||||
total_duration += duration
|
||||
s = recognizer.create_stream()
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
|
||||
streams.append(s)
|
||||
if i % 10 == 0:
|
||||
recognizer.decode_streams(streams)
|
||||
results += [s.result.text for s in streams]
|
||||
streams = []
|
||||
print(f"Processed {i} files")
|
||||
# process the last batch
|
||||
if streams:
|
||||
recognizer.decode_streams(streams)
|
||||
results += [s.result.text for s in streams]
|
||||
end_time = time.time()
|
||||
print("Done!")
|
||||
|
||||
results_dict = {}
|
||||
for wave_filename, result in zip(args.sound_files, results):
|
||||
print(f"{wave_filename}\n{result}")
|
||||
print("-" * 10)
|
||||
wave_basename = Path(wave_filename).stem
|
||||
results_dict[wave_basename] = result
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / total_duration
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
print(f"decoding_method: {args.decoding_method}")
|
||||
print(f"Wave duration: {total_duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
|
||||
# Load labels either from file or from dataset
|
||||
labels_dict = {}
|
||||
|
||||
if args.label:
|
||||
# Load labels from file (original functionality)
|
||||
print(f"Loading labels from file: {args.label}")
|
||||
with open(args.label, "r") as f:
|
||||
for line in f:
|
||||
# fields = line.strip().split(" ")
|
||||
# fields = [item for item in fields if item]
|
||||
# assert len(fields) == 4
|
||||
# prompt_text, prompt_audio, text, audio_path = fields
|
||||
|
||||
fields = line.strip().split("|")
|
||||
fields = [item for item in fields if item]
|
||||
assert len(fields) == 4
|
||||
audio_path, prompt_text, prompt_audio, text = fields
|
||||
labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
|
||||
else:
|
||||
# Load labels from dataset (new functionality)
|
||||
print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
|
||||
if 'zero' in args.split_name:
|
||||
dataset_name = "yuekai/CV3-Eval"
|
||||
else:
|
||||
dataset_name = "yuekai/seed_tts_cosy2"
|
||||
dataset = load_dataset(
|
||||
dataset_name,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
for item in dataset:
|
||||
audio_id = item["id"]
|
||||
labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
|
||||
|
||||
print(f"Loaded {len(labels_dict)} labels from dataset")
|
||||
|
||||
# Perform evaluation if labels are available
|
||||
if labels_dict:
|
||||
|
||||
final_results = []
|
||||
for key, value in results_dict.items():
|
||||
if key in labels_dict:
|
||||
final_results.append((key, labels_dict[key], value))
|
||||
else:
|
||||
print(f"Warning: No label found for {key}, skipping...")
|
||||
|
||||
if final_results:
|
||||
store_transcripts(
|
||||
filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
|
||||
)
|
||||
with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
|
||||
write_error_stats(f, "test-set", final_results, enable_log=True)
|
||||
|
||||
with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
|
||||
print(f.readline()) # WER
|
||||
print(f.readline()) # Detailed errors
|
||||
else:
|
||||
print("No matching labels found for evaluation")
|
||||
else:
|
||||
print("No labels available for evaluation")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
346
examples/grpo/cosyvoice2/token2wav_asr_server.py
Normal file
346
examples/grpo/cosyvoice2/token2wav_asr_server.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pytriton server for token2wav conversion and ASR"""
|
||||
|
||||
from datasets import load_dataset
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from omnisense.models import OmniSenseVoiceSmall
|
||||
from pytriton.proxy.types import Request
|
||||
from pytriton.triton import Triton, TritonConfig
|
||||
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
|
||||
from pytriton.decorators import batch
|
||||
import argparse
|
||||
import io
|
||||
import logging
|
||||
from typing import Any, List
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.signal import resample
|
||||
import sys
|
||||
import random
|
||||
import re
|
||||
from jiwer import wer
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||
|
||||
# Chinese text normalizer (cached globally)
|
||||
zh_tn_model = ZhNormalizer(
|
||||
cache_dir="./cache",
|
||||
remove_erhua=False,
|
||||
remove_interjections=False,
|
||||
remove_puncts=True,
|
||||
overwrite_cache=True,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
logger = logging.getLogger("token2wav_asr_server")
|
||||
|
||||
|
||||
class _ASR_Server:
|
||||
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
||||
|
||||
def __init__(self, device_id: int):
|
||||
self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
|
||||
|
||||
@batch
|
||||
def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray):
|
||||
"""
|
||||
WAV: np.ndarray, WAV_LENS: np.ndarray
|
||||
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
|
||||
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
|
||||
"""
|
||||
logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape)
|
||||
wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))]
|
||||
|
||||
results = self._model.transcribe_single_batch(
|
||||
wavs,
|
||||
language="zh",
|
||||
textnorm="woitn",
|
||||
)
|
||||
texts = [result.text for result in results]
|
||||
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
||||
return {"TRANSCRIPTS": transcripts}
|
||||
|
||||
|
||||
def audio_decode_cosyvoice2(
|
||||
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||
):
|
||||
"""
|
||||
Generate audio from tokens with optional tone and prompt embedding.
|
||||
"""
|
||||
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
|
||||
"empty", prompt_text, prompt_speech_16k, 24000
|
||||
)
|
||||
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||
token=audio_tokens.to(codec_decoder.model.device),
|
||||
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token_len=torch.tensor(
|
||||
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
|
||||
finalize=True,
|
||||
)
|
||||
|
||||
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
|
||||
return audio_hat
|
||||
|
||||
|
||||
def get_random_prompt_from_dataset(dataset):
|
||||
"""
|
||||
Get random prompt text and speech from the pre-loaded dataset.
|
||||
Returns (prompt_text, prompt_speech_16k)
|
||||
"""
|
||||
random_idx = random.randint(0, len(dataset) - 1)
|
||||
sample = dataset[random_idx]
|
||||
|
||||
# Extract audio data
|
||||
audio_data = sample["audio"]
|
||||
audio_array = audio_data["array"]
|
||||
sample_rate = audio_data["sampling_rate"]
|
||||
|
||||
# Convert audio to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
num_samples = int(len(audio_array) * (16000 / sample_rate))
|
||||
audio_array = resample(audio_array, num_samples)
|
||||
|
||||
# Convert to torch tensor
|
||||
prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0)
|
||||
prompt_text = sample["text"]
|
||||
# remove space in prompt_text
|
||||
prompt_text = prompt_text.replace(" ", "")
|
||||
return prompt_text, prompt_speech_16k
|
||||
|
||||
|
||||
class _Token2Wav_ASR:
|
||||
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
||||
|
||||
def __init__(self, device_id: int):
|
||||
self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
|
||||
self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"]
|
||||
|
||||
# Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model
|
||||
# CosyVoice2 internally uses generic "cuda" device, so we first switch the
|
||||
# current CUDA context to the desired card before the object is created.
|
||||
# Afterwards, all parameters loaded with the generic "cuda" device will
|
||||
# reside on this GPU. We keep the selected id in `self.device_id` and
|
||||
# will set the context again for every forward call to avoid race
|
||||
# conditions when several instances are used in the same process.
|
||||
|
||||
self.device_id = device_id
|
||||
|
||||
# Construct the TTS codec decoder under the correct CUDA device context
|
||||
with torch.cuda.device(self.device_id):
|
||||
self.codec_decoder = CosyVoice2(
|
||||
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
|
||||
)
|
||||
|
||||
@batch
|
||||
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
|
||||
"""
|
||||
WAV: np.ndarray, WAV_LENS: np.ndarray
|
||||
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
|
||||
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
|
||||
"""
|
||||
# Ensure the default CUDA device is set correctly for this invocation
|
||||
torch.cuda.set_device(self.device_id)
|
||||
|
||||
if self.device_id == 0:
|
||||
print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")
|
||||
|
||||
tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]
|
||||
|
||||
# Decode ground-truth text strings (BYTES → str)
|
||||
if GT_TEXT.ndim == 2:
|
||||
gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
|
||||
else:
|
||||
gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]
|
||||
|
||||
wavs = []
|
||||
for tokens in tokens_list:
|
||||
prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
|
||||
audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
|
||||
audio_hat = audio_decode_cosyvoice2(
|
||||
audio_tokens,
|
||||
prompt_text,
|
||||
prompt_speech_16k,
|
||||
self.codec_decoder,
|
||||
)
|
||||
# resample to 16000 using soundfile
|
||||
audio_hat = audio_hat.squeeze(0).float().cpu()
|
||||
audio_hat = audio_hat.numpy()
|
||||
num_samples = int(len(audio_hat) * (16000 / 24000))
|
||||
audio_hat = resample(audio_hat, num_samples)
|
||||
wavs.append(audio_hat)
|
||||
|
||||
results = self.asr_model.transcribe_single_batch(
|
||||
wavs,
|
||||
language="zh",
|
||||
textnorm="woitn",
|
||||
)
|
||||
texts = [result.text for result in results]
|
||||
|
||||
# ---------------- Reward computation ----------------
|
||||
rewards = []
|
||||
for gt_text, hyp_text in zip(gt_texts, texts):
|
||||
gt_norm = zh_tn_model.normalize(gt_text).lower()
|
||||
hyp_norm = zh_tn_model.normalize(hyp_text).lower()
|
||||
|
||||
gt_pinyin = lazy_pinyin(
|
||||
gt_norm,
|
||||
style=Style.TONE3,
|
||||
tone_sandhi=True,
|
||||
neutral_tone_with_five=True,
|
||||
)
|
||||
hyp_pinyin = lazy_pinyin(
|
||||
hyp_norm,
|
||||
style=Style.TONE3,
|
||||
tone_sandhi=True,
|
||||
neutral_tone_with_five=True,
|
||||
)
|
||||
|
||||
c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin)))
|
||||
reward_val = 1.0 - np.tanh(3.0 * c)
|
||||
reward_val = max(0.0, min(1.0, reward_val))
|
||||
rewards.append(reward_val)
|
||||
print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}")
|
||||
|
||||
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
||||
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
|
||||
|
||||
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
|
||||
|
||||
|
||||
def _infer_function_factory(device_ids: List[int], model_name: str):
|
||||
"""Creates a list of inference functions, one for each requested device ID."""
|
||||
infer_funcs = []
|
||||
for device_id in device_ids:
|
||||
if model_name == "sensevoice":
|
||||
infer_funcs.append(_ASR_Server(device_id=device_id))
|
||||
else:
|
||||
infer_funcs.append(_Token2Wav_ASR(device_id=device_id))
|
||||
return infer_funcs
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--max-batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size of request.",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--number-of-instances-per-device",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of model instances to load.",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--number-of-devices",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of devices to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="token2wav_asr",
|
||||
choices=["token2wav_asr", "sensevoice"],
|
||||
help="Model name.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
log_level = logging.DEBUG if args.verbose else logging.INFO
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
|
||||
|
||||
triton_config = TritonConfig(
|
||||
http_port=8000,
|
||||
grpc_port=8001,
|
||||
metrics_port=8002,
|
||||
)
|
||||
|
||||
device_ids = list(range(args.number_of_devices))
|
||||
device_ids = device_ids * args.number_of_instances_per_device
|
||||
|
||||
with Triton(config=triton_config) as triton:
|
||||
logger.info("Loading SenseVoice model on device ids: %s", device_ids)
|
||||
if args.model_name == "sensevoice":
|
||||
triton.bind(
|
||||
model_name="sensevoice",
|
||||
infer_func=_infer_function_factory(device_ids, args.model_name),
|
||||
inputs=[
|
||||
Tensor(name="WAV", dtype=np.float32, shape=(-1,)),
|
||||
Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)),
|
||||
Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)),
|
||||
Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)),
|
||||
],
|
||||
outputs=[
|
||||
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
|
||||
],
|
||||
config=ModelConfig(
|
||||
max_batch_size=args.max_batch_size,
|
||||
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
|
||||
),
|
||||
strict=True,
|
||||
)
|
||||
else:
|
||||
triton.bind(
|
||||
model_name="token2wav_asr",
|
||||
infer_func=_infer_function_factory(device_ids, args.model_name),
|
||||
inputs=[
|
||||
Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)),
|
||||
Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)),
|
||||
Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)),
|
||||
],
|
||||
outputs=[
|
||||
Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)),
|
||||
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
|
||||
],
|
||||
config=ModelConfig(
|
||||
max_batch_size=args.max_batch_size,
|
||||
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
|
||||
),
|
||||
strict=True,
|
||||
)
|
||||
logger.info("Serving inference")
|
||||
triton.serve()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,257 +0,0 @@
|
||||
# set random seed, so that you may reproduce your result.
|
||||
__set_seed1: !apply:random.seed [1986]
|
||||
__set_seed2: !apply:numpy.random.seed [1986]
|
||||
__set_seed3: !apply:torch.manual_seed [1986]
|
||||
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
||||
|
||||
# fixed params
|
||||
sample_rate: 22050
|
||||
text_encoder_input_size: 512
|
||||
llm_input_size: 1024
|
||||
llm_output_size: 1024
|
||||
spk_embed_dim: 192
|
||||
|
||||
# model params
|
||||
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
||||
# for system/third_party class/function, we do not require this.
|
||||
llm: !new:cosyvoice.llm.llm.TransformerLM
|
||||
text_encoder_input_size: !ref <text_encoder_input_size>
|
||||
llm_input_size: !ref <llm_input_size>
|
||||
llm_output_size: !ref <llm_output_size>
|
||||
text_token_size: 51866 # change to 60515 if you want to train with CosyVoice-300M-25Hz recipe
|
||||
speech_token_size: 4096
|
||||
length_normalized_loss: True
|
||||
lsm_weight: 0
|
||||
spk_embed_dim: !ref <spk_embed_dim>
|
||||
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
||||
input_size: !ref <text_encoder_input_size>
|
||||
output_size: 1024
|
||||
attention_heads: 8
|
||||
linear_units: 2048
|
||||
num_blocks: 3
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
normalize_before: True
|
||||
input_layer: 'linear'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
use_cnn_module: False
|
||||
macaron_style: False
|
||||
use_dynamic_chunk: False
|
||||
use_dynamic_left_chunk: False
|
||||
static_chunk_size: 1
|
||||
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
|
||||
input_size: !ref <llm_input_size>
|
||||
output_size: !ref <llm_output_size>
|
||||
attention_heads: 8
|
||||
linear_units: 2048
|
||||
num_blocks: 7
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: 'linear_legacy'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
static_chunk_size: 1
|
||||
sampling: !name:cosyvoice.utils.common.ras_sampling
|
||||
top_p: 0.8
|
||||
top_k: 25
|
||||
win_size: 10
|
||||
tau_r: 0.1
|
||||
|
||||
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
||||
input_size: 512
|
||||
output_size: 80
|
||||
spk_embed_dim: !ref <spk_embed_dim>
|
||||
output_type: 'mel'
|
||||
vocab_size: 4096
|
||||
input_frame_rate: 50 # change to 25 if you want to train with CosyVoice-300M-25Hz recipe
|
||||
only_mask_loss: True
|
||||
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
||||
output_size: 512
|
||||
attention_heads: 4
|
||||
linear_units: 1024
|
||||
num_blocks: 3
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.1
|
||||
normalize_before: True
|
||||
input_layer: 'linear'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
input_size: 512
|
||||
use_cnn_module: False
|
||||
macaron_style: False
|
||||
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
|
||||
channels: 80
|
||||
sampling_ratios: [1, 1, 1, 1]
|
||||
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
|
||||
in_channels: 240
|
||||
n_spks: 1
|
||||
spk_emb_dim: 80
|
||||
cfm_params: !new:omegaconf.DictConfig
|
||||
content:
|
||||
sigma_min: 1e-06
|
||||
solver: 'euler'
|
||||
t_scheduler: 'cosine'
|
||||
training_cfg_rate: 0.2
|
||||
inference_cfg_rate: 0.7
|
||||
reg_loss_type: 'l1'
|
||||
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
|
||||
in_channels: 320
|
||||
out_channels: 80
|
||||
channels: [256, 256]
|
||||
dropout: 0.0
|
||||
attention_head_dim: 64
|
||||
n_blocks: 4
|
||||
num_mid_blocks: 8
|
||||
num_heads: 8
|
||||
act_fn: 'gelu'
|
||||
|
||||
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
||||
in_channels: 80
|
||||
base_channels: 512
|
||||
nb_harmonics: 8
|
||||
sampling_rate: !ref <sample_rate>
|
||||
nsf_alpha: 0.1
|
||||
nsf_sigma: 0.003
|
||||
nsf_voiced_threshold: 10
|
||||
upsample_rates: [8, 8]
|
||||
upsample_kernel_sizes: [16, 16]
|
||||
istft_params:
|
||||
n_fft: 16
|
||||
hop_len: 4
|
||||
resblock_kernel_sizes: [3, 7, 11]
|
||||
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
source_resblock_kernel_sizes: [7, 11]
|
||||
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
|
||||
lrelu_slope: 0.1
|
||||
audio_limit: 0.99
|
||||
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
|
||||
num_class: 1
|
||||
in_channels: 80
|
||||
cond_channels: 512
|
||||
|
||||
# gan related module
|
||||
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
|
||||
n_fft: 1024
|
||||
num_mels: 80
|
||||
sampling_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
win_size: 1024
|
||||
fmin: 0
|
||||
fmax: null
|
||||
center: False
|
||||
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
|
||||
generator: !ref <hift>
|
||||
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
|
||||
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
|
||||
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
|
||||
mel_spec_transform: [
|
||||
!ref <mel_spec_transform1>
|
||||
]
|
||||
|
||||
# processor functions
|
||||
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
||||
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
|
||||
multilingual: True
|
||||
num_languages: 100
|
||||
language: 'en'
|
||||
task: 'transcribe'
|
||||
allowed_special: 'all'
|
||||
tokenize: !name:cosyvoice.dataset.processor.tokenize
|
||||
get_tokenizer: !ref <get_tokenizer>
|
||||
allowed_special: !ref <allowed_special>
|
||||
filter: !name:cosyvoice.dataset.processor.filter
|
||||
max_length: 40960
|
||||
min_length: 0
|
||||
token_max_length: 200
|
||||
token_min_length: 1
|
||||
resample: !name:cosyvoice.dataset.processor.resample
|
||||
resample_rate: !ref <sample_rate>
|
||||
truncate: !name:cosyvoice.dataset.processor.truncate
|
||||
truncate_length: 24576 # must be a multiplier of hop_size
|
||||
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||
n_fft: 1024
|
||||
num_mels: 80
|
||||
sampling_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
win_size: 1024
|
||||
fmin: 0
|
||||
fmax: 8000
|
||||
center: False
|
||||
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
||||
feat_extractor: !ref <feat_extractor>
|
||||
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
||||
sample_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
||||
normalize: True
|
||||
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
||||
shuffle_size: 1000
|
||||
sort: !name:cosyvoice.dataset.processor.sort
|
||||
sort_size: 500 # sort_size should be less than shuffle_size
|
||||
batch: !name:cosyvoice.dataset.processor.batch
|
||||
batch_type: 'dynamic'
|
||||
max_frames_in_batch: 12000
|
||||
padding: !name:cosyvoice.dataset.processor.padding
|
||||
use_spk_embedding: False # change to True during sft
|
||||
|
||||
# dataset processor pipeline
|
||||
data_pipeline: [
|
||||
!ref <parquet_opener>,
|
||||
!ref <tokenize>,
|
||||
!ref <filter>,
|
||||
!ref <resample>,
|
||||
!ref <compute_fbank>,
|
||||
!ref <parse_embedding>,
|
||||
!ref <shuffle>,
|
||||
!ref <sort>,
|
||||
!ref <batch>,
|
||||
!ref <padding>,
|
||||
]
|
||||
data_pipeline_gan: [
|
||||
!ref <parquet_opener>,
|
||||
!ref <tokenize>,
|
||||
!ref <filter>,
|
||||
!ref <resample>,
|
||||
!ref <truncate>,
|
||||
!ref <compute_fbank>,
|
||||
!ref <compute_f0>,
|
||||
!ref <parse_embedding>,
|
||||
!ref <shuffle>,
|
||||
!ref <sort>,
|
||||
!ref <batch>,
|
||||
!ref <padding>,
|
||||
]
|
||||
|
||||
# llm flow train conf
|
||||
train_conf:
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.002 # change to 0.001 if you want to train flow from scratch
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
max_epoch: 200
|
||||
grad_clip: 5
|
||||
accum_grad: 2
|
||||
log_interval: 100
|
||||
save_per_step: -1
|
||||
|
||||
# gan train conf
|
||||
train_conf_gan:
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.0002 # use small lr for gan training
|
||||
scheduler: constantlr
|
||||
optim_d: adam
|
||||
optim_conf_d:
|
||||
lr: 0.0002 # use small lr for gan training
|
||||
scheduler_d: constantlr
|
||||
max_epoch: 200
|
||||
grad_clip: 5
|
||||
accum_grad: 1 # in gan training, accum_grad must be 1
|
||||
log_interval: 100
|
||||
save_per_step: -1
|
||||
@@ -1 +0,0 @@
|
||||
../../../cosyvoice
|
||||
@@ -40,6 +40,10 @@ def main():
|
||||
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
|
||||
for k, v in spk2utt.items():
|
||||
f.write('{} {}\n'.format(k, ' '.join(v)))
|
||||
if args.instruct != '':
|
||||
with open('{}/instruct'.format(args.des_dir), 'w') as f:
|
||||
for k, v in utt2text.items():
|
||||
f.write('{} {}\n'.format(k, args.instruct))
|
||||
return
|
||||
|
||||
|
||||
@@ -49,7 +53,8 @@ if __name__ == "__main__":
|
||||
type=str)
|
||||
parser.add_argument('--des_dir',
|
||||
type=str)
|
||||
parser.add_argument('--ref_model',
|
||||
type=str)
|
||||
parser.add_argument('--instruct',
|
||||
type=str,
|
||||
default='')
|
||||
args = parser.parse_args()
|
||||
main()
|
||||
|
||||
@@ -27,7 +27,7 @@ fi
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
tools/extract_embedding.py --dir data/$x \
|
||||
../../../tools/extract_embedding.py --dir data/$x \
|
||||
--onnx_path $pretrained_model_dir/campplus.onnx
|
||||
done
|
||||
fi
|
||||
@@ -35,7 +35,7 @@ fi
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
tools/extract_speech_token.py --dir data/$x \
|
||||
../../../tools/extract_speech_token.py --dir data/$x \
|
||||
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
|
||||
done
|
||||
fi
|
||||
@@ -44,7 +44,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mkdir -p data/$x/parquet
|
||||
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||
--num_processes 10 \
|
||||
--src_dir data/$x \
|
||||
--des_dir data/$x/parquet
|
||||
@@ -60,7 +60,7 @@ num_workers=2
|
||||
prefetch=100
|
||||
train_engine=torch_ddp
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
|
||||
echo "Run train. We only support llm traning for now"
|
||||
if [ $train_engine == 'deepspeed' ]; then
|
||||
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
||||
fi
|
||||
@@ -69,7 +69,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
for model in llm flow hifigan; do
|
||||
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
|
||||
cosyvoice/bin/train.py \
|
||||
../../../cosyvoice/bin/train.py \
|
||||
--train_engine $train_engine \
|
||||
--config conf/cosyvoice.yaml \
|
||||
--train_data data/train.data.list \
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../../tools
|
||||
@@ -139,7 +139,7 @@ tokenize: !name:cosyvoice.dataset.processor.tokenize
|
||||
get_tokenizer: !ref <get_tokenizer>
|
||||
allowed_special: !ref <allowed_special>
|
||||
filter: !name:cosyvoice.dataset.processor.filter
|
||||
max_length: 40960
|
||||
max_length: 6000
|
||||
min_length: 100
|
||||
token_max_length: 200
|
||||
token_min_length: 1
|
||||
@@ -158,7 +158,9 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||
center: False
|
||||
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
||||
feat_extractor: !ref <feat_extractor>
|
||||
token_mel_ratio: 2
|
||||
num_frames: 960
|
||||
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
|
||||
num_frames: 960
|
||||
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
||||
sample_rate: !ref <sample_rate>
|
||||
hop_size: 480
|
||||
@@ -183,6 +185,7 @@ data_pipeline: [
|
||||
!ref <resample>,
|
||||
!ref <compute_fbank>,
|
||||
!ref <parse_embedding>,
|
||||
!ref <compute_whisper_fbank>,
|
||||
!ref <shuffle>,
|
||||
!ref <sort>,
|
||||
!ref <batch>,
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../../cosyvoice
|
||||
@@ -24,27 +24,12 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
tools/extract_embedding.py --dir data/$x \
|
||||
--onnx_path $pretrained_model_dir/campplus.onnx
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
tools/extract_speech_token.py --dir data/$x \
|
||||
--onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
|
||||
done
|
||||
fi
|
||||
|
||||
# NOTE embedding/token extraction is not necessary now as we support online feature extraction
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mkdir -p data/$x/parquet
|
||||
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||
--num_processes 10 \
|
||||
--src_dir data/$x \
|
||||
--des_dir data/$x/parquet
|
||||
@@ -60,22 +45,22 @@ num_workers=2
|
||||
prefetch=100
|
||||
train_engine=torch_ddp
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
|
||||
echo "Run train. We only support llm traning for now"
|
||||
if [ $train_engine == 'deepspeed' ]; then
|
||||
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
||||
fi
|
||||
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
||||
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
||||
# NOTE will update llm/hift training later
|
||||
for model in llm flow hifigan; do
|
||||
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
|
||||
cosyvoice/bin/train.py \
|
||||
../../../cosyvoice/bin/train.py \
|
||||
--train_engine $train_engine \
|
||||
--config conf/cosyvoice2.yaml \
|
||||
--train_data data/train.data.list \
|
||||
--cv_data data/dev.data.list \
|
||||
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
|
||||
--onnx_path $pretrained_model_dir \
|
||||
--model $model \
|
||||
--checkpoint $pretrained_model_dir/$model.pt \
|
||||
--model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \
|
||||
|
||||
@@ -36,7 +36,7 @@ fi
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
tools/extract_embedding.py --dir data/$x \
|
||||
../../../tools/extract_embedding.py --dir data/$x \
|
||||
--onnx_path $pretrained_model_dir/campplus.onnx
|
||||
done
|
||||
fi
|
||||
@@ -44,7 +44,7 @@ fi
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 train-clean-100_reject train-clean-360_reject dev-clean dev-other test-clean test-other; do
|
||||
tools/extract_speech_token.py --dir data/$x \
|
||||
../../../tools/extract_speech_token.py --dir data/$x \
|
||||
--onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
|
||||
done
|
||||
fi
|
||||
@@ -53,7 +53,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mkdir -p data/$x/parquet
|
||||
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||
--num_processes 10 \
|
||||
--dpo \
|
||||
--src_dir data/$x \
|
||||
@@ -70,7 +70,7 @@ num_workers=2
|
||||
prefetch=100
|
||||
train_engine=torch_ddp
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
|
||||
echo "Run train. We only support llm traning for now"
|
||||
if [ $train_engine == 'deepspeed' ]; then
|
||||
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
||||
fi
|
||||
@@ -80,11 +80,12 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
for model in llm; do
|
||||
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
|
||||
cosyvoice/bin/train.py \
|
||||
../../../cosyvoice/bin/train.py \
|
||||
--train_engine $train_engine \
|
||||
--config conf/cosyvoice2.yaml \
|
||||
--train_data data/train.data.list \
|
||||
--cv_data data/dev.data.list \
|
||||
--onnx_path $pretrained_model_dir \
|
||||
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
|
||||
--model $model \
|
||||
--checkpoint $pretrained_model_dir/$model.pt \
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../../tools
|
||||
@@ -5,22 +5,28 @@ __set_seed3: !apply:torch.manual_seed [1986]
|
||||
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
||||
|
||||
# fixed params
|
||||
sample_rate: 24000 # 16000 for llm, 24000 for cfm
|
||||
sample_rate: 24000
|
||||
llm_input_size: 896
|
||||
llm_output_size: 896
|
||||
spk_embed_dim: 192
|
||||
qwen_pretrain_path: 'CosyVoice2-0.5B/CosyVoice-BlankEN'
|
||||
qwen_pretrain_path: ''
|
||||
token_frame_rate: 25
|
||||
token_mel_ratio: 2
|
||||
|
||||
# stream related params
|
||||
chunk_size: 25 # streaming inference chunk size, in token
|
||||
num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
|
||||
|
||||
# model params
|
||||
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
||||
# for system/third_party class/function, we do not require this.
|
||||
llm: !new:cosyvoice.llm.llm_dpo.Qwen2LM
|
||||
llm: !new:cosyvoice.llm.llm.CosyVoice3LM
|
||||
llm_input_size: !ref <llm_input_size>
|
||||
llm_output_size: !ref <llm_output_size>
|
||||
speech_token_size: 6561
|
||||
length_normalized_loss: True
|
||||
lsm_weight: 0
|
||||
dpo: True
|
||||
mix_ratio: [5, 15]
|
||||
llm: !new:cosyvoice.llm.llm.Qwen2Encoder
|
||||
pretrain_path: !ref <qwen_pretrain_path>
|
||||
sampling: !name:cosyvoice.utils.common.ras_sampling
|
||||
@@ -28,31 +34,21 @@ llm: !new:cosyvoice.llm.llm_dpo.Qwen2LM
|
||||
top_k: 25
|
||||
win_size: 10
|
||||
tau_r: 0.1
|
||||
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
|
||||
input_size: 512
|
||||
|
||||
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithDiT
|
||||
input_size: 80
|
||||
output_size: 80
|
||||
spk_embed_dim: !ref <spk_embed_dim>
|
||||
output_type: 'mel'
|
||||
vocab_size: 6561
|
||||
input_frame_rate: 25
|
||||
input_frame_rate: !ref <token_frame_rate>
|
||||
only_mask_loss: True
|
||||
token_mel_ratio: 2
|
||||
token_mel_ratio: !ref <token_mel_ratio>
|
||||
pre_lookahead_len: 3
|
||||
encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
|
||||
output_size: 512
|
||||
attention_heads: 8
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.1
|
||||
normalize_before: True
|
||||
input_layer: 'linear'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
input_size: 512
|
||||
use_cnn_module: False
|
||||
macaron_style: False
|
||||
pre_lookahead_layer: !new:cosyvoice.transformer.upsample_encoder.PreLookaheadLayer
|
||||
in_channels: 80
|
||||
channels: 1024
|
||||
pre_lookahead_len: 3
|
||||
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
|
||||
in_channels: 240
|
||||
n_spks: 1
|
||||
@@ -65,19 +61,20 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
|
||||
training_cfg_rate: 0.2
|
||||
inference_cfg_rate: 0.7
|
||||
reg_loss_type: 'l1'
|
||||
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
|
||||
in_channels: 320
|
||||
estimator: !new:cosyvoice.flow.DiT.dit.DiT
|
||||
dim: 1024
|
||||
depth: 22
|
||||
heads: 16
|
||||
dim_head: 64
|
||||
ff_mult: 2
|
||||
mel_dim: 80
|
||||
mu_dim: 80
|
||||
spk_dim: 80
|
||||
out_channels: 80
|
||||
causal: True
|
||||
channels: [256]
|
||||
dropout: 0.0
|
||||
attention_head_dim: 64
|
||||
n_blocks: 4
|
||||
num_mid_blocks: 12
|
||||
num_heads: 8
|
||||
act_fn: 'gelu'
|
||||
static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
|
||||
num_decoding_left_chunks: !ref <num_decoding_left_chunks>
|
||||
|
||||
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
||||
hift: !new:cosyvoice.hifigan.generator.CausalHiFTGenerator
|
||||
in_channels: 80
|
||||
base_channels: 512
|
||||
nb_harmonics: 8
|
||||
@@ -96,18 +93,19 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
||||
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
lrelu_slope: 0.1
|
||||
audio_limit: 0.99
|
||||
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
|
||||
conv_pre_look_right: 4
|
||||
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.CausalConvRNNF0Predictor
|
||||
num_class: 1
|
||||
in_channels: 80
|
||||
cond_channels: 512
|
||||
|
||||
# gan related module
|
||||
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
|
||||
n_fft: 1024
|
||||
n_fft: 1920
|
||||
num_mels: 80
|
||||
sampling_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
win_size: 1024
|
||||
hop_size: 480
|
||||
win_size: 1920
|
||||
fmin: 0
|
||||
fmax: null
|
||||
center: False
|
||||
@@ -115,45 +113,47 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
|
||||
generator: !ref <hift>
|
||||
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
|
||||
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
|
||||
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
|
||||
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
|
||||
mel_spec_transform: [
|
||||
!ref <mel_spec_transform1>
|
||||
]
|
||||
|
||||
# processor functions
|
||||
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
||||
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
|
||||
multilingual: True
|
||||
num_languages: 100
|
||||
language: 'en'
|
||||
task: 'transcribe'
|
||||
get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer
|
||||
token_path: !ref <qwen_pretrain_path>
|
||||
skip_special_tokens: True
|
||||
version: cosyvoice3
|
||||
allowed_special: 'all'
|
||||
tokenize: !name:cosyvoice.dataset.processor.tokenize
|
||||
get_tokenizer: !ref <get_tokenizer>
|
||||
allowed_special: !ref <allowed_special>
|
||||
filter: !name:cosyvoice.dataset.processor.filter
|
||||
max_length: 40960
|
||||
min_length: 0
|
||||
max_length: 6000
|
||||
min_length: 100
|
||||
token_max_length: 200
|
||||
token_min_length: 1
|
||||
resample: !name:cosyvoice.dataset.processor.resample
|
||||
resample_rate: !ref <sample_rate>
|
||||
truncate: !name:cosyvoice.dataset.processor.truncate
|
||||
truncate_length: 24576 # must be a multiplier of hop_size
|
||||
truncate_length: 24960 # must be a multiplier of hop_size and token_mel_ratio
|
||||
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||
n_fft: 1024
|
||||
n_fft: 1920
|
||||
num_mels: 80
|
||||
sampling_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
win_size: 1024
|
||||
hop_size: 480
|
||||
win_size: 1920
|
||||
fmin: 0
|
||||
fmax: 8000
|
||||
fmax: null
|
||||
center: False
|
||||
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
||||
feat_extractor: !ref <feat_extractor>
|
||||
num_frames: 960
|
||||
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
|
||||
num_frames: 960
|
||||
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
||||
sample_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
hop_size: 480
|
||||
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
||||
normalize: True
|
||||
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
||||
@@ -162,10 +162,10 @@ sort: !name:cosyvoice.dataset.processor.sort
|
||||
sort_size: 500 # sort_size should be less than shuffle_size
|
||||
batch: !name:cosyvoice.dataset.processor.batch
|
||||
batch_type: 'dynamic'
|
||||
max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g
|
||||
max_frames_in_batch: 2000
|
||||
padding: !name:cosyvoice.dataset.processor.padding
|
||||
use_spk_embedding: True # change to True during sft
|
||||
dpo: True
|
||||
use_spk_embedding: False # change to True during sft
|
||||
|
||||
|
||||
# dataset processor pipeline
|
||||
data_pipeline: [
|
||||
@@ -175,6 +175,7 @@ data_pipeline: [
|
||||
!ref <resample>,
|
||||
!ref <compute_fbank>,
|
||||
!ref <parse_embedding>,
|
||||
!ref <compute_whisper_fbank>,
|
||||
!ref <shuffle>,
|
||||
!ref <sort>,
|
||||
!ref <batch>,
|
||||
@@ -199,10 +200,10 @@ data_pipeline_gan: [
|
||||
train_conf:
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.00001 # change to 1e-5 during sft
|
||||
scheduler: warmuplr # change to constantlr during sft
|
||||
lr: 1e-5 # change to 1e-5 during sft
|
||||
scheduler: constantlr # change to constantlr during sft
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
warmup_steps: 2500
|
||||
max_epoch: 200
|
||||
grad_clip: 5
|
||||
accum_grad: 2
|
||||
@@ -223,4 +224,4 @@ train_conf_gan:
|
||||
grad_clip: 5
|
||||
accum_grad: 1 # in gan training, accum_grad must be 1
|
||||
log_interval: 100
|
||||
save_per_step: -1
|
||||
save_per_step: -1
|
||||
42
examples/libritts/cosyvoice3/conf/ds_stage2.json
Normal file
42
examples/libritts/cosyvoice3/conf/ds_stage2.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 100,
|
||||
"gradient_clipping": 5,
|
||||
"fp16": {
|
||||
"enabled": false,
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 16,
|
||||
"loss_scale_window": 256,
|
||||
"hysteresis": 2,
|
||||
"consecutive_hysteresis": false,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": false
|
||||
},
|
||||
"zero_force_ds_cpu_optimizer": false,
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "none",
|
||||
"pin_memory": true
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": false,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients" : true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 0.001,
|
||||
"weight_decay": 0.0001,
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
}
|
||||
}
|
||||
1
examples/libritts/cosyvoice3/local
Symbolic link
1
examples/libritts/cosyvoice3/local
Symbolic link
@@ -0,0 +1 @@
|
||||
../cosyvoice/local
|
||||
1
examples/libritts/cosyvoice3/path.sh
Symbolic link
1
examples/libritts/cosyvoice3/path.sh
Symbolic link
@@ -0,0 +1 @@
|
||||
../cosyvoice/path.sh
|
||||
97
examples/libritts/cosyvoice3/run.sh
Normal file
97
examples/libritts/cosyvoice3/run.sh
Normal file
@@ -0,0 +1,97 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
stage=-1
|
||||
stop_stage=3
|
||||
|
||||
data_url=www.openslr.org/resources/60
|
||||
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
|
||||
pretrained_model_dir=../../../pretrained_models/Fun-CosyVoice3-0.5B
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
echo "Data Download"
|
||||
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
|
||||
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mkdir -p data/$x
|
||||
# NOTE in CosyVoice3, we add instruct in sequence
|
||||
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x --instruct "You are a helpful assistant.<|endofprompt|>"
|
||||
done
|
||||
fi
|
||||
|
||||
# NOTE embedding/token extraction is not necessary now as we support online feature extraction
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mkdir -p data/$x/parquet
|
||||
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||
--num_processes 10 \
|
||||
--src_dir data/$x \
|
||||
--des_dir data/$x/parquet
|
||||
done
|
||||
fi
|
||||
|
||||
# train llm
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
job_id=1986
|
||||
dist_backend="nccl"
|
||||
num_workers=2
|
||||
prefetch=100
|
||||
train_engine=torch_ddp
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
echo "Run train. We only support llm traning for now"
|
||||
if [ $train_engine == 'deepspeed' ]; then
|
||||
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
||||
fi
|
||||
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
||||
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
||||
for model in llm flow hifigan; do
|
||||
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
|
||||
../../../cosyvoice/bin/train.py \
|
||||
--train_engine $train_engine \
|
||||
--config conf/cosyvoice3.yaml \
|
||||
--train_data data/train.data.list \
|
||||
--cv_data data/dev.data.list \
|
||||
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
|
||||
--onnx_path $pretrained_model_dir \
|
||||
--model $model \
|
||||
--checkpoint $pretrained_model_dir/$model.pt \
|
||||
--model_dir `pwd`/exp/cosyvoice3/$model/$train_engine \
|
||||
--tensorboard_dir `pwd`/tensorboard/cosyvoice3/$model/$train_engine \
|
||||
--ddp.dist_backend $dist_backend \
|
||||
--num_workers ${num_workers} \
|
||||
--prefetch ${prefetch} \
|
||||
--pin_memory \
|
||||
--use_amp \
|
||||
--deepspeed_config ./conf/ds_stage2.json \
|
||||
--deepspeed.save_states model+optimizer
|
||||
done
|
||||
fi
|
||||
|
||||
# average model
|
||||
average_num=5
|
||||
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||
for model in llm flow hifigan; do
|
||||
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
|
||||
echo "do model average and final checkpoint is $decode_checkpoint"
|
||||
python cosyvoice/bin/average_model.py \
|
||||
--dst_model $decode_checkpoint \
|
||||
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
|
||||
--num ${average_num} \
|
||||
--val_best
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
||||
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
||||
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
||||
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
||||
fi
|
||||
@@ -1 +0,0 @@
|
||||
../../../cosyvoice
|
||||
@@ -27,7 +27,7 @@ fi
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
||||
for x in dev test train; do
|
||||
tools/extract_embedding.py --dir data/$x \
|
||||
../../../tools/extract_embedding.py --dir data/$x \
|
||||
--onnx_path $pretrained_model_dir/campplus.onnx
|
||||
done
|
||||
fi
|
||||
@@ -35,7 +35,7 @@ fi
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
||||
for x in dev test train; do
|
||||
tools/extract_speech_token.py --dir data/$x \
|
||||
../../../tools/extract_speech_token.py --dir data/$x \
|
||||
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
|
||||
done
|
||||
fi
|
||||
@@ -44,7 +44,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
||||
for x in dev test train; do
|
||||
mkdir -p data/$x/parquet
|
||||
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||
--num_processes 10 \
|
||||
--src_dir data/$x \
|
||||
--des_dir data/$x/parquet
|
||||
@@ -69,7 +69,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
for model in llm flow hifigan; do
|
||||
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
||||
cosyvoice/bin/train.py \
|
||||
../../../cosyvoice/bin/train.py \
|
||||
--train_engine $train_engine \
|
||||
--config conf/cosyvoice.yaml \
|
||||
--train_data data/train.data.list \
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../../tools
|
||||
@@ -17,6 +17,7 @@ lightning==2.2.4
|
||||
matplotlib==3.7.5
|
||||
modelscope==1.20.0
|
||||
networkx==3.1
|
||||
numpy==1.26.4
|
||||
omegaconf==2.3.0
|
||||
onnx==1.16.0
|
||||
onnxruntime-gpu==1.18.0; sys_platform == 'linux'
|
||||
@@ -29,12 +30,13 @@ pyworld==0.3.4
|
||||
rich==13.7.1
|
||||
soundfile==0.12.1
|
||||
tensorboard==2.14.0
|
||||
tensorrt-cu12==10.0.1; sys_platform == 'linux'
|
||||
tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux'
|
||||
tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
|
||||
tensorrt-cu12==10.13.3.9; sys_platform == 'linux'
|
||||
tensorrt-cu12-bindings==10.13.3.9; sys_platform == 'linux'
|
||||
tensorrt-cu12-libs==10.13.3.9; sys_platform == 'linux'
|
||||
torch==2.3.1
|
||||
torchaudio==2.3.1
|
||||
transformers==4.40.1
|
||||
transformers==4.51.3
|
||||
x-transformers==2.11.24
|
||||
uvicorn==0.30.0
|
||||
wetext==0.0.4
|
||||
wget==3.2
|
||||
|
||||
@@ -9,5 +9,5 @@ RUN apt-get -y install git unzip git-lfs g++
|
||||
RUN git lfs install
|
||||
RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||
# here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
|
||||
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
|
||||
RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto
|
||||
@@ -24,7 +24,7 @@ import numpy as np
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
|
||||
app = FastAPI()
|
||||
@@ -88,14 +88,8 @@ if __name__ == '__main__':
|
||||
default=50000)
|
||||
parser.add_argument('--model_dir',
|
||||
type=str,
|
||||
default='iic/CosyVoice-300M',
|
||||
default='iic/CosyVoice2-0.5B',
|
||||
help='local path or modelscope repo id')
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
cosyvoice = CosyVoice(args.model_dir)
|
||||
except Exception:
|
||||
try:
|
||||
cosyvoice = CosyVoice2(args.model_dir)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
cosyvoice = AutoModel(model_dir=args.model_dir)
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
|
||||
@@ -25,7 +25,7 @@ import numpy as np
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
@@ -33,13 +33,7 @@ logging.basicConfig(level=logging.DEBUG,
|
||||
|
||||
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
||||
def __init__(self, args):
|
||||
try:
|
||||
self.cosyvoice = CosyVoice(args.model_dir, trt_concurrent=args.max_conc)
|
||||
except Exception:
|
||||
try:
|
||||
self.cosyvoice = CosyVoice2(args.model_dir, trt_concurrent=args.max_conc)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
self.cosyvoice = AutoModel(model_dir=args.model_dir)
|
||||
logging.info('grpc service initialized')
|
||||
|
||||
def Inference(self, request, context):
|
||||
@@ -90,7 +84,7 @@ if __name__ == '__main__':
|
||||
default=4)
|
||||
parser.add_argument('--model_dir',
|
||||
type=str,
|
||||
default='iic/CosyVoice-300M',
|
||||
default='iic/CosyVoice2-0.5B',
|
||||
help='local path or modelscope repo id')
|
||||
args = parser.parse_args()
|
||||
main()
|
||||
|
||||
8
runtime/triton_trtllm/Dockerfile.server
Normal file
8
runtime/triton_trtllm/Dockerfile.server
Normal file
@@ -0,0 +1,8 @@
|
||||
FROM nvcr.io/nvidia/tritonserver:25.06-trtllm-python-py3
|
||||
LABEL maintainer="zhangyuekai@foxmail.com"
|
||||
|
||||
RUN apt-get update && apt-get install -y cmake
|
||||
RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
|
||||
COPY ./requirements.txt /workspace/requirements.txt
|
||||
RUN pip install -r /workspace/requirements.txt
|
||||
WORKDIR /workspace
|
||||
141
runtime/triton_trtllm/README.DIT.md
Normal file
141
runtime/triton_trtllm/README.DIT.md
Normal file
@@ -0,0 +1,141 @@
|
||||
## Accelerating CosyVoice with DiT-based Token2Wav, NVIDIA Triton Inference Server and TensorRT-LLM
|
||||
|
||||
Contributed by Yuekai Zhang (NVIDIA).
|
||||
|
||||
This document describes how to accelerate CosyVoice with a DiT-based Token2Wav module from Step-Audio2, using NVIDIA Triton Inference Server and TensorRT-LLM.
|
||||
|
||||
### Quick Start
|
||||
|
||||
Launch the service directly with Docker Compose:
|
||||
```sh
|
||||
docker compose -f docker-compose.dit.yml up
|
||||
```
|
||||
|
||||
### Build the Docker Image
|
||||
|
||||
To build the image from scratch:
|
||||
```sh
|
||||
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
|
||||
```
|
||||
|
||||
### Run a Docker Container
|
||||
```sh
|
||||
your_mount_dir=/mnt:/mnt
|
||||
docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
|
||||
```
|
||||
|
||||
### Understanding `run_stepaudio2_dit_token2wav.sh`
|
||||
|
||||
The `run_stepaudio2_dit_token2wav.sh` script orchestrates the entire workflow through numbered stages.
|
||||
|
||||
You can run a subset of stages with:
|
||||
```sh
|
||||
bash run_stepaudio2_dit_token2wav.sh <start_stage> <stop_stage>
|
||||
```
|
||||
- `<start_stage>`: The stage to start from.
|
||||
- `<stop_stage>`: The stage to stop after.
|
||||
|
||||
**Stages:**
|
||||
|
||||
- **Stage -1**: Clones the `Step-Audio2` and `CosyVoice` repositories.
|
||||
- **Stage 0**: Downloads the `cosyvoice2_llm`, `CosyVoice2-0.5B`, and `Step-Audio-2-mini` models.
|
||||
- **Stage 1**: Converts the HuggingFace checkpoint for the LLM to the TensorRT-LLM format and builds the TensorRT engines.
|
||||
- **Stage 2**: Creates the Triton model repository, including configurations for `cosyvoice2_dit` and `token2wav_dit`.
|
||||
- **Stage 3**: Launches the Triton Inference Server for Token2Wav module and uses `trtllm-serve` to deploy Cosyvoice2 LLM.
|
||||
- **Stage 4**: Runs the gRPC benchmark client for performance testing.
|
||||
- **Stage 5**: Runs the offline TTS inference benchmark test.
|
||||
- **Stage 6**: Runs a standalone inference script for the Step-Audio2-mini DiT Token2Wav model.
|
||||
- **Stage 7**: Launches servers in a disaggregated setup, with the LLM on GPU 0 and Token2Wav servers on GPUs 1-3.
|
||||
- **Stage 8**: Runs the benchmark client for the disaggregated server configuration.
|
||||
### Export Models and Launch Server
|
||||
|
||||
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
|
||||
```sh
|
||||
# This command runs stages 0, 1, 2, and 3
|
||||
bash run_stepaudio2_dit_token2wav.sh 0 3
|
||||
```
|
||||
|
||||
### Benchmark with client-server mode
|
||||
|
||||
To benchmark the running Triton server, run stage 4:
|
||||
```sh
|
||||
bash run_stepaudio2_dit_token2wav.sh 4 4
|
||||
|
||||
# You can customize parameters such as the number of tasks inside the script.
|
||||
```
|
||||
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
|
||||
|
||||
#### Total Request Latency
|
||||
|
||||
| Concurrent Tasks | RTF | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
|
||||
| ---------------- | ------ | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
|
||||
| 1 | 0.1228 | 833.66 | 779.98 | 1297.05 | 1555.97 | 1653.02 |
|
||||
| 2 | 0.0901 | 1166.23 | 1124.69 | 1762.76 | 1900.64 | 2204.14 |
|
||||
| 4 | 0.0741 | 1849.30 | 1759.42 | 2624.50 | 2822.20 | 3128.42 |
|
||||
| 6 | 0.0774 | 2936.13 | 3054.64 | 3849.60 | 3900.49 | 4245.79 |
|
||||
| 8 | 0.0691 | 3408.56 | 3434.98 | 4547.13 | 5047.76 | 5346.53 |
|
||||
| 10 | 0.0707 | 4306.56 | 4343.44 | 5769.64 | 5876.09 | 5939.79 |
|
||||
|
||||
#### First Chunk Latency
|
||||
|
||||
| Concurrent Tasks | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
|
||||
| ---------------- | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
|
||||
| 1 | 197.50 | 196.13 | 214.65 | 215.96 | 229.21 |
|
||||
| 2 | 281.15 | 278.20 | 345.18 | 361.79 | 395.97 |
|
||||
| 4 | 510.65 | 530.50 | 630.13 | 642.44 | 666.65 |
|
||||
| 6 | 921.54 | 918.86 | 1079.97 | 1265.22 | 1524.41 |
|
||||
| 8 | 1019.95 | 1085.26 | 1371.05 | 1402.24 | 1410.66 |
|
||||
| 10 | 1214.98 | 1293.54 | 1575.36 | 1654.51 | 2161.76 |
|
||||
|
||||
### Benchmark with offline inference mode
|
||||
For offline inference mode benchmark, please run stage 5:
|
||||
```sh
|
||||
bash run_stepaudio2_dit_token2wav.sh 5 5
|
||||
```
|
||||
|
||||
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
|
||||
|
||||
#### Offline TTS (Cosyvoice2 0.5B LLM + StepAudio2 DiT Token2Wav)
|
||||
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|
||||
|---------|------------|------------------|-----------------------|--|
|
||||
| TRTLLM | 16 | 2.01 | 5.03 | 0.0292 |
|
||||
|
||||
|
||||
### Disaggregated Server
|
||||
When the LLM and token2wav components are deployed on the same GPU, they compete for resources. To optimize performance, we use a disaggregated setup where the LLM is deployed on one dedicated L20 GPU, taking advantage of in-flight batching for inference. The token2wav module is deployed on separate, dedicated GPUs.
|
||||
|
||||
The table below shows the first chunk latency results for this configuration. In our tests, we deploy two token2wav instances on each dedicated token2wav GPU.
|
||||
|
||||
| token2wav_num_gpu | concurrent_task_per_instance | concurrent_tasks_per_gpu | avg (ms) | p50 (ms) | p90 (ms) | p99 (ms) |
|
||||
|---|---|---|---|---|---|---|
|
||||
| 1 | 1 | 1.00 | 218.53 | 217.86 | 254.07 | 296.49 |
|
||||
| 2 | 1 | 1.33 | 218.82 | 219.21 | 256.62 | 303.13 |
|
||||
| 3 | 1 | 1.50 | 229.08 | 223.27 | 302.13 | 324.41 |
|
||||
| 4 | 1 | 1.60 | 203.87 | 198.23 | 254.92 | 279.31 |
|
||||
| 1 | 2 | 2.00 | 293.46 | 280.53 | 370.81 | 407.40 |
|
||||
| 2 | 2 | 2.67 | 263.38 | 236.84 | 350.82 | 397.39 |
|
||||
| 3 | 2 | 3.00 | 308.09 | 275.48 | 385.22 | 521.45 |
|
||||
| 4 | 2 | 3.20 | 271.85 | 253.25 | 359.03 | 387.91 |
|
||||
| 1 | 3 | 3.00 | 389.15 | 373.01 | 469.22 | 542.89 |
|
||||
| 2 | 3 | 4.00 | 403.48 | 394.80 | 481.24 | 507.75 |
|
||||
| 3 | 3 | 4.50 | 406.33 | 391.28 | 495.43 | 571.29 |
|
||||
| 4 | 3 | 4.80 | 436.72 | 383.81 | 638.44 | 879.23 |
|
||||
| 1 | 4 | 4.00 | 520.12 | 493.98 | 610.38 | 739.85 |
|
||||
| 2 | 4 | 5.33 | 494.60 | 490.50 | 605.93 | 708.09 |
|
||||
| 3 | 4 | 6.00 | 538.23 | 508.33 | 687.62 | 736.96 |
|
||||
| 4 | 4 | 6.40 | 579.68 | 546.20 | 721.53 | 958.04 |
|
||||
| 1 | 5 | 5.00 | 635.02 | 623.30 | 786.85 | 819.84 |
|
||||
| 2 | 5 | 6.67 | 598.23 | 617.09 | 741.00 | 788.96 |
|
||||
| 3 | 5 | 7.50 | 644.78 | 684.40 | 786.45 | 1009.45 |
|
||||
| 4 | 5 | 8.00 | 733.92 | 642.26 | 1024.79 | 1281.55 |
|
||||
| 1 | 6 | 6.00 | 715.38 | 745.68 | 887.04 | 906.68 |
|
||||
| 2 | 6 | 8.00 | 748.31 | 753.94 | 873.59 | 1007.14 |
|
||||
| 3 | 6 | 9.00 | 900.27 | 822.28 | 1431.14 | 1800.23 |
|
||||
| 4 | 6 | 9.60 | 857.54 | 820.33 | 1150.30 | 1298.53 |
|
||||
|
||||
The `concurrent_task_per_gpu` is calculated as:
|
||||
`concurrent_task_per_gpu = concurrent_task_per_instance * num_token2wav_instance_per_gpu (2) * token2wav_gpus / (token2wav_gpus + llm_gpus (1))`
|
||||
|
||||
### Acknowledgements
|
||||
|
||||
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).
|
||||
146
runtime/triton_trtllm/README.md
Normal file
146
runtime/triton_trtllm/README.md
Normal file
@@ -0,0 +1,146 @@
|
||||
## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM
|
||||
|
||||
Contributed by Yuekai Zhang (NVIDIA).
|
||||
|
||||
### Quick Start
|
||||
|
||||
Launch the service directly with Docker Compose:
|
||||
```sh
|
||||
docker compose up
|
||||
```
|
||||
|
||||
### Build the Docker Image
|
||||
|
||||
To build the image from scratch:
|
||||
```sh
|
||||
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
|
||||
```
|
||||
|
||||
### Run a Docker Container
|
||||
```sh
|
||||
your_mount_dir=/mnt:/mnt
|
||||
docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
|
||||
```
|
||||
|
||||
### Understanding `run.sh`
|
||||
|
||||
The `run.sh` script orchestrates the entire workflow through numbered stages.
|
||||
|
||||
You can run a subset of stages with:
|
||||
```sh
|
||||
bash run.sh <start_stage> <stop_stage> [service_type]
|
||||
```
|
||||
- `<start_stage>`: The stage to start from (0-5).
|
||||
- `<stop_stage>`: The stage to stop after (0-5).
|
||||
|
||||
**Stages:**
|
||||
|
||||
- **Stage 0**: Downloads the `cosyvoice-2 0.5B` model from HuggingFace.
|
||||
- **Stage 1**: Converts the HuggingFace checkpoint to the TensorRT-LLM format and builds the TensorRT engines.
|
||||
- **Stage 2**: Creates the Triton model repository and configures the model files. The configuration is adjusted based on whether `Decoupled=True` (streaming) or `Decoupled=False` (offline) will be used.
|
||||
- **Stage 3**: Launches the Triton Inference Server.
|
||||
- **Stage 4**: Runs the single-utterance HTTP client for testing.
|
||||
- **Stage 5**: Runs the gRPC benchmark client.
|
||||
- **Stage 6**: Runs the offline inference benchmark test.
|
||||
|
||||
### Export Models and Launch Server
|
||||
|
||||
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
|
||||
```sh
|
||||
# This command runs stages 0, 1, 2, and 3
|
||||
bash run.sh 0 3
|
||||
```
|
||||
> [!TIP]
|
||||
> Both streaming and offline (non-streaming) TTS modes are supported. For streaming TTS, set `Decoupled=True`. For offline TTS, set `Decoupled=False`. You need to rerun stage 2 if you switch between modes.
|
||||
|
||||
### Single-Utterance HTTP Client
|
||||
|
||||
Sends a single HTTP inference request. This is intended for testing the offline TTS mode (`Decoupled=False`):
|
||||
```sh
|
||||
bash run.sh 4 4
|
||||
```
|
||||
|
||||
### Benchmark with client-server mode
|
||||
|
||||
To benchmark the running Triton server, pass `streaming` or `offline` as the third argument:
|
||||
```sh
|
||||
bash run.sh 5 5 # [streaming|offline]
|
||||
|
||||
# You can also customize parameters such as the number of tasks and the dataset split:
|
||||
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline]
|
||||
```
|
||||
> [!TIP]
|
||||
> It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up.
|
||||
|
||||
### Benchmark with offline inference mode
|
||||
For offline inference mode benchmark, please check the below command:
|
||||
```sh
|
||||
# install FlashCosyVoice for token2wav batching
|
||||
# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt
|
||||
# cd /workspace/FlashCosyVoice
|
||||
# pip install -e .
|
||||
# cd -
|
||||
# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx
|
||||
|
||||
bash run.sh 6 6
|
||||
|
||||
# You can also switch to huggingface backend by setting backend=hf
|
||||
```
|
||||
|
||||
|
||||
### Benchmark Results
|
||||
The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio):
|
||||
|
||||
**Client-Server Mode: Streaming TTS (First Chunk Latency)**
|
||||
| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
||||
|---|---|---|---|---|
|
||||
| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
|
||||
| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 |
|
||||
| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 |
|
||||
| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 |
|
||||
| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 |
|
||||
| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 |
|
||||
|
||||
> If your service only needs a fixed speaker, you can set `use_spk2info_cache=True` in `run.sh`. To add more speakers, refer to the instructions [here](https://github.com/qi-hua/async_cosyvoice?tab=readme-ov-file#9-spk2info-%E8%AF%B4%E6%98%8E).
|
||||
|
||||
**Client-Server Mode: Offline TTS (Full Sentence Latency)**
|
||||
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
||||
|---|---|---|---|---|---|
|
||||
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
|
||||
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
|
||||
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
|
||||
|
||||
**Offline Inference Mode: Hugginface LLM V.S. TensorRT-LLM**
|
||||
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|
||||
|---------|------------|------------------|-----------------------|--|
|
||||
| HF | 1 | 39.26 | 44.31 | 0.2494 |
|
||||
| HF | 2 | 30.54 | 35.62 | 0.2064 |
|
||||
| HF | 4 | 18.63 | 23.90 | 0.1421 |
|
||||
| HF | 8 | 11.22 | 16.45 | 0.0947 |
|
||||
| HF | 16 | 8.42 | 13.78 | 0.0821 |
|
||||
| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 |
|
||||
| TRTLLM | 2 | 7.64 |12.65 | 0.0739 |
|
||||
| TRTLLM | 4 | 4.89 | 9.38 | 0.0539 |
|
||||
| TRTLLM | 8 | 2.92 | 7.23 | 0.0418 |
|
||||
| TRTLLM | 16 | 2.01 | 6.63 | 0.0386 |
|
||||
### OpenAI-Compatible Server
|
||||
|
||||
To launch an OpenAI-compatible API service, run the following commands:
|
||||
```sh
|
||||
git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
|
||||
cd Triton-OpenAI-Speech
|
||||
pip install -r requirements.txt
|
||||
|
||||
# After the Triton service is running, start the FastAPI bridge:
|
||||
python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
|
||||
|
||||
# Test the service with curl:
|
||||
bash test/test_cosyvoice.sh
|
||||
```
|
||||
> [!NOTE]
|
||||
> Currently, only the offline TTS mode is compatible with the OpenAI-compatible server.
|
||||
|
||||
### Acknowledgements
|
||||
|
||||
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).
|
||||
|
||||
922
runtime/triton_trtllm/client_grpc.py
Normal file
922
runtime/triton_trtllm/client_grpc.py
Normal file
@@ -0,0 +1,922 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# 2023 Nvidia (authors: Yuekai Zhang)
|
||||
# 2023 Recurrent.ai (authors: Songtao Shi)
|
||||
# See LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script supports to load dataset from huggingface and sends it to the server
|
||||
for decoding, in parallel.
|
||||
|
||||
Usage:
|
||||
num_task=2
|
||||
|
||||
# For offline F5-TTS
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--model-name f5_tts \
|
||||
--num-tasks $num_task \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name test_zh \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
|
||||
# For offline Spark-TTS-0.5B
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--model-name spark_tts \
|
||||
--num-tasks $num_task \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name wenetspeech4tts \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import queue
|
||||
import uuid
|
||||
import functools
|
||||
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import tritonclient
|
||||
import tritonclient.grpc.aio as grpcclient_aio
|
||||
import tritonclient.grpc as grpcclient_sync
|
||||
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
|
||||
|
||||
|
||||
class UserData:
|
||||
def __init__(self):
|
||||
self._completed_requests = queue.Queue()
|
||||
self._first_chunk_time = None
|
||||
self._second_chunk_time = None
|
||||
self._start_time = None
|
||||
|
||||
def record_start_time(self):
|
||||
self._start_time = time.time()
|
||||
|
||||
def get_first_chunk_latency(self):
|
||||
if self._first_chunk_time and self._start_time:
|
||||
return self._first_chunk_time - self._start_time
|
||||
return None
|
||||
|
||||
def get_second_chunk_latency(self):
|
||||
if self._first_chunk_time and self._second_chunk_time:
|
||||
return self._second_chunk_time - self._first_chunk_time
|
||||
return None
|
||||
|
||||
|
||||
def callback(user_data, result, error):
|
||||
if not error:
|
||||
if user_data._first_chunk_time is None:
|
||||
user_data._first_chunk_time = time.time()
|
||||
elif user_data._second_chunk_time is None:
|
||||
user_data._second_chunk_time = time.time()
|
||||
|
||||
if error:
|
||||
user_data._completed_requests.put(error)
|
||||
else:
|
||||
user_data._completed_requests.put(result)
|
||||
|
||||
|
||||
def stream_callback(user_data_map, result, error):
|
||||
request_id = None
|
||||
if error:
|
||||
print(f"An error occurred in the stream callback: {error}")
|
||||
else:
|
||||
request_id = result.get_response().id
|
||||
|
||||
if request_id:
|
||||
user_data = user_data_map.get(request_id)
|
||||
if user_data:
|
||||
callback(user_data, result, error)
|
||||
else:
|
||||
print(f"Warning: Could not find user_data for request_id {request_id}")
|
||||
|
||||
|
||||
def write_triton_stats(stats, summary_file):
|
||||
with open(summary_file, "w") as summary_f:
|
||||
model_stats = stats["model_stats"]
|
||||
for model_state in model_stats:
|
||||
if "last_inference" not in model_state:
|
||||
continue
|
||||
summary_f.write(f"model name is {model_state['name']} \n")
|
||||
model_inference_stats = model_state["inference_stats"]
|
||||
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
|
||||
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
|
||||
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
||||
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
||||
summary_f.write(
|
||||
f"queue time {total_queue_time_s:<5.2f} s, "
|
||||
f"compute infer time {total_infer_time_s:<5.2f} s, "
|
||||
f"compute input time {total_input_time_s:<5.2f} s, "
|
||||
f"compute output time {total_output_time_s:<5.2f} s \n"
|
||||
)
|
||||
model_batch_stats = model_state["batch_stats"]
|
||||
for batch in model_batch_stats:
|
||||
batch_size = int(batch["batch_size"])
|
||||
compute_input = batch["compute_input"]
|
||||
compute_output = batch["compute_output"]
|
||||
compute_infer = batch["compute_infer"]
|
||||
batch_count = int(compute_infer["count"])
|
||||
if batch_count == 0:
|
||||
continue
|
||||
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
|
||||
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
||||
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
||||
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
||||
summary_f.write(
|
||||
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
|
||||
f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
|
||||
f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
|
||||
f"{compute_infer_time_ms / batch_count:.2f} ms, "
|
||||
f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
|
||||
f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
|
||||
)
|
||||
summary_f.write(
|
||||
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
|
||||
)
|
||||
summary_f.write(
|
||||
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
|
||||
)
|
||||
|
||||
|
||||
def subtract_stats(stats_after, stats_before):
|
||||
"""Subtracts two Triton inference statistics objects."""
|
||||
stats_diff = json.loads(json.dumps(stats_after))
|
||||
|
||||
model_stats_before_map = {
|
||||
s["name"]: {
|
||||
"version": s["version"],
|
||||
"last_inference": s.get("last_inference", 0),
|
||||
"inference_count": s.get("inference_count", 0),
|
||||
"execution_count": s.get("execution_count", 0),
|
||||
"inference_stats": s.get("inference_stats", {}),
|
||||
"batch_stats": s.get("batch_stats", []),
|
||||
}
|
||||
for s in stats_before["model_stats"]
|
||||
}
|
||||
|
||||
for model_stat_after in stats_diff["model_stats"]:
|
||||
model_name = model_stat_after["name"]
|
||||
if model_name in model_stats_before_map:
|
||||
model_stat_before = model_stats_before_map[model_name]
|
||||
|
||||
model_stat_after["inference_count"] = str(
|
||||
int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
|
||||
)
|
||||
model_stat_after["execution_count"] = str(
|
||||
int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
|
||||
)
|
||||
|
||||
if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
|
||||
for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
|
||||
if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
|
||||
if "ns" in model_stat_after["inference_stats"][key]:
|
||||
ns_after = int(model_stat_after["inference_stats"][key]["ns"])
|
||||
ns_before = int(model_stat_before["inference_stats"][key]["ns"])
|
||||
model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
|
||||
if "count" in model_stat_after["inference_stats"][key]:
|
||||
count_after = int(model_stat_after["inference_stats"][key]["count"])
|
||||
count_before = int(model_stat_before["inference_stats"][key]["count"])
|
||||
model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
|
||||
|
||||
if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
|
||||
batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
|
||||
for batch_stat_after in model_stat_after["batch_stats"]:
|
||||
bs = batch_stat_after["batch_size"]
|
||||
if bs in batch_stats_before_map:
|
||||
batch_stat_before = batch_stats_before_map[bs]
|
||||
for key in ["compute_input", "compute_infer", "compute_output"]:
|
||||
if key in batch_stat_after and key in batch_stat_before:
|
||||
count_after = int(batch_stat_after[key]["count"])
|
||||
count_before = int(batch_stat_before[key]["count"])
|
||||
batch_stat_after[key]["count"] = str(count_after - count_before)
|
||||
|
||||
ns_after = int(batch_stat_after[key]["ns"])
|
||||
ns_before = int(batch_stat_before[key]["ns"])
|
||||
batch_stat_after[key]["ns"] = str(ns_after - ns_before)
|
||||
return stats_diff
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-addr",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Address of the server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-port",
|
||||
type=int,
|
||||
default=8001,
|
||||
help="Grpc port of the triton server, default is 8001",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-audio",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-text",
|
||||
type=str,
|
||||
default="",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-text",
|
||||
type=str,
|
||||
default="",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--huggingface-dataset",
|
||||
type=str,
|
||||
default="yuekai/seed_tts",
|
||||
help="dataset name in huggingface dataset hub",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||
help="dataset split name, default is 'test'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the manifest dir which includes wav.scp trans.txt files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="f5_tts",
|
||||
choices=[
|
||||
"f5_tts",
|
||||
"spark_tts",
|
||||
"cosyvoice2",
|
||||
"cosyvoice2_dit"],
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-tasks",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of concurrent tasks for sending",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Controls how frequently we print the log.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--compute-wer",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="""True to compute WER.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-dir",
|
||||
type=str,
|
||||
required=False,
|
||||
default="./tmp",
|
||||
help="log directory",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="offline",
|
||||
choices=["offline", "streaming"],
|
||||
help="Select offline or streaming benchmark mode."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-overlap-duration",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Chunk overlap duration for streaming reconstruction (in seconds)."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-spk2info-cache",
|
||||
type=str,
|
||||
default="False",
|
||||
help="Use spk2info cache for reference audio.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_audio(wav_path, target_sample_rate=16000):
|
||||
assert target_sample_rate == 16000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
waveform = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
else:
|
||||
waveform, sample_rate = sf.read(wav_path)
|
||||
if sample_rate != target_sample_rate:
|
||||
from scipy.signal import resample
|
||||
|
||||
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
|
||||
waveform = resample(waveform, num_samples)
|
||||
return waveform, target_sample_rate
|
||||
|
||||
|
||||
def prepare_request_input_output(
|
||||
protocol_client,
|
||||
waveform,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=16000,
|
||||
padding_duration: int = None,
|
||||
use_spk2info_cache: bool = False
|
||||
):
|
||||
"""Prepares inputs for Triton inference (offline or streaming)."""
|
||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
|
||||
if padding_duration:
|
||||
duration = len(waveform) / sample_rate
|
||||
if reference_text:
|
||||
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
||||
else:
|
||||
estimated_target_duration = duration
|
||||
|
||||
required_total_samples = padding_duration * sample_rate * (
|
||||
(int(estimated_target_duration + duration) // padding_duration) + 1
|
||||
)
|
||||
samples = np.zeros((1, required_total_samples), dtype=np.float32)
|
||||
samples[0, : len(waveform)] = waveform
|
||||
else:
|
||||
samples = waveform.reshape(1, -1).astype(np.float32)
|
||||
|
||||
inputs = [
|
||||
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
|
||||
protocol_client.InferInput(
|
||||
"reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
|
||||
),
|
||||
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
|
||||
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
|
||||
]
|
||||
inputs[0].set_data_from_numpy(samples)
|
||||
inputs[1].set_data_from_numpy(lengths)
|
||||
|
||||
input_data_numpy = np.array([reference_text], dtype=object)
|
||||
input_data_numpy = input_data_numpy.reshape((1, 1))
|
||||
inputs[2].set_data_from_numpy(input_data_numpy)
|
||||
|
||||
input_data_numpy = np.array([target_text], dtype=object)
|
||||
input_data_numpy = input_data_numpy.reshape((1, 1))
|
||||
inputs[3].set_data_from_numpy(input_data_numpy)
|
||||
|
||||
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
||||
if use_spk2info_cache:
|
||||
inputs = inputs[-1:]
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def run_sync_streaming_inference(
|
||||
sync_triton_client: tritonclient.grpc.InferenceServerClient,
|
||||
model_name: str,
|
||||
inputs: list,
|
||||
outputs: list,
|
||||
request_id: str,
|
||||
user_data: UserData,
|
||||
chunk_overlap_duration: float,
|
||||
save_sample_rate: int,
|
||||
audio_save_path: str,
|
||||
):
|
||||
"""Helper function to run the blocking sync streaming call."""
|
||||
start_time_total = time.time()
|
||||
user_data.record_start_time()
|
||||
|
||||
sync_triton_client.async_stream_infer(
|
||||
model_name,
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
outputs=outputs,
|
||||
enable_empty_final_response=True,
|
||||
)
|
||||
|
||||
audios = []
|
||||
while True:
|
||||
try:
|
||||
result = user_data._completed_requests.get(timeout=200)
|
||||
if isinstance(result, InferenceServerException):
|
||||
print(f"Received InferenceServerException: {result}")
|
||||
return None, None, None, None
|
||||
response = result.get_response()
|
||||
final = response.parameters["triton_final_response"].bool_param
|
||||
if final is True:
|
||||
break
|
||||
|
||||
audio_chunk = result.as_numpy("waveform").reshape(-1)
|
||||
if audio_chunk.size > 0:
|
||||
audios.append(audio_chunk)
|
||||
else:
|
||||
print("Warning: received empty audio chunk.")
|
||||
|
||||
except queue.Empty:
|
||||
print(f"Timeout waiting for response for request id {request_id}")
|
||||
return None, None, None, None
|
||||
|
||||
end_time_total = time.time()
|
||||
total_request_latency = end_time_total - start_time_total
|
||||
first_chunk_latency = user_data.get_first_chunk_latency()
|
||||
second_chunk_latency = user_data.get_second_chunk_latency()
|
||||
|
||||
if audios:
|
||||
if model_name == "spark_tts":
|
||||
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
|
||||
fade_out = np.linspace(1, 0, cross_fade_samples)
|
||||
fade_in = np.linspace(0, 1, cross_fade_samples)
|
||||
reconstructed_audio = None
|
||||
|
||||
if not audios:
|
||||
print("Warning: No audio chunks received.")
|
||||
reconstructed_audio = np.array([], dtype=np.float32)
|
||||
elif len(audios) == 1:
|
||||
reconstructed_audio = audios[0]
|
||||
else:
|
||||
reconstructed_audio = audios[0][:-cross_fade_samples]
|
||||
for i in range(1, len(audios)):
|
||||
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
||||
audios[i - 1][-cross_fade_samples:] * fade_out)
|
||||
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
||||
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
||||
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
||||
|
||||
if reconstructed_audio is not None and reconstructed_audio.size > 0:
|
||||
actual_duration = len(reconstructed_audio) / save_sample_rate
|
||||
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
||||
else:
|
||||
print("Warning: No audio chunks received or reconstructed.")
|
||||
actual_duration = 0
|
||||
else:
|
||||
reconstructed_audio = np.concatenate(audios)
|
||||
actual_duration = len(reconstructed_audio) / save_sample_rate
|
||||
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
||||
|
||||
else:
|
||||
print("Warning: No audio chunks received.")
|
||||
actual_duration = 0
|
||||
|
||||
return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
|
||||
|
||||
|
||||
async def send_streaming(
|
||||
manifest_item_list: list,
|
||||
name: str,
|
||||
server_url: str,
|
||||
protocol_client: types.ModuleType,
|
||||
log_interval: int,
|
||||
model_name: str,
|
||||
audio_save_dir: str = "./",
|
||||
save_sample_rate: int = 16000,
|
||||
chunk_overlap_duration: float = 0.1,
|
||||
padding_duration: int = None,
|
||||
use_spk2info_cache: bool = False,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
task_id = int(name[5:])
|
||||
sync_triton_client = None
|
||||
user_data_map = {}
|
||||
|
||||
try:
|
||||
print(f"{name}: Initializing sync client for streaming...")
|
||||
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
|
||||
sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
|
||||
|
||||
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
||||
for i, item in enumerate(manifest_item_list):
|
||||
if i % log_interval == 0:
|
||||
print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
|
||||
|
||||
try:
|
||||
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
||||
reference_text, target_text = item["reference_text"], item["target_text"]
|
||||
|
||||
inputs, outputs = prepare_request_input_output(
|
||||
protocol_client,
|
||||
waveform,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate,
|
||||
padding_duration=padding_duration,
|
||||
use_spk2info_cache=use_spk2info_cache
|
||||
)
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
user_data = UserData()
|
||||
user_data_map[request_id] = user_data
|
||||
|
||||
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||
total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
|
||||
run_sync_streaming_inference,
|
||||
sync_triton_client,
|
||||
model_name,
|
||||
inputs,
|
||||
outputs,
|
||||
request_id,
|
||||
user_data,
|
||||
chunk_overlap_duration,
|
||||
save_sample_rate,
|
||||
audio_save_path
|
||||
)
|
||||
|
||||
if total_request_latency is not None:
|
||||
print(
|
||||
f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
|
||||
f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
|
||||
f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
|
||||
)
|
||||
latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
|
||||
total_duration += actual_duration
|
||||
else:
|
||||
print(f"{name}: Item {i} failed.")
|
||||
|
||||
del user_data_map[request_id]
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
||||
except Exception as e:
|
||||
print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
if sync_triton_client:
|
||||
try:
|
||||
print(f"{name}: Closing stream and sync client...")
|
||||
sync_triton_client.stop_stream()
|
||||
sync_triton_client.close()
|
||||
except Exception as e:
|
||||
print(f"{name}: Error closing sync client: {e}")
|
||||
|
||||
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
|
||||
return total_duration, latency_data
|
||||
|
||||
|
||||
async def send(
|
||||
manifest_item_list: list,
|
||||
name: str,
|
||||
triton_client: tritonclient.grpc.aio.InferenceServerClient,
|
||||
protocol_client: types.ModuleType,
|
||||
log_interval: int,
|
||||
model_name: str,
|
||||
padding_duration: int = None,
|
||||
audio_save_dir: str = "./",
|
||||
save_sample_rate: int = 16000,
|
||||
use_spk2info_cache: bool = False,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
task_id = int(name[5:])
|
||||
|
||||
for i, item in enumerate(manifest_item_list):
|
||||
if i % log_interval == 0:
|
||||
print(f"{name}: {i}/{len(manifest_item_list)}")
|
||||
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
||||
reference_text, target_text = item["reference_text"], item["target_text"]
|
||||
|
||||
inputs, outputs = prepare_request_input_output(
|
||||
protocol_client,
|
||||
waveform,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate,
|
||||
padding_duration=padding_duration,
|
||||
use_spk2info_cache=use_spk2info_cache
|
||||
)
|
||||
sequence_id = 100000000 + i + task_id * 10
|
||||
start = time.time()
|
||||
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
|
||||
|
||||
audio = response.as_numpy("waveform").reshape(-1)
|
||||
actual_duration = len(audio) / save_sample_rate
|
||||
|
||||
end = time.time() - start
|
||||
|
||||
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
||||
|
||||
latency_data.append((end, actual_duration))
|
||||
total_duration += actual_duration
|
||||
|
||||
return total_duration, latency_data
|
||||
|
||||
|
||||
def load_manifests(manifest_path):
|
||||
with open(manifest_path, "r") as f:
|
||||
manifest_list = []
|
||||
for line in f:
|
||||
assert len(line.strip().split("|")) == 4
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||
utt = Path(utt).stem
|
||||
if not os.path.isabs(prompt_wav):
|
||||
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
|
||||
manifest_list.append(
|
||||
{
|
||||
"audio_filepath": prompt_wav,
|
||||
"reference_text": prompt_text,
|
||||
"target_text": gt_text,
|
||||
"target_audio_path": utt,
|
||||
}
|
||||
)
|
||||
return manifest_list
|
||||
|
||||
|
||||
def split_data(data, k):
|
||||
n = len(data)
|
||||
if n < k:
|
||||
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
|
||||
k = n
|
||||
|
||||
quotient = n // k
|
||||
remainder = n % k
|
||||
|
||||
result = []
|
||||
start = 0
|
||||
for i in range(k):
|
||||
if i < remainder:
|
||||
end = start + quotient + 1
|
||||
else:
|
||||
end = start + quotient
|
||||
|
||||
result.append(data[start:end])
|
||||
start = end
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def main():
|
||||
args = get_args()
|
||||
url = f"{args.server_addr}:{args.server_port}"
|
||||
|
||||
triton_client = None
|
||||
protocol_client = None
|
||||
if args.mode == "offline":
|
||||
print("Initializing gRPC client for offline mode...")
|
||||
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
||||
protocol_client = grpcclient_aio
|
||||
elif args.mode == "streaming":
|
||||
print("Initializing gRPC client for streaming mode...")
|
||||
protocol_client = grpcclient_sync
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {args.mode}")
|
||||
|
||||
if args.reference_audio:
|
||||
args.num_tasks = 1
|
||||
args.log_interval = 1
|
||||
manifest_item_list = [
|
||||
{
|
||||
"reference_text": args.reference_text,
|
||||
"target_text": args.target_text,
|
||||
"audio_filepath": args.reference_audio,
|
||||
"target_audio_path": "test",
|
||||
}
|
||||
]
|
||||
elif args.huggingface_dataset:
|
||||
import datasets
|
||||
|
||||
dataset = datasets.load_dataset(
|
||||
args.huggingface_dataset,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
manifest_item_list = []
|
||||
for i in range(len(dataset)):
|
||||
manifest_item_list.append(
|
||||
{
|
||||
"audio_filepath": dataset[i]["prompt_audio"],
|
||||
"reference_text": dataset[i]["prompt_text"],
|
||||
"target_audio_path": dataset[i]["id"],
|
||||
"target_text": dataset[i]["target_text"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
manifest_item_list = load_manifests(args.manifest_path)
|
||||
|
||||
stats_client = None
|
||||
stats_before = None
|
||||
try:
|
||||
print("Initializing temporary async client for fetching stats...")
|
||||
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
||||
print("Fetching inference statistics before running tasks...")
|
||||
stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
||||
except Exception as e:
|
||||
print(f"Could not retrieve statistics before running tasks: {e}")
|
||||
|
||||
num_tasks = min(args.num_tasks, len(manifest_item_list))
|
||||
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
||||
|
||||
os.makedirs(args.log_dir, exist_ok=True)
|
||||
args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
|
||||
tasks = []
|
||||
start_time = time.time()
|
||||
for i in range(num_tasks):
|
||||
if args.mode == "offline":
|
||||
task = asyncio.create_task(
|
||||
send(
|
||||
manifest_item_list[i],
|
||||
name=f"task-{i}",
|
||||
triton_client=triton_client,
|
||||
protocol_client=protocol_client,
|
||||
log_interval=args.log_interval,
|
||||
model_name=args.model_name,
|
||||
audio_save_dir=args.log_dir,
|
||||
padding_duration=1,
|
||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||
use_spk2info_cache=args.use_spk2info_cache,
|
||||
)
|
||||
)
|
||||
elif args.mode == "streaming":
|
||||
task = asyncio.create_task(
|
||||
send_streaming(
|
||||
manifest_item_list[i],
|
||||
name=f"task-{i}",
|
||||
server_url=url,
|
||||
protocol_client=protocol_client,
|
||||
log_interval=args.log_interval,
|
||||
model_name=args.model_name,
|
||||
audio_save_dir=args.log_dir,
|
||||
padding_duration=10,
|
||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||
chunk_overlap_duration=args.chunk_overlap_duration,
|
||||
use_spk2info_cache=args.use_spk2info_cache,
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
ans_list = await asyncio.gather(*tasks)
|
||||
|
||||
end_time = time.time()
|
||||
elapsed = end_time - start_time
|
||||
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
for ans in ans_list:
|
||||
if ans:
|
||||
total_duration += ans[0]
|
||||
latency_data.extend(ans[1])
|
||||
else:
|
||||
print("Warning: A task returned None, possibly due to an error.")
|
||||
|
||||
if total_duration == 0:
|
||||
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
|
||||
rtf = float('inf')
|
||||
else:
|
||||
rtf = elapsed / total_duration
|
||||
|
||||
s = f"Mode: {args.mode}\n"
|
||||
s += f"RTF: {rtf:.4f}\n"
|
||||
s += f"total_duration: {total_duration:.3f} seconds\n"
|
||||
s += f"({total_duration / 3600:.2f} hours)\n"
|
||||
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
|
||||
|
||||
if latency_data:
|
||||
if args.mode == "offline":
|
||||
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
|
||||
if latency_list:
|
||||
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
||||
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
|
||||
s += f"latency_variance: {latency_variance:.2f}\n"
|
||||
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
|
||||
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
|
||||
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
|
||||
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
|
||||
s += f"average_latency_ms: {latency_ms:.2f}\n"
|
||||
else:
|
||||
s += "No latency data collected for offline mode.\n"
|
||||
|
||||
elif args.mode == "streaming":
|
||||
total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
|
||||
first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
|
||||
second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
|
||||
|
||||
s += "\n--- Total Request Latency ---\n"
|
||||
if total_latency_list:
|
||||
avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
|
||||
variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
|
||||
s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
|
||||
s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
|
||||
s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
|
||||
s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
|
||||
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
|
||||
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
|
||||
else:
|
||||
s += "No total request latency data collected.\n"
|
||||
|
||||
s += "\n--- First Chunk Latency ---\n"
|
||||
if first_chunk_latency_list:
|
||||
avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
|
||||
variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
|
||||
s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
|
||||
s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
|
||||
s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
|
||||
s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
|
||||
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
||||
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
||||
else:
|
||||
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
||||
|
||||
s += "\n--- Second Chunk Latency ---\n"
|
||||
if second_chunk_latency_list:
|
||||
avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
|
||||
variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
|
||||
s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
|
||||
s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
|
||||
s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
|
||||
s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
|
||||
s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
||||
s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
|
||||
else:
|
||||
s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
|
||||
else:
|
||||
s += "No latency data collected.\n"
|
||||
|
||||
print(s)
|
||||
if args.manifest_path:
|
||||
name = Path(args.manifest_path).stem
|
||||
elif args.split_name:
|
||||
name = args.split_name
|
||||
elif args.reference_audio:
|
||||
name = Path(args.reference_audio).stem
|
||||
else:
|
||||
name = "results"
|
||||
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
||||
f.write(s)
|
||||
|
||||
try:
|
||||
if stats_client and stats_before:
|
||||
print("Fetching inference statistics after running tasks...")
|
||||
stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
||||
|
||||
print("Calculating statistics difference...")
|
||||
stats = subtract_stats(stats_after, stats_before)
|
||||
|
||||
print("Fetching model config...")
|
||||
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
|
||||
|
||||
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
||||
|
||||
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
||||
json.dump(metadata, f, indent=4)
|
||||
else:
|
||||
print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Could not retrieve statistics or config: {e}")
|
||||
finally:
|
||||
if stats_client:
|
||||
try:
|
||||
print("Closing temporary async stats client...")
|
||||
await stats_client.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing async stats client: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def run_main():
|
||||
try:
|
||||
await main()
|
||||
except Exception as e:
|
||||
print(f"An error occurred in main: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
asyncio.run(run_main())
|
||||
172
runtime/triton_trtllm/client_http.py
Normal file
172
runtime/triton_trtllm/client_http.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import requests
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
type=str,
|
||||
default="localhost:8000",
|
||||
help="Address of the server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-audio",
|
||||
type=str,
|
||||
default="../../example/prompt_audio.wav",
|
||||
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-text",
|
||||
type=str,
|
||||
default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-text",
|
||||
type=str,
|
||||
default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="spark_tts",
|
||||
choices=[
|
||||
"f5_tts",
|
||||
"spark_tts",
|
||||
"cosyvoice2"],
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-audio",
|
||||
type=str,
|
||||
default="output.wav",
|
||||
help="Path to save the output audio",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_request(
|
||||
waveform,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=16000,
|
||||
padding_duration: int = None,
|
||||
audio_save_dir: str = "./",
|
||||
):
|
||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
if padding_duration:
|
||||
# padding to nearset 10 seconds
|
||||
samples = np.zeros(
|
||||
(
|
||||
1,
|
||||
padding_duration
|
||||
* sample_rate
|
||||
* ((int(len(waveform) / sample_rate) // padding_duration) + 1),
|
||||
),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
samples[0, : len(waveform)] = waveform
|
||||
else:
|
||||
samples = waveform
|
||||
|
||||
samples = samples.reshape(1, -1).astype(np.float32)
|
||||
|
||||
data = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "reference_wav",
|
||||
"shape": samples.shape,
|
||||
"datatype": "FP32",
|
||||
"data": samples.tolist()
|
||||
},
|
||||
{
|
||||
"name": "reference_wav_len",
|
||||
"shape": lengths.shape,
|
||||
"datatype": "INT32",
|
||||
"data": lengths.tolist(),
|
||||
},
|
||||
{
|
||||
"name": "reference_text",
|
||||
"shape": [1, 1],
|
||||
"datatype": "BYTES",
|
||||
"data": [reference_text]
|
||||
},
|
||||
{
|
||||
"name": "target_text",
|
||||
"shape": [1, 1],
|
||||
"datatype": "BYTES",
|
||||
"data": [target_text]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
server_url = args.server_url
|
||||
if not server_url.startswith(("http://", "https://")):
|
||||
server_url = f"http://{server_url}"
|
||||
|
||||
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
||||
waveform, sr = sf.read(args.reference_audio)
|
||||
assert sr == 16000, "sample rate hardcoded in server"
|
||||
|
||||
samples = np.array(waveform, dtype=np.float32)
|
||||
data = prepare_request(samples, args.reference_text, args.target_text)
|
||||
|
||||
rsp = requests.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=data,
|
||||
verify=False,
|
||||
params={"request_id": '0'}
|
||||
)
|
||||
result = rsp.json()
|
||||
audio = result["outputs"][0]["data"]
|
||||
audio = np.array(audio, dtype=np.float32)
|
||||
if args.model_name == "spark_tts":
|
||||
sample_rate = 16000
|
||||
else:
|
||||
sample_rate = 24000
|
||||
sf.write(args.output_audio, audio, sample_rate, "PCM_16")
|
||||
20
runtime/triton_trtllm/docker-compose.dit.yml
Normal file
20
runtime/triton_trtllm/docker-compose.dit.yml
Normal file
@@ -0,0 +1,20 @@
|
||||
services:
|
||||
tts:
|
||||
image: soar97/triton-cosyvoice:25.06
|
||||
shm_size: '1gb'
|
||||
ports:
|
||||
- "8000:8000"
|
||||
- "8001:8001"
|
||||
- "8002:8002"
|
||||
environment:
|
||||
- PYTHONIOENCODING=utf-8
|
||||
- MODEL_ID=${MODEL_ID}
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
device_ids: ['0']
|
||||
capabilities: [gpu]
|
||||
command: >
|
||||
/bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run_stepaudio2_dit_token2wav.sh 0 3"
|
||||
20
runtime/triton_trtllm/docker-compose.yml
Normal file
20
runtime/triton_trtllm/docker-compose.yml
Normal file
@@ -0,0 +1,20 @@
|
||||
services:
|
||||
tts:
|
||||
image: soar97/triton-cosyvoice:25.06
|
||||
shm_size: '1gb'
|
||||
ports:
|
||||
- "8000:8000"
|
||||
- "8001:8001"
|
||||
- "8002:8002"
|
||||
environment:
|
||||
- PYTHONIOENCODING=utf-8
|
||||
- MODEL_ID=${MODEL_ID}
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
device_ids: ['0']
|
||||
capabilities: [gpu]
|
||||
command: >
|
||||
/bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/FunAudioLLM/CosyVoice.git && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3"
|
||||
97
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
Normal file
97
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import json
|
||||
import torch
|
||||
from torch.utils.dlpack import to_dlpack
|
||||
|
||||
import triton_python_backend_utils as pb_utils
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import s3tokenizer
|
||||
torch.set_num_threads(1)
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for audio tokenization.
|
||||
|
||||
This model takes reference audio input and extracts semantic tokens
|
||||
using s3tokenizer.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing model configuration
|
||||
"""
|
||||
# Parse model parameters
|
||||
parameters = json.loads(args['model_config'])['parameters']
|
||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||
|
||||
self.device = torch.device("cuda")
|
||||
model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
|
||||
self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
|
||||
|
||||
def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
Args:
|
||||
requests: List of inference requests
|
||||
|
||||
Returns:
|
||||
List of inference responses containing tokenized outputs
|
||||
"""
|
||||
mels = []
|
||||
|
||||
# Process each request in batch
|
||||
for request in requests:
|
||||
# Extract input tensors
|
||||
wav_array = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav").as_numpy()
|
||||
wav_len = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav_len").as_numpy().item()
|
||||
|
||||
wav_array = torch.from_numpy(wav_array).to(self.device)
|
||||
# Prepare inputs
|
||||
wav = wav_array[:, :wav_len].squeeze(0)
|
||||
mels.append(s3tokenizer.log_mel_spectrogram(wav))
|
||||
|
||||
mels, mels_lens = s3tokenizer.padding(mels)
|
||||
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
|
||||
codes = codes.clone() + ORIGINAL_VOCAB_SIZE
|
||||
|
||||
responses = []
|
||||
for i in range(len(requests)):
|
||||
prompt_speech_tokens = codes[i, :codes_lens[i].item()]
|
||||
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
||||
"prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
||||
inference_response = pb_utils.InferenceResponse(
|
||||
output_tensors=[prompt_speech_tokens_tensor])
|
||||
responses.append(inference_response)
|
||||
|
||||
return responses
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "audio_tokenizer"
|
||||
backend: "python"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
key: "model_dir",
|
||||
value: {string_value:"${model_dir}"}
|
||||
}
|
||||
]
|
||||
|
||||
input [
|
||||
{
|
||||
name: "reference_wav"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
},
|
||||
{
|
||||
name: "reference_wav_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "prompt_speech_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [-1]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_CPU
|
||||
}
|
||||
]
|
||||
454
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py
Normal file
454
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py
Normal file
@@ -0,0 +1,454 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import torchaudio
|
||||
|
||||
|
||||
from matcha.utils.audio import mel_spectrogram
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for Spark TTS.
|
||||
|
||||
This model orchestrates the end-to-end TTS pipeline by coordinating
|
||||
between audio tokenizer, LLM, and vocoder components.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing model configuration
|
||||
"""
|
||||
self.logger = pb_utils.Logger
|
||||
# Parse model parameters
|
||||
self.model_config = json.loads(args['model_config'])
|
||||
parameters = self.model_config['parameters']
|
||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||
self.logger.log_info(f"model_params:{model_params}")
|
||||
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
||||
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
||||
|
||||
# Initialize tokenizer
|
||||
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
|
||||
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
|
||||
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
|
||||
|
||||
self.device = torch.device("cuda")
|
||||
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
||||
|
||||
self.token_frame_rate = 25
|
||||
self.flow_pre_lookahead_len = 3
|
||||
self.token_hop_len = 15
|
||||
|
||||
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
|
||||
if not os.path.exists(spk_info_path):
|
||||
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
||||
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||
self.default_spk_info = spk_info["001"]
|
||||
|
||||
def forward_llm(self, input_ids):
|
||||
"""
|
||||
Prepares the response from the language model based on the provided
|
||||
inputs. Creates a `pb_utils.InferenceRequest` object with passed
|
||||
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
|
||||
For each response from the language model:
|
||||
- Checks for errors and raise an exception if any are found.
|
||||
- Extracts the "output_ids" tensor from the response.
|
||||
- Determines the finish reason based on the presence of the
|
||||
end-of-sequence token or reaching the maximum length.
|
||||
- Appends the generated token IDs to `output_ids`.
|
||||
- If the finish reason is determined, decodes the output IDs to text
|
||||
and prepares the final response.
|
||||
|
||||
The final response includes the generated text, finish reason,
|
||||
completion tokens, prompt tokens, and total tokens.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
|
||||
"""
|
||||
# convert input_ids to numpy, with shape [1, sequence_length]
|
||||
input_ids = input_ids.cpu().numpy()
|
||||
max_tokens = 750
|
||||
input_dict = {
|
||||
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
||||
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
||||
"pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
||||
"streaming": np.array([[self.decoupled]], dtype=np.bool_),
|
||||
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
||||
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
||||
"temperature": np.array([[0.8]], dtype=np.float32),
|
||||
"repetition_penalty": np.array([[1.1]], dtype=np.float32),
|
||||
"random_seed": np.array([[42]], dtype=np.uint64),
|
||||
"input_ids": input_ids,
|
||||
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
||||
}
|
||||
|
||||
# Convert inputs to Triton tensors
|
||||
input_tensor_list = [
|
||||
pb_utils.Tensor(k, v) for k, v in input_dict.items()
|
||||
]
|
||||
|
||||
# Create and execute inference request
|
||||
llm_request = pb_utils.InferenceRequest(
|
||||
model_name="tensorrt_llm",
|
||||
requested_output_names=["output_ids", "sequence_length"],
|
||||
inputs=input_tensor_list,
|
||||
)
|
||||
|
||||
llm_responses = llm_request.exec(decoupled=self.decoupled)
|
||||
if self.decoupled:
|
||||
for llm_response in llm_responses:
|
||||
if llm_response.has_error():
|
||||
raise pb_utils.TritonModelException(llm_response.error().message())
|
||||
|
||||
# Extract and process output
|
||||
output_ids = pb_utils.get_output_tensor_by_name(
|
||||
llm_response, "output_ids").as_numpy()
|
||||
seq_lens = pb_utils.get_output_tensor_by_name(
|
||||
llm_response, "sequence_length").as_numpy()
|
||||
|
||||
# Get actual output IDs up to the sequence length
|
||||
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
||||
|
||||
yield actual_output_ids
|
||||
else:
|
||||
llm_response = llm_responses
|
||||
if llm_response.has_error():
|
||||
raise pb_utils.TritonModelException(llm_response.error().message())
|
||||
|
||||
# Extract and process output
|
||||
output_ids = pb_utils.get_output_tensor_by_name(
|
||||
llm_response, "output_ids").as_numpy()
|
||||
seq_lens = pb_utils.get_output_tensor_by_name(
|
||||
llm_response, "sequence_length").as_numpy()
|
||||
|
||||
# Get actual output IDs up to the sequence length
|
||||
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
||||
|
||||
yield actual_output_ids
|
||||
|
||||
def forward_audio_tokenizer(self, wav, wav_len):
|
||||
"""Forward pass through the audio tokenizer component.
|
||||
|
||||
Args:
|
||||
wav: Input waveform tensor
|
||||
wav_len: Waveform length tensor
|
||||
|
||||
Returns:
|
||||
Tuple of global and semantic tokens
|
||||
"""
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='audio_tokenizer',
|
||||
requested_output_names=['prompt_speech_tokens'],
|
||||
inputs=[wav, wav_len]
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
# Extract and convert output tensors
|
||||
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
|
||||
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
|
||||
|
||||
return prompt_speech_tokens
|
||||
|
||||
def forward_speaker_embedding(self, wav):
|
||||
"""Forward pass through the speaker embedding component.
|
||||
|
||||
Args:
|
||||
wav: Input waveform tensor
|
||||
|
||||
Returns:
|
||||
Prompt speaker embedding tensor
|
||||
"""
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='speaker_embedding',
|
||||
requested_output_names=['prompt_spk_embedding'],
|
||||
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
# Extract and convert output tensors
|
||||
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
|
||||
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
|
||||
|
||||
return prompt_spk_embedding
|
||||
|
||||
def forward_token2wav(
|
||||
self,
|
||||
target_speech_tokens: torch.Tensor,
|
||||
request_id: str,
|
||||
prompt_speech_tokens: torch.Tensor = None,
|
||||
prompt_speech_feat: torch.Tensor = None,
|
||||
prompt_spk_embedding: torch.Tensor = None,
|
||||
token_offset: int = None,
|
||||
finalize: bool = None) -> torch.Tensor:
|
||||
"""Forward pass through the vocoder component.
|
||||
|
||||
Args:
|
||||
prompt_speech_tokens: Prompt speech tokens tensor
|
||||
prompt_speech_feat: Prompt speech feat tensor
|
||||
prompt_spk_embedding: Prompt spk embedding tensor
|
||||
target_speech_tokens: Target speech tokens tensor
|
||||
|
||||
Returns:
|
||||
Generated waveform tensor
|
||||
"""
|
||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||
|
||||
inputs_tensor = [target_speech_tokens_tensor]
|
||||
|
||||
if token_offset is not None:
|
||||
assert finalize is not None
|
||||
token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
|
||||
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||
inputs_tensor.append(token_offset_tensor)
|
||||
inputs_tensor.append(finalize_tensor)
|
||||
|
||||
if prompt_spk_embedding is not None:
|
||||
assert prompt_speech_feat is not None
|
||||
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
||||
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
||||
inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
|
||||
|
||||
# Create and execute inference request
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='token2wav',
|
||||
requested_output_names=['waveform'],
|
||||
inputs=inputs_tensor,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
# Extract and convert output waveform
|
||||
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
||||
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
||||
|
||||
return waveform
|
||||
|
||||
def parse_input(self, text, prompt_text, prompt_speech_tokens):
|
||||
total_text = f"{prompt_text}{text}"
|
||||
prompt = self.prompt_template.format(input_text=total_text)
|
||||
input_ids = self.tokenizer.encode(prompt)
|
||||
input_ids = torch.tensor([input_ids], dtype=torch.int32)
|
||||
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
||||
return input_ids
|
||||
|
||||
def _extract_speech_feat(self, speech):
|
||||
speech_feat = mel_spectrogram(
|
||||
speech,
|
||||
n_fft=1920,
|
||||
num_mels=80,
|
||||
sampling_rate=24000,
|
||||
hop_size=480,
|
||||
win_size=1920,
|
||||
fmin=0,
|
||||
fmax=8000).squeeze(
|
||||
dim=0).transpose(
|
||||
0,
|
||||
1).to(
|
||||
self.device)
|
||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||
return speech_feat
|
||||
|
||||
def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
|
||||
for generated_ids in generated_ids_iter:
|
||||
generated_ids = generated_ids.tolist()
|
||||
if len(generated_ids) == 0:
|
||||
break
|
||||
semantic_token_ids_arr.extend(generated_ids)
|
||||
llm_is_done_flag[0] = True
|
||||
|
||||
def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
Args:
|
||||
requests: List of inference requests
|
||||
|
||||
Returns:
|
||||
List of inference responses containing generated audio
|
||||
"""
|
||||
responses = []
|
||||
|
||||
for request in requests:
|
||||
request_id = request.request_id()
|
||||
# Extract input tensors
|
||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||
|
||||
# Process reference audio through audio tokenizer
|
||||
if wav is not None:
|
||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||
|
||||
wav_tensor = wav.as_numpy()
|
||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||
else:
|
||||
# using pre-cached reference text
|
||||
reference_text = self.default_spk_info["prompt_text"]
|
||||
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||
prompt_speech_feat = None
|
||||
prompt_spk_embedding = None
|
||||
|
||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||
target_text = target_text[0][0].decode('utf-8')
|
||||
|
||||
# Prepare prompt for LLM
|
||||
input_ids = self.parse_input(
|
||||
text=target_text,
|
||||
prompt_text=reference_text,
|
||||
prompt_speech_tokens=prompt_speech_tokens,
|
||||
)
|
||||
|
||||
# Generate semantic tokens with LLM
|
||||
generated_ids_iter = self.forward_llm(input_ids)
|
||||
|
||||
token2wav_request_id = request_id or str(uuid4())
|
||||
if self.decoupled:
|
||||
response_sender = request.get_response_sender()
|
||||
|
||||
semantic_token_ids_arr = []
|
||||
llm_is_done_flag = [False]
|
||||
|
||||
llm_thread = threading.Thread(
|
||||
target=self._llm_gen_thread,
|
||||
args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
|
||||
)
|
||||
|
||||
llm_thread.start()
|
||||
|
||||
token_offset, chunk_index = 0, 0
|
||||
start_time = time.time()
|
||||
this_token_hop_len = self.token_hop_len
|
||||
|
||||
while True:
|
||||
pending_num = len(semantic_token_ids_arr) - token_offset
|
||||
|
||||
if llm_is_done_flag[0]:
|
||||
break
|
||||
|
||||
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
||||
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
|
||||
sub_tts_speech = self.forward_token2wav(
|
||||
this_tts_speech_token, token2wav_request_id, prompt_speech_tokens,
|
||||
prompt_speech_feat, prompt_spk_embedding, token_offset, False
|
||||
)
|
||||
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
|
||||
token_offset += this_token_hop_len
|
||||
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
|
||||
|
||||
if self.dynamic_chunk_strategy == "exponential":
|
||||
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
||||
elif self.dynamic_chunk_strategy == "time_based":
|
||||
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
||||
cost_time = time.time() - start_time
|
||||
duration = token_offset / self.token_frame_rate
|
||||
if chunk_index > 0 and cost_time > 0:
|
||||
avg_chunk_processing_time = cost_time / (chunk_index + 1)
|
||||
if avg_chunk_processing_time > 0:
|
||||
multiples = (duration - cost_time) / avg_chunk_processing_time
|
||||
self.logger.log_info(f"multiples: {multiples}")
|
||||
next_pending_num = len(semantic_token_ids_arr) - token_offset
|
||||
if multiples > 4:
|
||||
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
|
||||
elif multiples > 2:
|
||||
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
|
||||
else:
|
||||
this_token_hop_len = self.token_hop_len
|
||||
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
||||
chunk_index += 1
|
||||
else:
|
||||
time.sleep(0.02)
|
||||
|
||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
|
||||
llm_thread.join()
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||
else:
|
||||
generated_ids = next(generated_ids_iter)
|
||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
|
||||
if generated_ids is None or len(generated_ids) == 0:
|
||||
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
||||
|
||||
audio = self.forward_token2wav(generated_ids, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
|
||||
|
||||
# Prepare response
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
responses.append(inference_response)
|
||||
|
||||
if not self.decoupled:
|
||||
return responses
|
||||
73
runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt
Normal file
73
runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "cosyvoice2"
|
||||
backend: "python"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
}
|
||||
model_transaction_policy {
|
||||
decoupled: ${decoupled_mode}
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
key: "llm_tokenizer_dir",
|
||||
value: {string_value:"${llm_tokenizer_dir}"}
|
||||
},
|
||||
{
|
||||
key: "model_dir",
|
||||
value: {string_value:"${model_dir}"}
|
||||
}
|
||||
]
|
||||
|
||||
input [
|
||||
{
|
||||
name: "reference_wav"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "reference_wav_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "reference_text"
|
||||
data_type: TYPE_STRING
|
||||
dims: [1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "target_text"
|
||||
data_type: TYPE_STRING
|
||||
dims: [1]
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "waveform"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: ${bls_instance_num}
|
||||
kind: KIND_CPU
|
||||
}
|
||||
]
|
||||
394
runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py
Normal file
394
runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, List, Tuple, Optional, Union
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import torchaudio
|
||||
|
||||
|
||||
from matcha.utils.audio import mel_spectrogram
|
||||
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
def parse_speech_token_string(response_text: str) -> List[int]:
|
||||
"""
|
||||
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
|
||||
"""
|
||||
speech_tokens = response_text.strip().split('><')
|
||||
if len(speech_tokens) > 1:
|
||||
# Add back the missing '<' and '>' for proper parsing
|
||||
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
|
||||
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
|
||||
|
||||
speech_ids = []
|
||||
for token_str in speech_tokens:
|
||||
match = re.match(r'<\|s_(\d+)\|>', token_str)
|
||||
if match:
|
||||
speech_ids.append(int(match.group(1)))
|
||||
return speech_ids
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for Spark TTS.
|
||||
|
||||
This model orchestrates the end-to-end TTS pipeline by coordinating
|
||||
between audio tokenizer, LLM, and vocoder components.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing model configuration
|
||||
"""
|
||||
self.logger = pb_utils.Logger
|
||||
# Parse model parameters
|
||||
self.model_config = json.loads(args['model_config'])
|
||||
parameters = self.model_config['parameters']
|
||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
||||
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
||||
|
||||
# Initialize tokenizer
|
||||
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
|
||||
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
|
||||
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
|
||||
|
||||
self.device = torch.device("cuda")
|
||||
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
||||
|
||||
self.token_frame_rate = 25
|
||||
self.flow_pre_lookahead_len = 3
|
||||
self.token_hop_len = 15
|
||||
|
||||
self.http_client = httpx.AsyncClient()
|
||||
self.api_base = "http://localhost:8000/v1/chat/completions"
|
||||
self.speaker_cache = {}
|
||||
|
||||
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
||||
"""Converts a tensor or list of speech token IDs to a string representation."""
|
||||
if isinstance(speech_tokens, torch.Tensor):
|
||||
# Ensure tensor is on CPU and flattened
|
||||
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
|
||||
|
||||
speech_id_str = ""
|
||||
for token_id in speech_tokens:
|
||||
# Convert token ID back to the speech number N
|
||||
token_num = token_id - ORIGINAL_VOCAB_SIZE
|
||||
speech_id_str += f"<|s_{token_num}|>"
|
||||
return speech_id_str
|
||||
|
||||
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
|
||||
"""
|
||||
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
|
||||
"""
|
||||
full_text = f"{reference_text}{target_text}"
|
||||
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
|
||||
|
||||
chat = [
|
||||
{"role": "user", "content": full_text},
|
||||
{"role": "assistant", "content": prompt_speech_tokens_str}
|
||||
]
|
||||
|
||||
payload = {
|
||||
"model": "trt_engines_bfloat16",
|
||||
"messages": chat,
|
||||
"max_tokens": 750,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.95,
|
||||
"top_k": 50,
|
||||
"repetition_penalty": 1.1,
|
||||
"stop": ["<|eos1|>", "<|eos|>"],
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
buffer = ""
|
||||
async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
line_data = line[len("data: "):].strip()
|
||||
if line_data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
json_data = json.loads(line_data)
|
||||
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
|
||||
if content:
|
||||
buffer += content
|
||||
while True:
|
||||
match = re.search(r"<\|s_(\d+)\|>", buffer)
|
||||
if not match:
|
||||
break
|
||||
|
||||
token_num = int(match.group(1))
|
||||
final_id = token_num + ORIGINAL_VOCAB_SIZE
|
||||
yield final_id
|
||||
buffer = buffer[match.end():]
|
||||
except json.JSONDecodeError:
|
||||
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
|
||||
continue
|
||||
|
||||
# Process any remaining complete tokens in the buffer after the stream ends
|
||||
while True:
|
||||
match = re.search(r"<\|s_(\d+)\|>", buffer)
|
||||
if not match:
|
||||
break
|
||||
token_num = int(match.group(1))
|
||||
final_id = token_num + ORIGINAL_VOCAB_SIZE
|
||||
yield final_id
|
||||
buffer = buffer[match.end():]
|
||||
|
||||
def forward_audio_tokenizer(self, wav, wav_len):
|
||||
"""Forward pass through the audio tokenizer component.
|
||||
|
||||
Args:
|
||||
wav: Input waveform tensor
|
||||
wav_len: Waveform length tensor
|
||||
|
||||
Returns:
|
||||
Tuple of global and semantic tokens
|
||||
"""
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='audio_tokenizer',
|
||||
requested_output_names=['prompt_speech_tokens'],
|
||||
inputs=[wav, wav_len]
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
# Extract and convert output tensors
|
||||
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
|
||||
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
|
||||
|
||||
return prompt_speech_tokens
|
||||
|
||||
def forward_speaker_embedding(self, wav):
|
||||
"""Forward pass through the speaker embedding component.
|
||||
|
||||
Args:
|
||||
wav: Input waveform tensor
|
||||
|
||||
Returns:
|
||||
Prompt speaker embedding tensor
|
||||
"""
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='speaker_embedding',
|
||||
requested_output_names=['prompt_spk_embedding'],
|
||||
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
# Extract and convert output tensors
|
||||
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
|
||||
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
|
||||
|
||||
return prompt_spk_embedding
|
||||
|
||||
async def forward_token2wav(
|
||||
self,
|
||||
index: int,
|
||||
target_speech_tokens: torch.Tensor,
|
||||
request_id: str,
|
||||
reference_wav: object,
|
||||
reference_wav_len: object,
|
||||
finalize: bool = None) -> torch.Tensor:
|
||||
"""Forward pass through the vocoder component.
|
||||
|
||||
Args:
|
||||
index: Index of the request
|
||||
target_speech_tokens: Target speech tokens tensor
|
||||
request_id: Request ID
|
||||
reference_wav: Reference waveform tensor
|
||||
reference_wav_len: Reference waveform length tensor
|
||||
finalize: Whether to finalize the request
|
||||
|
||||
Returns:
|
||||
Generated waveform tensor
|
||||
"""
|
||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
||||
|
||||
# Create and execute inference request
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='token2wav_dit',
|
||||
requested_output_names=[
|
||||
"waveform",
|
||||
],
|
||||
inputs=inputs_tensor,
|
||||
request_id=request_id,
|
||||
parameters={"priority": index + 1},
|
||||
)
|
||||
|
||||
inference_response = await inference_request.async_exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
# Extract and convert output waveform
|
||||
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
||||
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
||||
|
||||
return waveform
|
||||
|
||||
def _extract_speech_feat(self, speech):
|
||||
speech_feat = mel_spectrogram(
|
||||
speech,
|
||||
n_fft=1920,
|
||||
num_mels=80,
|
||||
sampling_rate=24000,
|
||||
hop_size=480,
|
||||
win_size=1920,
|
||||
fmin=0,
|
||||
fmax=8000).squeeze(
|
||||
dim=0).transpose(
|
||||
0,
|
||||
1).to(
|
||||
self.device)
|
||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||
return speech_feat
|
||||
|
||||
async def _process_request(self, request):
|
||||
request_id = request.request_id()
|
||||
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
|
||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||
|
||||
if reference_text not in self.speaker_cache:
|
||||
self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0)
|
||||
prompt_speech_tokens = self.speaker_cache[reference_text]
|
||||
|
||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||
target_text = target_text[0][0].decode('utf-8')
|
||||
|
||||
if self.decoupled:
|
||||
response_sender = request.get_response_sender()
|
||||
|
||||
semantic_token_ids_arr = []
|
||||
token_offset, chunk_index = 0, 0
|
||||
start_time = time.time()
|
||||
this_token_hop_len = self.token_hop_len
|
||||
async for generated_ids in self.forward_llm_async(
|
||||
target_text=target_text,
|
||||
reference_text=reference_text,
|
||||
prompt_speech_tokens=prompt_speech_tokens,
|
||||
):
|
||||
if not generated_ids:
|
||||
break
|
||||
semantic_token_ids_arr.append(generated_ids)
|
||||
while True:
|
||||
pending_num = len(semantic_token_ids_arr) - token_offset
|
||||
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
||||
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
sub_tts_speech = await self.forward_token2wav(
|
||||
chunk_index,
|
||||
this_tts_speech_token, request_id, wav, wav_len, False
|
||||
)
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
|
||||
token_offset += this_token_hop_len
|
||||
|
||||
if self.dynamic_chunk_strategy == "exponential":
|
||||
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
||||
elif self.dynamic_chunk_strategy == "equal":
|
||||
this_token_hop_len = self.token_hop_len
|
||||
elif self.dynamic_chunk_strategy == "time_based":
|
||||
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
||||
cost_time = time.time() - start_time
|
||||
duration = token_offset / self.token_frame_rate
|
||||
if chunk_index > 0 and cost_time > 0:
|
||||
avg_chunk_processing_time = cost_time / (chunk_index + 1)
|
||||
if avg_chunk_processing_time > 0:
|
||||
multiples = (duration - cost_time) / avg_chunk_processing_time
|
||||
next_pending_num = len(semantic_token_ids_arr) - token_offset
|
||||
if multiples > 4:
|
||||
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
|
||||
elif multiples > 2:
|
||||
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
|
||||
else:
|
||||
this_token_hop_len = self.token_hop_len
|
||||
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
||||
chunk_index += 1
|
||||
else:
|
||||
break
|
||||
|
||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
else:
|
||||
raise NotImplementedError("Offline TTS mode is not supported")
|
||||
|
||||
async def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
Args:
|
||||
requests: List of inference requests
|
||||
|
||||
Returns:
|
||||
List of inference responses containing generated audio
|
||||
"""
|
||||
tasks = [
|
||||
asyncio.create_task(self._process_request(request))
|
||||
for request in requests
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
return None
|
||||
|
||||
def finalize(self):
|
||||
self.logger.log_info("Finalizing CosyVoice DIT model")
|
||||
if hasattr(self, "http_client"):
|
||||
asyncio.run(self.http_client.aclose())
|
||||
73
runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt
Normal file
73
runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "cosyvoice2_dit"
|
||||
backend: "python"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
}
|
||||
model_transaction_policy {
|
||||
decoupled: ${decoupled_mode}
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
key: "llm_tokenizer_dir",
|
||||
value: {string_value:"${llm_tokenizer_dir}"}
|
||||
},
|
||||
{
|
||||
key: "model_dir",
|
||||
value: {string_value:"${model_dir}"}
|
||||
}
|
||||
]
|
||||
|
||||
input [
|
||||
{
|
||||
name: "reference_wav"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "reference_wav_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "reference_text"
|
||||
data_type: TYPE_STRING
|
||||
dims: [1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "target_text"
|
||||
data_type: TYPE_STRING
|
||||
dims: [1]
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "waveform"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: ${bls_instance_num}
|
||||
kind: KIND_CPU
|
||||
}
|
||||
]
|
||||
153
runtime/triton_trtllm/model_repo/speaker_embedding/1/model.py
Normal file
153
runtime/triton_trtllm/model_repo/speaker_embedding/1/model.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import json
|
||||
import torch
|
||||
from torch.utils.dlpack import to_dlpack
|
||||
|
||||
import triton_python_backend_utils as pb_utils
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
||||
from cosyvoice.utils.common import TrtContextWrapper
|
||||
import onnxruntime
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for audio tokenization.
|
||||
|
||||
This model takes reference audio input and extracts semantic tokens
|
||||
using s3tokenizer.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing model configuration
|
||||
"""
|
||||
# Parse model parameters
|
||||
parameters = json.loads(args['model_config'])['parameters']
|
||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
model_dir = model_params["model_dir"]
|
||||
gpu = "l20"
|
||||
enable_trt = True
|
||||
if enable_trt:
|
||||
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||
f'{model_dir}/campplus.onnx',
|
||||
1,
|
||||
False)
|
||||
else:
|
||||
campplus_model = f'{model_dir}/campplus.onnx'
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.spk_model = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
||||
|
||||
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
|
||||
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
|
||||
trt_kwargs = self.get_spk_trt_kwargs()
|
||||
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
|
||||
import tensorrt as trt
|
||||
with open(spk_model, 'rb') as f:
|
||||
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
|
||||
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_spk_trt_kwargs(self):
|
||||
min_shape = [(1, 4, 80)]
|
||||
opt_shape = [(1, 500, 80)]
|
||||
max_shape = [(1, 3000, 80)]
|
||||
input_names = ["input"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def _extract_spk_embedding(self, speech):
|
||||
feat = kaldi.fbank(speech,
|
||||
num_mel_bins=80,
|
||||
dither=0,
|
||||
sample_frequency=16000)
|
||||
spk_feat = feat - feat.mean(dim=0, keepdim=True)
|
||||
|
||||
if isinstance(self.spk_model, onnxruntime.InferenceSession):
|
||||
embedding = self.spk_model.run(
|
||||
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
||||
)[0].flatten().tolist()
|
||||
embedding = torch.tensor([embedding]).to(self.device)
|
||||
else:
|
||||
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
|
||||
# NOTE need to synchronize when switching stream
|
||||
with torch.cuda.device(self.device):
|
||||
torch.cuda.current_stream().synchronize()
|
||||
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
|
||||
batch_size = spk_feat.size(0)
|
||||
|
||||
with stream:
|
||||
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
|
||||
embedding = torch.empty((batch_size, 192), device=spk_feat.device)
|
||||
|
||||
data_ptrs = [spk_feat.contiguous().data_ptr(),
|
||||
embedding.contiguous().data_ptr()]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
|
||||
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
# run trt engine
|
||||
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.spk_model.release_estimator(spk_model, stream)
|
||||
|
||||
return embedding.half()
|
||||
|
||||
def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
Args:
|
||||
requests: List of inference requests
|
||||
|
||||
Returns:
|
||||
List of inference responses containing tokenized outputs
|
||||
"""
|
||||
responses = []
|
||||
# Process each request in batch
|
||||
for request in requests:
|
||||
# Extract input tensors
|
||||
wav_array = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav").as_numpy()
|
||||
wav_array = torch.from_numpy(wav_array).to(self.device)
|
||||
|
||||
embedding = self._extract_spk_embedding(wav_array)
|
||||
|
||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
|
||||
"prompt_spk_embedding", to_dlpack(embedding))
|
||||
inference_response = pb_utils.InferenceResponse(
|
||||
output_tensors=[prompt_spk_embedding_tensor])
|
||||
|
||||
responses.append(inference_response)
|
||||
|
||||
return responses
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "speaker_embedding"
|
||||
backend: "python"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
key: "model_dir",
|
||||
value: {string_value:"${model_dir}"}
|
||||
}
|
||||
]
|
||||
|
||||
input [
|
||||
{
|
||||
name: "reference_wav"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "prompt_spk_embedding"
|
||||
data_type: TYPE_FP16
|
||||
dims: [-1]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_CPU
|
||||
}
|
||||
]
|
||||
857
runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt
Normal file
857
runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt
Normal file
@@ -0,0 +1,857 @@
|
||||
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
name: "tensorrt_llm"
|
||||
backend: "${triton_backend}"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
|
||||
model_transaction_policy {
|
||||
decoupled: ${decoupled_mode}
|
||||
}
|
||||
|
||||
dynamic_batching {
|
||||
preferred_batch_size: [ ${triton_max_batch_size} ]
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
default_queue_policy: { max_queue_size: ${max_queue_size} }
|
||||
}
|
||||
|
||||
input [
|
||||
{
|
||||
name: "input_ids"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
allow_ragged_batch: true
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "encoder_input_features"
|
||||
data_type: ${encoder_input_features_data_type}
|
||||
dims: [ -1, -1 ]
|
||||
allow_ragged_batch: true
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "encoder_output_lengths"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "input_lengths"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
},
|
||||
{
|
||||
name: "request_output_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
},
|
||||
{
|
||||
name: "num_return_sequences"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "draft_input_ids"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "decoder_input_ids"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "decoder_input_lengths"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
reshape: { shape: [ ] }
|
||||
},
|
||||
{
|
||||
name: "draft_logits"
|
||||
data_type: ${logits_datatype}
|
||||
dims: [ -1, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "draft_acceptance_threshold"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "end_id"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "pad_id"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "stop_words_list"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 2, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "bad_words_list"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 2, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "embedding_bias"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "beam_width"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "temperature"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "runtime_top_k"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "runtime_top_p"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "runtime_top_p_min"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "runtime_top_p_decay"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "runtime_top_p_reset_ids"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "len_penalty"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "early_stopping"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "repetition_penalty"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "min_length"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "beam_search_diversity_rate"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "presence_penalty"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "frequency_penalty"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "random_seed"
|
||||
data_type: TYPE_UINT64
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "return_log_probs"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "return_context_logits"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "return_generation_logits"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "return_perf_metrics"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "exclude_input_in_output"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "stop"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "streaming"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "prompt_embedding_table"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ -1, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "prompt_table_extra_ids"
|
||||
data_type: TYPE_UINT64
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "prompt_vocab_size"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
# cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
|
||||
{
|
||||
name: "cross_attention_mask"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ -1, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
# Mrope param when mrope is used
|
||||
{
|
||||
name: "mrope_rotary_cos_sin"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "mrope_position_deltas"
|
||||
data_type: TYPE_INT64
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
},
|
||||
# the unique task ID for the given LoRA.
|
||||
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
|
||||
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
|
||||
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
|
||||
{
|
||||
name: "lora_task_id"
|
||||
data_type: TYPE_UINT64
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
|
||||
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
|
||||
# each of the in / out tensors are first flattened and then concatenated together in the format above.
|
||||
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
|
||||
{
|
||||
name: "lora_weights"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ -1, -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
# module identifier (same size a first dimension of lora_weights)
|
||||
# See LoraModule::ModuleType for model id mapping
|
||||
#
|
||||
# "attn_qkv": 0 # compbined qkv adapter
|
||||
# "attn_q": 1 # q adapter
|
||||
# "attn_k": 2 # k adapter
|
||||
# "attn_v": 3 # v adapter
|
||||
# "attn_dense": 4 # adapter for the dense layer in attention
|
||||
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
|
||||
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
|
||||
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
|
||||
#
|
||||
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
|
||||
{
|
||||
name: "lora_config"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1, 3 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "context_phase_params"
|
||||
data_type: TYPE_UINT8
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
# skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
|
||||
{
|
||||
name: "skip_cross_attn_blocks"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_token_range_starts"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_token_range_ends"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_token_range_priorities"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_token_range_durations_ms"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_decode_priority"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "retention_decode_duration_ms"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "guided_decoding_guide_type"
|
||||
data_type: TYPE_STRING
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "guided_decoding_guide"
|
||||
data_type: TYPE_STRING
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "lookahead_window_size"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "lookahead_ngram_size"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
},
|
||||
{
|
||||
name: "lookahead_verification_set_size"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
optional: true
|
||||
allow_ragged_batch: true
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "output_ids"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1, -1 ]
|
||||
},
|
||||
{
|
||||
name: "sequence_length"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ -1 ]
|
||||
},
|
||||
{
|
||||
name: "cum_log_probs"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
},
|
||||
{
|
||||
name: "output_log_probs"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1, -1 ]
|
||||
},
|
||||
{
|
||||
name: "context_logits"
|
||||
data_type: ${logits_datatype}
|
||||
dims: [ -1, -1 ]
|
||||
},
|
||||
{
|
||||
name: "generation_logits"
|
||||
data_type: ${logits_datatype}
|
||||
dims: [ -1, -1, -1 ]
|
||||
},
|
||||
{
|
||||
name: "batch_index"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "sequence_index"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "context_phase_params"
|
||||
data_type: TYPE_UINT8
|
||||
dims: [ -1 ]
|
||||
},
|
||||
{
|
||||
name: "kv_cache_alloc_new_blocks"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "kv_cache_reused_blocks"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "kv_cache_alloc_total_blocks"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "arrival_time_ns"
|
||||
data_type: TYPE_INT64
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "first_scheduled_time_ns"
|
||||
data_type: TYPE_INT64
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "first_token_time_ns"
|
||||
data_type: TYPE_INT64
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "last_token_time_ns"
|
||||
data_type: TYPE_INT64
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "acceptance_rate"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "total_accepted_draft_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
},
|
||||
{
|
||||
name: "total_draft_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
}
|
||||
]
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind : KIND_CPU
|
||||
}
|
||||
]
|
||||
parameters: {
|
||||
key: "max_beam_width"
|
||||
value: {
|
||||
string_value: "${max_beam_width}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "FORCE_CPU_ONLY_INPUT_TENSORS"
|
||||
value: {
|
||||
string_value: "no"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "gpt_model_type"
|
||||
value: {
|
||||
string_value: "${batching_strategy}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "gpt_model_path"
|
||||
value: {
|
||||
string_value: "${engine_dir}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "encoder_model_path"
|
||||
value: {
|
||||
string_value: "${encoder_engine_dir}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "max_tokens_in_paged_kv_cache"
|
||||
value: {
|
||||
string_value: "${max_tokens_in_paged_kv_cache}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "max_attention_window_size"
|
||||
value: {
|
||||
string_value: "${max_attention_window_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "sink_token_length"
|
||||
value: {
|
||||
string_value: "${sink_token_length}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "batch_scheduler_policy"
|
||||
value: {
|
||||
string_value: "${batch_scheduler_policy}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "kv_cache_free_gpu_mem_fraction"
|
||||
value: {
|
||||
string_value: "${kv_cache_free_gpu_mem_fraction}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "cross_kv_cache_fraction"
|
||||
value: {
|
||||
string_value: "${cross_kv_cache_fraction}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "kv_cache_host_memory_bytes"
|
||||
value: {
|
||||
string_value: "${kv_cache_host_memory_bytes}"
|
||||
}
|
||||
}
|
||||
# kv_cache_onboard_blocks is for internal implementation.
|
||||
parameters: {
|
||||
key: "kv_cache_onboard_blocks"
|
||||
value: {
|
||||
string_value: "${kv_cache_onboard_blocks}"
|
||||
}
|
||||
}
|
||||
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
|
||||
# parameters: {
|
||||
# key: "enable_trt_overlap"
|
||||
# value: {
|
||||
# string_value: "${enable_trt_overlap}"
|
||||
# }
|
||||
# }
|
||||
parameters: {
|
||||
key: "exclude_input_in_output"
|
||||
value: {
|
||||
string_value: "${exclude_input_in_output}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "cancellation_check_period_ms"
|
||||
value: {
|
||||
string_value: "${cancellation_check_period_ms}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "stats_check_period_ms"
|
||||
value: {
|
||||
string_value: "${stats_check_period_ms}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "iter_stats_max_iterations"
|
||||
value: {
|
||||
string_value: "${iter_stats_max_iterations}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "request_stats_max_iterations"
|
||||
value: {
|
||||
string_value: "${request_stats_max_iterations}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "enable_kv_cache_reuse"
|
||||
value: {
|
||||
string_value: "${enable_kv_cache_reuse}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "normalize_log_probs"
|
||||
value: {
|
||||
string_value: "${normalize_log_probs}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "enable_chunked_context"
|
||||
value: {
|
||||
string_value: "${enable_chunked_context}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "gpu_device_ids"
|
||||
value: {
|
||||
string_value: "${gpu_device_ids}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "participant_ids"
|
||||
value: {
|
||||
string_value: "${participant_ids}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lora_cache_optimal_adapter_size"
|
||||
value: {
|
||||
string_value: "${lora_cache_optimal_adapter_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lora_cache_max_adapter_size"
|
||||
value: {
|
||||
string_value: "${lora_cache_max_adapter_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lora_cache_gpu_memory_fraction"
|
||||
value: {
|
||||
string_value: "${lora_cache_gpu_memory_fraction}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lora_cache_host_memory_bytes"
|
||||
value: {
|
||||
string_value: "${lora_cache_host_memory_bytes}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lora_prefetch_dir"
|
||||
value: {
|
||||
string_value: "${lora_prefetch_dir}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "decoding_mode"
|
||||
value: {
|
||||
string_value: "${decoding_mode}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "executor_worker_path"
|
||||
value: {
|
||||
string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lookahead_window_size"
|
||||
value: {
|
||||
string_value: "${lookahead_window_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lookahead_ngram_size"
|
||||
value: {
|
||||
string_value: "${lookahead_ngram_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "lookahead_verification_set_size"
|
||||
value: {
|
||||
string_value: "${lookahead_verification_set_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "medusa_choices"
|
||||
value: {
|
||||
string_value: "${medusa_choices}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "eagle_choices"
|
||||
value: {
|
||||
string_value: "${eagle_choices}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "gpu_weights_percent"
|
||||
value: {
|
||||
string_value: "${gpu_weights_percent}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "enable_context_fmha_fp32_acc"
|
||||
value: {
|
||||
string_value: "${enable_context_fmha_fp32_acc}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "multi_block_mode"
|
||||
value: {
|
||||
string_value: "${multi_block_mode}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "cuda_graph_mode"
|
||||
value: {
|
||||
string_value: "${cuda_graph_mode}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "cuda_graph_cache_size"
|
||||
value: {
|
||||
string_value: "${cuda_graph_cache_size}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "speculative_decoding_fast_logits"
|
||||
value: {
|
||||
string_value: "${speculative_decoding_fast_logits}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "tokenizer_dir"
|
||||
value: {
|
||||
string_value: "${tokenizer_dir}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "guided_decoding_backend"
|
||||
value: {
|
||||
string_value: "${guided_decoding_backend}"
|
||||
}
|
||||
}
|
||||
parameters: {
|
||||
key: "xgrammar_tokenizer_info_path"
|
||||
value: {
|
||||
string_value: "${xgrammar_tokenizer_info_path}"
|
||||
}
|
||||
}
|
||||
277
runtime/triton_trtllm/model_repo/token2wav/1/model.py
Normal file
277
runtime/triton_trtllm/model_repo/token2wav/1/model.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.utils.dlpack import to_dlpack
|
||||
from torch.nn import functional as F
|
||||
|
||||
import triton_python_backend_utils as pb_utils
|
||||
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from cosyvoice.utils.common import fade_in_out
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
||||
from cosyvoice.utils.common import TrtContextWrapper
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
class CosyVoice2:
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
|
||||
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
|
||||
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
|
||||
if not os.path.exists(hyper_yaml_path):
|
||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
|
||||
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
|
||||
|
||||
class CosyVoice2Model:
|
||||
|
||||
def __init__(self,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False,
|
||||
device: str = 'cuda'):
|
||||
self.device = device
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
if self.fp16 is True:
|
||||
self.flow.half()
|
||||
|
||||
# streaming tts config
|
||||
self.token_hop_len = 25
|
||||
self.mel_cache_len = 8
|
||||
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||
self.hift_cache_dict = defaultdict(lambda: None)
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load(self, flow_model, hift_model):
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
# in case hift_model is a hifigan model
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_trt_kwargs(self):
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
||||
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
|
||||
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
|
||||
input_names = ["x", "mask", "mu", "cond"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
streaming=stream,
|
||||
finalize=finalize)
|
||||
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||
# append hift cache
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||
else:
|
||||
hift_cache_source = torch.zeros(1, 1, 0)
|
||||
# keep overlap mel and hift cache
|
||||
if finalize is False:
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||
'source': tts_source[:, :, -self.source_cache_len:],
|
||||
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||
else:
|
||||
if speed != 1.0:
|
||||
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
return tts_speech
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for vocoder.
|
||||
|
||||
This model takes global and semantic tokens as input and generates audio waveforms
|
||||
using the BiCodec vocoder.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing model configuration
|
||||
"""
|
||||
# Parse model parameters
|
||||
parameters = json.loads(args['model_config'])['parameters']
|
||||
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
||||
model_dir = model_params["model_dir"]
|
||||
|
||||
# Initialize device and vocoder
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
||||
|
||||
self.token2wav_model = CosyVoice2(
|
||||
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
|
||||
)
|
||||
|
||||
spk_info_path = os.path.join(model_dir, "spk2info.pt")
|
||||
if not os.path.exists(spk_info_path):
|
||||
raise ValueError(f"spk2info.pt not found in {model_dir}")
|
||||
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||
self.default_spk_info = spk_info["001"]
|
||||
|
||||
logger.info("Token2Wav initialized successfully")
|
||||
|
||||
def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
Args:
|
||||
requests: List of inference requests
|
||||
|
||||
Returns:
|
||||
List of inference responses containing generated waveforms
|
||||
"""
|
||||
responses = []
|
||||
# Process each request in batch
|
||||
for request in requests:
|
||||
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
||||
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
|
||||
|
||||
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
|
||||
if prompt_speech_tokens_tensor is not None:
|
||||
prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
|
||||
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
|
||||
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
|
||||
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
|
||||
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
|
||||
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
|
||||
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
else:
|
||||
prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
|
||||
prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
|
||||
prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
|
||||
|
||||
# shift the speech tokens according to the original vocab size
|
||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
|
||||
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
|
||||
token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
|
||||
if token_offset is not None:
|
||||
token_offset = token_offset.as_numpy().item()
|
||||
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||
if not finalize:
|
||||
stream = True
|
||||
else:
|
||||
stream = False
|
||||
request_id = request.request_id()
|
||||
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
|
||||
prompt_token=prompt_speech_tokens,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=prompt_spk_embedding,
|
||||
token_offset=token_offset,
|
||||
uuid=request_id,
|
||||
stream=stream,
|
||||
finalize=finalize)
|
||||
if finalize:
|
||||
self.token2wav_model.model.hift_cache_dict.pop(request_id)
|
||||
|
||||
else:
|
||||
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
||||
token=target_speech_tokens,
|
||||
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
||||
self.device
|
||||
),
|
||||
prompt_token=prompt_speech_tokens,
|
||||
prompt_token_len=torch.tensor(
|
||||
[prompt_speech_tokens.shape[1]], dtype=torch.int32
|
||||
).to(self.device),
|
||||
prompt_feat=prompt_speech_feat,
|
||||
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=prompt_spk_embedding,
|
||||
streaming=False,
|
||||
finalize=True,
|
||||
)
|
||||
|
||||
audio_hat, _ = self.token2wav_model.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
|
||||
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||
|
||||
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
|
||||
responses.append(inference_response)
|
||||
|
||||
return responses
|
||||
80
runtime/triton_trtllm/model_repo/token2wav/config.pbtxt
Normal file
80
runtime/triton_trtllm/model_repo/token2wav/config.pbtxt
Normal file
@@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "token2wav"
|
||||
backend: "python"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
key: "model_dir",
|
||||
value: {string_value:"${model_dir}"}
|
||||
}
|
||||
]
|
||||
|
||||
input [
|
||||
{
|
||||
name: "target_speech_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [-1]
|
||||
},
|
||||
{
|
||||
name: "prompt_speech_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [-1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "prompt_speech_feat"
|
||||
data_type: TYPE_FP16
|
||||
dims: [-1, 80]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "prompt_spk_embedding"
|
||||
data_type: TYPE_FP16
|
||||
dims: [-1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "token_offset"
|
||||
data_type: TYPE_INT32
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "finalize"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "waveform"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_CPU
|
||||
}
|
||||
]
|
||||
142
runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py
Normal file
142
runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
|
||||
import torch
|
||||
from torch.utils.dlpack import to_dlpack
|
||||
from torch.nn import functional as F
|
||||
|
||||
import triton_python_backend_utils as pb_utils
|
||||
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from cosyvoice.utils.common import fade_in_out
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
||||
from cosyvoice.utils.common import TrtContextWrapper
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from .token2wav_dit import CosyVoice2_Token2Wav
|
||||
import hashlib
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
||||
"""
|
||||
Generates a unique ID for a torch.Tensor.
|
||||
Tensors with the same elements and properties will have the same ID.
|
||||
"""
|
||||
# Convert tensor to a byte string
|
||||
tensor_bytes = tensor.numpy().tobytes()
|
||||
|
||||
# Create a SHA-256 hash of the byte string
|
||||
hasher = hashlib.sha256()
|
||||
hasher.update(tensor_bytes)
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for vocoder.
|
||||
|
||||
This model takes global and semantic tokens as input and generates audio waveforms
|
||||
using the BiCodec vocoder.
|
||||
"""
|
||||
|
||||
def initialize(self, args):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing model configuration
|
||||
"""
|
||||
# Parse model parameters
|
||||
parameters = json.loads(args['model_config'])['parameters']
|
||||
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
||||
model_dir = model_params["model_dir"]
|
||||
|
||||
# Initialize device and vocoder
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
||||
|
||||
# FIXME: device id settings
|
||||
self.token2wav_model = CosyVoice2_Token2Wav(
|
||||
model_dir, enable_trt=True, streaming=True
|
||||
)
|
||||
logger.info("Token2Wav initialized successfully")
|
||||
|
||||
def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
Args:
|
||||
requests: List of inference requests
|
||||
|
||||
Returns:
|
||||
List of inference responses containing generated waveforms
|
||||
"""
|
||||
responses = []
|
||||
# Process each request in batch
|
||||
for request in requests:
|
||||
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
||||
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)
|
||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
target_speech_tokens = target_speech_tokens.squeeze().tolist()
|
||||
|
||||
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||
|
||||
request_id = request.request_id()
|
||||
|
||||
wav_array = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav").as_numpy()
|
||||
wav_len = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav_len").as_numpy().item()
|
||||
|
||||
wav_array = torch.from_numpy(wav_array)
|
||||
wav = wav_array[:, :wav_len].squeeze(0)
|
||||
|
||||
spk_id = get_spk_id_from_prompt_audio(wav)
|
||||
|
||||
audio_hat = self.token2wav_model.forward_streaming(
|
||||
target_speech_tokens, finalize, request_id=request_id,
|
||||
speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
|
||||
)
|
||||
|
||||
outputs = []
|
||||
|
||||
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
||||
outputs.append(wav_tensor)
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
|
||||
responses.append(inference_response)
|
||||
|
||||
return responses
|
||||
@@ -0,0 +1,510 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Example Usage
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python3 token2wav.py --enable-trt || exit 1
|
||||
"""
|
||||
import torch
|
||||
# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
|
||||
from flashcosyvoice.modules.hifigan import HiFTGenerator
|
||||
from flashcosyvoice.utils.audio import mel_spectrogram
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
import onnxruntime
|
||||
import s3tokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
from datasets import load_dataset
|
||||
import torchaudio
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import queue
|
||||
import time
|
||||
import numpy as np
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
|
||||
|
||||
def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor):
|
||||
"""perform fade_in_out in tensor style
|
||||
"""
|
||||
mel_overlap_len = int(window.shape[0] / 2)
|
||||
fade_in_mel = fade_in_mel.clone()
|
||||
fade_in_mel[..., :mel_overlap_len] = \
|
||||
fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
||||
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
||||
return fade_in_mel
|
||||
|
||||
|
||||
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
|
||||
import tensorrt as trt
|
||||
logging.info("Converting onnx to trt...")
|
||||
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||
if dtype == torch.float16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
profile = builder.create_optimization_profile()
|
||||
# load onnx model
|
||||
with open(onnx_model, "rb") as f:
|
||||
if not parser.parse(f.read()):
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
raise ValueError('failed to parse {}'.format(onnx_model))
|
||||
# set input shapes
|
||||
for i in range(len(trt_kwargs['input_names'])):
|
||||
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
||||
if dtype == torch.float16:
|
||||
tensor_dtype = trt.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
tensor_dtype = trt.DataType.BF16
|
||||
elif dtype == torch.float32:
|
||||
tensor_dtype = trt.DataType.FLOAT
|
||||
else:
|
||||
raise ValueError('invalid dtype {}'.format(dtype))
|
||||
# set input and output data type
|
||||
for i in range(network.num_inputs):
|
||||
input_tensor = network.get_input(i)
|
||||
input_tensor.dtype = tensor_dtype
|
||||
for i in range(network.num_outputs):
|
||||
output_tensor = network.get_output(i)
|
||||
output_tensor.dtype = tensor_dtype
|
||||
config.add_optimization_profile(profile)
|
||||
engine_bytes = builder.build_serialized_network(network, config)
|
||||
# save trt engine
|
||||
with open(trt_model, "wb") as f:
|
||||
f.write(engine_bytes)
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
self.trt_engine = trt_engine
|
||||
self.device = device
|
||||
for _ in range(trt_concurrent):
|
||||
trt_context = trt_engine.create_execution_context()
|
||||
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
|
||||
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||
self.trt_context_pool.put([trt_context, trt_stream])
|
||||
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.trt_context_pool.get(), self.trt_engine
|
||||
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
|
||||
|
||||
class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
|
||||
super().__init__()
|
||||
self.device_id = device_id
|
||||
self.device = f"cuda:{device_id}"
|
||||
with open(f"{model_dir}/flow.yaml", "r") as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
self.flow = configs['flow']
|
||||
|
||||
self.dtype = dtype
|
||||
self.flow.to(self.dtype)
|
||||
|
||||
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
|
||||
self.hift = HiFTGenerator()
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.spk_model = onnxruntime.InferenceSession(
|
||||
f"{model_dir}/campplus.onnx", sess_options=option,
|
||||
providers=["CPUExecutionProvider"])
|
||||
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
|
||||
|
||||
gpu = "l20"
|
||||
if enable_trt:
|
||||
if streaming:
|
||||
self.load_trt(
|
||||
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
|
||||
1,
|
||||
self.dtype, streaming
|
||||
)
|
||||
else:
|
||||
self.load_trt(
|
||||
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
||||
1,
|
||||
self.dtype
|
||||
)
|
||||
self.load_spk_trt(
|
||||
f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||
f'{model_dir}/campplus.onnx',
|
||||
1,
|
||||
False
|
||||
)
|
||||
|
||||
self.streaming_flow_cache = {}
|
||||
self.speaker_cache = {}
|
||||
|
||||
self.mel_cache_len = 8 # hard-coded, 160ms
|
||||
self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave
|
||||
self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
|
||||
|
||||
# hifigan cache for streaming tts
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
def forward_spk_embedding(self, spk_feat):
|
||||
if isinstance(self.spk_model, onnxruntime.InferenceSession):
|
||||
return self.spk_model.run(
|
||||
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
||||
)[0].flatten().tolist()
|
||||
else:
|
||||
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
|
||||
# NOTE need to synchronize when switching stream
|
||||
with torch.cuda.device(self.device_id):
|
||||
torch.cuda.current_stream().synchronize()
|
||||
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
|
||||
batch_size = spk_feat.size(0)
|
||||
|
||||
with stream:
|
||||
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
|
||||
output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
|
||||
|
||||
data_ptrs = [spk_feat.contiguous().data_ptr(),
|
||||
output_tensor.contiguous().data_ptr()]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
|
||||
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
# run trt engine
|
||||
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.spk_model.release_estimator(spk_model, stream)
|
||||
|
||||
return output_tensor.cpu().numpy().flatten().tolist()
|
||||
|
||||
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
|
||||
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
|
||||
trt_kwargs = self.get_spk_trt_kwargs()
|
||||
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32)
|
||||
import tensorrt as trt
|
||||
with open(spk_model, 'rb') as f:
|
||||
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
|
||||
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_spk_trt_kwargs(self):
|
||||
min_shape = [(1, 4, 80)]
|
||||
opt_shape = [(1, 500, 80)]
|
||||
max_shape = [(1, 3000, 80)]
|
||||
input_names = ["input"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||
opt_batch_size = 2
|
||||
max_batch_size = 16
|
||||
if streaming:
|
||||
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
|
||||
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
|
||||
if streaming:
|
||||
min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
|
||||
opt_shape = [
|
||||
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500),
|
||||
(opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2),
|
||||
(16, opt_batch_size * 2, 8, 100, 128)
|
||||
]
|
||||
max_shape = [
|
||||
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000),
|
||||
(max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2),
|
||||
(16, max_batch_size * 2, 8, 1000, 128)
|
||||
]
|
||||
input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
|
||||
else:
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
|
||||
opt_shape = [
|
||||
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500),
|
||||
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80)
|
||||
]
|
||||
max_shape = [
|
||||
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000),
|
||||
(max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80)
|
||||
]
|
||||
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
|
||||
prompt_speech_tokens_list, prompt_speech_mels_list = [], []
|
||||
for audio in prompt_audios_list:
|
||||
assert len(audio.shape) == 1
|
||||
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
|
||||
prompt_speech_mels_list.append(log_mel)
|
||||
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
|
||||
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
|
||||
prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
|
||||
)
|
||||
for i in range(len(prompt_speech_tokens)):
|
||||
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
||||
prompt_speech_tokens_list.append(speech_tokens_i)
|
||||
return prompt_speech_tokens_list
|
||||
|
||||
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
||||
spk_emb_for_flow = []
|
||||
for audio in prompt_audios_list:
|
||||
assert len(audio.shape) == 1
|
||||
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
|
||||
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
|
||||
spk_emb = self.forward_spk_embedding(spk_feat)
|
||||
|
||||
spk_emb_for_flow.append(spk_emb)
|
||||
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
||||
if self.dtype != torch.float32:
|
||||
spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
|
||||
return spk_emb_for_flow
|
||||
|
||||
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
||||
prompt_mels_for_flow = []
|
||||
prompt_mels_lens_for_flow = []
|
||||
for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
|
||||
assert len(audio.shape) == 1
|
||||
audio = audio.unsqueeze(0)
|
||||
if sample_rate != 24000:
|
||||
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
|
||||
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
|
||||
mel_len = mel.shape[0]
|
||||
prompt_mels_for_flow.append(mel)
|
||||
prompt_mels_lens_for_flow.append(mel_len)
|
||||
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
|
||||
prompt_mels_for_flow, batch_first=True, padding_value=0
|
||||
) # [B, T', num_mels=80]
|
||||
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
||||
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
||||
|
||||
def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
|
||||
generated_speech_tokens_list: list[list[int]],
|
||||
prompt_mels_for_flow: torch.Tensor,
|
||||
prompt_mels_lens_for_flow: torch.Tensor,
|
||||
spk_emb_for_flow: torch.Tensor):
|
||||
batch_size = prompt_mels_for_flow.shape[0]
|
||||
flow_inputs = []
|
||||
flow_inputs_lens = []
|
||||
for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
|
||||
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
|
||||
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
|
||||
|
||||
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
|
||||
flow_inputs_lens = torch.tensor(flow_inputs_lens)
|
||||
|
||||
with torch.amp.autocast(self.device, dtype=torch.float16):
|
||||
generated_mels, generated_mels_lens = self.flow.inference(
|
||||
flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
|
||||
prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10
|
||||
)
|
||||
|
||||
return generated_mels, generated_mels_lens
|
||||
|
||||
def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
|
||||
batch_size = generated_mels.shape[0]
|
||||
generated_wavs = []
|
||||
for i in range(batch_size):
|
||||
mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
|
||||
wav, _ = self.hift(speech_feat=mel)
|
||||
generated_wavs.append(wav)
|
||||
return generated_wavs
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
):
|
||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||
|
||||
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
generated_mels, generated_mels_lens = self.forward_flow(
|
||||
prompt_speech_tokens_list, generated_speech_tokens_list,
|
||||
prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
)
|
||||
|
||||
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
||||
return generated_wavs
|
||||
|
||||
def prepare_prompt_audio(
|
||||
self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
):
|
||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||
|
||||
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
||||
|
||||
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
||||
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
|
||||
def get_prompt_audio_cache_for_streaming_tts(
|
||||
self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
):
|
||||
assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts"
|
||||
for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list):
|
||||
prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3])
|
||||
prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0)
|
||||
|
||||
cache = self.flow.setup_cache(
|
||||
prompt_speech_tokens_tensor.to(self.device),
|
||||
prompt_mels_for_flow.to(self.device),
|
||||
spk_emb_for_flow.to(self.device),
|
||||
n_timesteps=10
|
||||
)
|
||||
new_cache = {k: v.clone() for k, v in cache.items()}
|
||||
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
|
||||
return new_cache
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_streaming(
|
||||
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
||||
):
|
||||
if speaker_id not in self.speaker_cache:
|
||||
assert prompt_audio is not None, "prompt_audio is required for new speaker"
|
||||
assert prompt_audio_sample_rate == 16000
|
||||
|
||||
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate])
|
||||
|
||||
token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0]))
|
||||
prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
|
||||
prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
|
||||
|
||||
prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
|
||||
|
||||
cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
|
||||
|
||||
if request_id not in self.streaming_flow_cache:
|
||||
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
||||
self.hift_cache_dict[request_id] = dict(
|
||||
mel=torch.zeros(1, 80, 0, device='cuda'),
|
||||
source=torch.zeros(1, 1, 0, device='cuda'),
|
||||
speech=torch.zeros(1, 0, device='cuda'),
|
||||
)
|
||||
|
||||
current_request_cache = self.streaming_flow_cache[request_id]
|
||||
|
||||
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
|
||||
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
||||
|
||||
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
|
||||
token=generated_speech_tokens,
|
||||
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
|
||||
cache=current_request_cache,
|
||||
last_chunk=last_chunk,
|
||||
n_timesteps=10,
|
||||
)
|
||||
|
||||
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
|
||||
|
||||
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
|
||||
], dim=4)
|
||||
|
||||
hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
|
||||
hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
|
||||
hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
|
||||
mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone()
|
||||
|
||||
speech, source = self.hift(mel, hift_cache_source)
|
||||
|
||||
# overlap speech smooth
|
||||
if hift_cache_speech.shape[-1] > 0:
|
||||
speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
|
||||
|
||||
# update vocoder cache
|
||||
self.hift_cache_dict[request_id] = dict(
|
||||
mel=mel[..., -self.mel_cache_len:].clone().detach(),
|
||||
source=source[:, :, -self.source_cache_len:].clone().detach(),
|
||||
speech=speech[:, -self.source_cache_len:].clone().detach(),
|
||||
)
|
||||
if not last_chunk:
|
||||
speech = speech[:, :-self.source_cache_len]
|
||||
|
||||
if last_chunk:
|
||||
assert request_id in self.streaming_flow_cache
|
||||
self.streaming_flow_cache.pop(request_id)
|
||||
self.hift_cache_dict.pop(request_id)
|
||||
|
||||
return speech
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||
for item in batch:
|
||||
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||
prompt_audios_list.append(audio)
|
||||
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||
ids.append(item['id'])
|
||||
|
||||
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--enable-trt", action="store_true")
|
||||
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
|
||||
parser.add_argument("--batch-size", type=int, default=1)
|
||||
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
||||
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
||||
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
dataset_name = "yuekai/seed_tts_cosy2"
|
||||
|
||||
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||
|
||||
for _ in range(args.warmup):
|
||||
start_time = time.time()
|
||||
for batch in data_loader:
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
||||
|
||||
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
for id, wav in zip(ids, generated_wavs):
|
||||
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
||||
end_time = time.time()
|
||||
epoch_time = end_time - start_time
|
||||
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||
69
runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt
Normal file
69
runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "token2wav_dit"
|
||||
backend: "python"
|
||||
max_batch_size: ${triton_max_batch_size}
|
||||
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
||||
priority_levels: 10
|
||||
default_priority_level: 10
|
||||
}
|
||||
|
||||
parameters [
|
||||
{
|
||||
key: "model_dir",
|
||||
value: {string_value:"${model_dir}"}
|
||||
}
|
||||
]
|
||||
|
||||
input [
|
||||
{
|
||||
name: "target_speech_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [-1]
|
||||
},
|
||||
{
|
||||
name: "reference_wav"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
},
|
||||
{
|
||||
name: "reference_wav_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
},
|
||||
{
|
||||
name: "finalize"
|
||||
data_type: TYPE_BOOL
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "waveform"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_CPU
|
||||
}
|
||||
]
|
||||
652
runtime/triton_trtllm/offline_inference.py
Normal file
652
runtime/triton_trtllm/offline_inference.py
Normal file
@@ -0,0 +1,652 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Example Usage
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python3 offline_inference.py \
|
||||
--output-dir $output_dir \
|
||||
--llm-model-name-or-path $huggingface_model_local_dir \
|
||||
--token2wav-path $model_scope_model_local_dir \
|
||||
--backend $backend \
|
||||
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
|
||||
--engine-dir $trt_engines_dir \
|
||||
--split-name ${dataset} || exit 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from tqdm import tqdm
|
||||
import soundfile as sf
|
||||
import s3tokenizer
|
||||
from functools import partial
|
||||
import time
|
||||
import requests
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
try:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def send_request_async(client, url, payload):
|
||||
response = await client.post(url, json=payload, timeout=None)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
return response_json['choices'][0]['message']['content']
|
||||
|
||||
|
||||
async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
|
||||
async with httpx.AsyncClient() as client:
|
||||
tasks = []
|
||||
for chat in chats:
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"messages": chat,
|
||||
"max_tokens": 2048,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"repetition_penalty": 1.1,
|
||||
"stop": ["<|eos1|>", "<|eos|>"],
|
||||
"stream": False,
|
||||
}
|
||||
tasks.append(send_request_async(client, api_base, payload))
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
def extract_speech_ids(speech_tokens_str):
|
||||
"""Extract speech IDs from token strings like <|s_23456|>"""
|
||||
speech_ids = []
|
||||
for token_str in speech_tokens_str:
|
||||
if token_str.startswith('<|s_') and token_str.endswith('|>'):
|
||||
num_str = token_str[4:-2]
|
||||
num = int(num_str)
|
||||
speech_ids.append(num)
|
||||
else:
|
||||
print(f"Unexpected token: {token_str}")
|
||||
return speech_ids
|
||||
|
||||
|
||||
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||
speech_id_str = ""
|
||||
for token in cosy2_tokens:
|
||||
speech_id_str += f"<|s_{token}|>"
|
||||
return speech_id_str
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", required=True, type=str, help="dir to save result"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
default=1,
|
||||
type=int,
|
||||
help="batch size (per-device) for inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token2wav-batch-size",
|
||||
default=1,
|
||||
type=int,
|
||||
help="batch size (per-device) for inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers", type=int, default=0, help="workers for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefetch", type=int, default=None, help="prefetch for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-model-name-or-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="LLM model path (includes both model and tokenizer)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token2wav-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="CosyVoice2 token2wav model path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-text",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt text for CosyVoice2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-speech-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the prompt speech for CosyVoice2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="top p for sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="temperature for sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=50,
|
||||
help="top k for sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="hf",
|
||||
choices=["hf", "trtllm", "vllm", "trtllm-serve"],
|
||||
help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--engine-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-free-gpu-memory-fraction",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--openai-api-base",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1/chat/completions",
|
||||
help="OpenAI API base URL (for trtllm-serve backend)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--openai-model-name",
|
||||
type=str,
|
||||
default="trt_engines_bfloat16",
|
||||
help="Model name to use with OpenAI API (for trtllm-serve backend)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
"""Simplified data collator for batch_size=1 processing"""
|
||||
collator_start_time = time.time()
|
||||
total_audio_processing_time = 0
|
||||
total_speech_tokenization_time = 0
|
||||
total_text_tokenization_time = 0
|
||||
|
||||
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
||||
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
||||
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||
prompt_text_after_apply_template_list = []
|
||||
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
|
||||
chat_list = []
|
||||
for _, item in enumerate(batch):
|
||||
audio_processing_start_time = time.time()
|
||||
prompt_text, target_text = (
|
||||
item["prompt_text"],
|
||||
item["target_text"],
|
||||
)
|
||||
prompt_text_list.append(prompt_text)
|
||||
full_text = prompt_text + target_text
|
||||
full_text_list.append(full_text)
|
||||
# remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
|
||||
puncts = ['"', '(', ')', '“', '”', '‘', '(', ')', '\'']
|
||||
for p in puncts:
|
||||
if p in full_text:
|
||||
full_text = full_text.replace(p, '')
|
||||
print(f"removed {p} from {full_text}")
|
||||
|
||||
# get prompt audio for CosyVoice2 (convert to 16kHz)
|
||||
ref_audio_org, ref_sr = (
|
||||
item["prompt_audio"]["array"],
|
||||
item["prompt_audio"]["sampling_rate"],
|
||||
)
|
||||
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
|
||||
print(ref_audio_org.shape)
|
||||
|
||||
if ref_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||
ref_audio = resampler(ref_audio_org)
|
||||
else:
|
||||
ref_audio = ref_audio_org
|
||||
|
||||
prompt_audio_list.append(ref_audio)
|
||||
audio_processing_end_time = time.time()
|
||||
total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
|
||||
|
||||
speech_tokenization_start_time = time.time()
|
||||
if "prompt_audio_cosy2_tokens" in item:
|
||||
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
|
||||
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
|
||||
else:
|
||||
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
|
||||
|
||||
if len(mels) > 0:
|
||||
mels, mels_lens = s3tokenizer.padding(mels)
|
||||
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
||||
for i in range(len(codes)):
|
||||
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
||||
speech_tokenization_end_time = time.time()
|
||||
total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
|
||||
|
||||
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
|
||||
text_tokenization_start_time = time.time()
|
||||
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
||||
# Create chat template for LLM generation
|
||||
chat = [
|
||||
{"role": "user", "content": full_text_list[i]},
|
||||
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
|
||||
]
|
||||
chat_list.append(chat)
|
||||
|
||||
assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
|
||||
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
chat,
|
||||
tokenize=True,
|
||||
return_tensors='pt',
|
||||
continue_final_message=True
|
||||
)
|
||||
input_ids_list.append(input_ids.squeeze(0))
|
||||
|
||||
prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
|
||||
|
||||
prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
|
||||
text_tokenization_end_time = time.time()
|
||||
total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
|
||||
|
||||
ids = [item["id"] for item in batch]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids_list,
|
||||
"ids": ids,
|
||||
"prompt_text": prompt_text_list,
|
||||
"prompt_audio_list": prompt_audio_list,
|
||||
"prompt_text_after_apply_template": prompt_text_after_apply_template_list,
|
||||
"audio_processing_time": total_audio_processing_time,
|
||||
"speech_tokenization_time": total_speech_tokenization_time,
|
||||
"text_tokenization_time": total_text_tokenization_time,
|
||||
"chat_list": chat_list
|
||||
}
|
||||
|
||||
|
||||
def init_distributed():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
print(
|
||||
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||
)
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group("nccl")
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
local_rank, world_size, rank = 0, 1, 0
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
||||
|
||||
if args.backend == "hf":
|
||||
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
runner = None
|
||||
elif args.backend == "trtllm":
|
||||
if args.engine_dir is None:
|
||||
raise ValueError("--engine-dir is required when backend is 'trtllm'")
|
||||
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
model = None
|
||||
|
||||
runner_kwargs = dict(
|
||||
engine_dir=args.engine_dir,
|
||||
rank=runtime_rank,
|
||||
max_output_len=2048,
|
||||
enable_context_fmha_fp32_acc=False,
|
||||
max_batch_size=args.batch_size,
|
||||
max_input_len=512,
|
||||
kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
|
||||
cuda_graph_mode=False,
|
||||
gather_generation_logits=False,
|
||||
)
|
||||
|
||||
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
|
||||
elif args.backend == "vllm":
|
||||
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
|
||||
runner = None
|
||||
elif args.backend == "trtllm-serve":
|
||||
model = None
|
||||
runner = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
if 'Step-Audio-2-mini' in args.token2wav_path:
|
||||
from token2wav_dit import CosyVoice2_Token2Wav
|
||||
else:
|
||||
assert 'CosyVoice2-0.5B' in args.token2wav_path
|
||||
from token2wav import CosyVoice2_Token2Wav
|
||||
token2wav_model = CosyVoice2_Token2Wav(
|
||||
model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
|
||||
)
|
||||
if args.prompt_speech_path:
|
||||
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
|
||||
else:
|
||||
prompt_speech_16k = None
|
||||
s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
|
||||
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
|
||||
dataset = load_dataset(
|
||||
dataset_name,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
sampler = None
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch,
|
||||
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
|
||||
)
|
||||
for _ in range(3):
|
||||
print(f"Running {_} times")
|
||||
total_llm_time = 0
|
||||
total_token2wav_time = 0
|
||||
total_data_load_time = 0
|
||||
total_llm_post_processing_time = 0
|
||||
total_audio_save_time = 0
|
||||
total_audio_processing_time_in_collator = 0
|
||||
total_speech_tokenization_time_in_collator = 0
|
||||
total_text_tokenization_time_in_collator = 0
|
||||
total_audio_samples = 0
|
||||
start_time = time.time()
|
||||
total_steps = len(dataset)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||
|
||||
last_batch_end_time = time.time()
|
||||
for batch in dataloader:
|
||||
data_loaded_time = time.time()
|
||||
total_data_load_time += data_loaded_time - last_batch_end_time
|
||||
total_audio_processing_time_in_collator += batch["audio_processing_time"]
|
||||
total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
|
||||
total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
|
||||
with torch.no_grad():
|
||||
llm_start_time = time.time()
|
||||
if args.backend == "hf":
|
||||
input_ids_list = batch["input_ids"]
|
||||
if len(input_ids_list) == 1:
|
||||
input_ids = input_ids_list[0].unsqueeze(0)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||
input_ids_list_new = [
|
||||
torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
|
||||
for input_ids in input_ids_list
|
||||
]
|
||||
input_ids = torch.stack(input_ids_list_new)
|
||||
attention_mask = torch.zeros_like(input_ids)
|
||||
for i in range(len(input_ids_list)):
|
||||
attention_mask[i, :len(input_ids_list[i])] = 1
|
||||
|
||||
input_ids = input_ids.to(device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
max_new_tokens=2048,
|
||||
do_sample=True,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
repetition_penalty=1.1,
|
||||
top_k=args.top_k,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
elif args.backend == "trtllm":
|
||||
batch_input_ids = list(batch["input_ids"])
|
||||
input_lengths = [x.size(0) for x in batch_input_ids]
|
||||
|
||||
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
|
||||
print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
|
||||
outputs = runner.generate(
|
||||
batch_input_ids=batch_input_ids,
|
||||
max_new_tokens=2048,
|
||||
end_id=end_id,
|
||||
pad_id=end_id,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1,
|
||||
streaming=False,
|
||||
output_sequence_lengths=True,
|
||||
output_generation_logits=False,
|
||||
return_dict=True,
|
||||
return_all_generated_tokens=False
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
|
||||
num_output_sents, num_beams, _ = output_ids.size()
|
||||
assert num_beams == 1
|
||||
beam = 0
|
||||
batch_size = len(batch["input_ids"])
|
||||
num_return_sequences = num_output_sents // batch_size
|
||||
assert num_return_sequences == 1
|
||||
outputs = []
|
||||
for i in range(batch_size * num_return_sequences):
|
||||
batch_idx = i // num_return_sequences
|
||||
seq_idx = i % num_return_sequences
|
||||
output_begin = input_lengths[batch_idx]
|
||||
output_end = sequence_lengths[i][beam]
|
||||
outputs_i = output_ids[i][beam][:output_end].tolist()
|
||||
outputs.append(outputs_i)
|
||||
elif args.backend == "vllm":
|
||||
input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
repetition_penalty=1.1,
|
||||
max_tokens=2048,
|
||||
)
|
||||
outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
|
||||
print(outputs)
|
||||
for j, output in enumerate(outputs):
|
||||
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
|
||||
elif args.backend == "trtllm-serve":
|
||||
if args.batch_size > 1:
|
||||
outputs = asyncio.run(send_batch_requests_async(
|
||||
args.openai_api_base,
|
||||
args.openai_model_name,
|
||||
batch["chat_list"],
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.top_k,
|
||||
))
|
||||
else:
|
||||
outputs = []
|
||||
for chat in batch["chat_list"]:
|
||||
payload = {
|
||||
"model": args.openai_model_name,
|
||||
"messages": chat,
|
||||
"max_tokens": 2048,
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"repetition_penalty": 1.1,
|
||||
"stop": ["<|eos1|>", "<|eos|>"],
|
||||
"stream": False,
|
||||
}
|
||||
response = requests.post(args.openai_api_base, json=payload)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
generated_content = response_json['choices'][0]['message']['content']
|
||||
outputs.append(generated_content)
|
||||
|
||||
llm_end_time = time.time()
|
||||
total_llm_time += (llm_end_time - llm_start_time)
|
||||
|
||||
items_for_token_2wav = []
|
||||
for i in range(len(batch["ids"])):
|
||||
llm_post_processing_start_time = time.time()
|
||||
if args.backend == "trtllm-serve":
|
||||
speech_tokens_str = outputs[i].strip().split('><')
|
||||
if len(speech_tokens_str) > 1:
|
||||
speech_tokens_str = [
|
||||
t if t.startswith('<') else '<' + t for t in speech_tokens_str
|
||||
]
|
||||
speech_tokens_str = [
|
||||
t if t.endswith('>') else t + '>' for t in speech_tokens_str
|
||||
]
|
||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||
else:
|
||||
input_length = len(batch["input_ids"][i])
|
||||
generated_ids = outputs[i][input_length:]
|
||||
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||
print(i, speech_ids)
|
||||
if len(speech_ids) == 0:
|
||||
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||
continue
|
||||
|
||||
if args.prompt_text is not None:
|
||||
current_prompt_text = args.prompt_text
|
||||
current_prompt_audio = prompt_speech_16k
|
||||
else:
|
||||
current_prompt_text = batch["prompt_text"][i]
|
||||
current_prompt_audio = batch["prompt_audio_list"][i]
|
||||
|
||||
llm_post_processing_end_time = time.time()
|
||||
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
|
||||
if current_prompt_audio is not None:
|
||||
items_for_token_2wav.append({
|
||||
"speech_ids": speech_ids,
|
||||
"prompt_audio": current_prompt_audio.squeeze(0),
|
||||
"id": batch["ids"][i]
|
||||
})
|
||||
else:
|
||||
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||
|
||||
for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
|
||||
t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
|
||||
if not t2w_batch:
|
||||
continue
|
||||
|
||||
t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
|
||||
t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
|
||||
t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
|
||||
t2w_ids = [item["id"] for item in t2w_batch]
|
||||
|
||||
token2wav_start_time = time.time()
|
||||
generated_wavs = token2wav_model(
|
||||
t2w_generated_speech_tokens_list,
|
||||
t2w_prompt_audios_list,
|
||||
t2w_prompt_audios_sample_rate,
|
||||
)
|
||||
token2wav_end_time = time.time()
|
||||
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
|
||||
|
||||
audio_save_start_time = time.time()
|
||||
for j, audio_hat in enumerate(generated_wavs):
|
||||
generated_wave = audio_hat.squeeze().cpu().numpy()
|
||||
total_audio_samples += len(generated_wave)
|
||||
target_sample_rate = 24000
|
||||
|
||||
utt = t2w_ids[j]
|
||||
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
|
||||
print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
|
||||
audio_save_end_time = time.time()
|
||||
total_audio_save_time += audio_save_end_time - audio_save_start_time
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(world_size * len(batch["ids"]))
|
||||
|
||||
last_batch_end_time = time.time()
|
||||
if rank == 0:
|
||||
progress_bar.close()
|
||||
end_time = time.time()
|
||||
target_sample_rate = 24000
|
||||
total_audio_duration_seconds = total_audio_samples / target_sample_rate
|
||||
|
||||
log_file_path = os.path.join(args.output_dir, "log.txt")
|
||||
with open(log_file_path, 'w') as f:
|
||||
args_dict = vars(args)
|
||||
log_data = {
|
||||
"args": args_dict,
|
||||
"data_load_time_seconds": total_data_load_time,
|
||||
"audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
|
||||
"speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
|
||||
"text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
|
||||
"llm_time_seconds": total_llm_time,
|
||||
"llm_post_processing_time_seconds": total_llm_post_processing_time,
|
||||
"token2wav_time_seconds": total_token2wav_time,
|
||||
"audio_save_time_seconds": total_audio_save_time,
|
||||
"total_audio_duration_seconds": total_audio_duration_seconds,
|
||||
"pipeline_time_seconds": end_time - start_time,
|
||||
}
|
||||
print(log_data)
|
||||
f.write(json.dumps(log_data, indent=4))
|
||||
print(f"Metrics logged to {log_file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
if args.backend == "vllm":
|
||||
from vllm import LLM, SamplingParams
|
||||
elif args.backend == "trtllm":
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm.runtime import ModelRunnerCpp
|
||||
elif args.backend == "hf":
|
||||
from transformers import AutoModelForCausalLM
|
||||
elif args.backend == "trtllm-serve":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
main(args)
|
||||
14
runtime/triton_trtllm/requirements.txt
Normal file
14
runtime/triton_trtllm/requirements.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
hyperpyyaml
|
||||
s3tokenizer
|
||||
onnxruntime-gpu
|
||||
omegaconf
|
||||
conformer
|
||||
hydra-core
|
||||
lightning
|
||||
gdown
|
||||
wget
|
||||
librosa
|
||||
pyworld
|
||||
openai-whisper
|
||||
tritonclient
|
||||
modelscope
|
||||
142
runtime/triton_trtllm/run.sh
Normal file
142
runtime/triton_trtllm/run.sh
Normal file
@@ -0,0 +1,142 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
cosyvoice_path=/workspace/CosyVoice
|
||||
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
|
||||
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
|
||||
huggingface_model_local_dir=./cosyvoice2_llm
|
||||
model_scope_model_local_dir=./CosyVoice2-0.5B
|
||||
trt_dtype=bfloat16
|
||||
trt_weights_dir=./trt_weights_${trt_dtype}
|
||||
trt_engines_dir=./trt_engines_${trt_dtype}
|
||||
|
||||
model_repo=./model_repo_cosyvoice2
|
||||
|
||||
use_spk2info_cache=False
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
echo "Cloning CosyVoice"
|
||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
|
||||
cd $cosyvoice_path
|
||||
git submodule update --init --recursive
|
||||
cd runtime/triton_trtllm
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
echo "Downloading CosyVoice2-0.5B"
|
||||
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
|
||||
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
|
||||
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
|
||||
# download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
|
||||
wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
echo "Converting checkpoint to TensorRT weights"
|
||||
python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
|
||||
--output_dir $trt_weights_dir \
|
||||
--dtype $trt_dtype || exit 1
|
||||
|
||||
echo "Building TensorRT engines"
|
||||
trtllm-build --checkpoint_dir $trt_weights_dir \
|
||||
--output_dir $trt_engines_dir \
|
||||
--max_batch_size 16 \
|
||||
--max_num_tokens 32768 \
|
||||
--gemm_plugin $trt_dtype || exit 1
|
||||
|
||||
echo "Testing TensorRT engines"
|
||||
python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
|
||||
--tokenizer_dir $huggingface_model_local_dir \
|
||||
--top_k 50 --top_p 0.95 --temperature 0.8 \
|
||||
--engine_dir=$trt_engines_dir || exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
echo "Creating model repository"
|
||||
rm -rf $model_repo
|
||||
mkdir -p $model_repo
|
||||
cosyvoice2_dir="cosyvoice2"
|
||||
|
||||
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
||||
cp -r ./model_repo/tensorrt_llm $model_repo
|
||||
cp -r ./model_repo/token2wav $model_repo
|
||||
if [ $use_spk2info_cache == "False" ]; then
|
||||
cp -r ./model_repo/audio_tokenizer $model_repo
|
||||
cp -r ./model_repo/speaker_embedding $model_repo
|
||||
fi
|
||||
|
||||
ENGINE_PATH=$trt_engines_dir
|
||||
MAX_QUEUE_DELAY_MICROSECONDS=0
|
||||
MODEL_DIR=$model_scope_model_local_dir
|
||||
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
|
||||
BLS_INSTANCE_NUM=4
|
||||
TRITON_MAX_BATCH_SIZE=16
|
||||
DECOUPLED_MODE=True # True for streaming, False for offline
|
||||
|
||||
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
||||
if [ $use_spk2info_cache == "False" ]; then
|
||||
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
echo "Starting Triton server"
|
||||
tritonserver --model-repository $model_repo
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
echo "Single request test http, only work for offline TTS mode"
|
||||
python3 client_http.py \
|
||||
--reference-audio ./assets/prompt_audio.wav \
|
||||
--reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
||||
--target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
|
||||
--model-name cosyvoice2
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
echo "Running benchmark client grpc"
|
||||
num_task=4
|
||||
|
||||
mode=streaming
|
||||
BLS_INSTANCE_NUM=4
|
||||
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--model-name cosyvoice2 \
|
||||
--num-tasks $num_task \
|
||||
--mode $mode \
|
||||
--use-spk2info-cache $use_spk2info_cache \
|
||||
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
echo "stage 6: Offline inference benchmark"
|
||||
n_gpus=1
|
||||
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
|
||||
backend=trtllm # hf, trtllm, vllm
|
||||
|
||||
batch_sizes=(16 8 4 2 1)
|
||||
token2wav_batch_size=1
|
||||
for batch_size in ${batch_sizes[@]}; do
|
||||
for dataset in ${datasets[@]}; do
|
||||
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python3 offline_inference.py \
|
||||
--output-dir $output_dir \
|
||||
--llm-model-name-or-path $huggingface_model_local_dir \
|
||||
--token2wav-path $model_scope_model_local_dir \
|
||||
--backend $backend \
|
||||
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
|
||||
--engine-dir $trt_engines_dir \
|
||||
--split-name ${dataset} || exit 1
|
||||
done
|
||||
done
|
||||
fi
|
||||
225
runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh
Normal file
225
runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh
Normal file
@@ -0,0 +1,225 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
cosyvoice_path=/workspace/CosyVoice
|
||||
stepaudio2_path=/workspace/Step-Audio2
|
||||
|
||||
export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
|
||||
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
|
||||
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
|
||||
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
|
||||
huggingface_model_local_dir=./cosyvoice2_llm
|
||||
model_scope_model_local_dir=./CosyVoice2-0.5B
|
||||
step_audio_model_dir=./Step-Audio-2-mini
|
||||
|
||||
trt_dtype=bfloat16
|
||||
trt_weights_dir=./trt_weights_${trt_dtype}
|
||||
trt_engines_dir=./trt_engines_${trt_dtype}
|
||||
|
||||
model_repo=./model_repo_cosyvoice2_dit
|
||||
bls_instance_num=10
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
|
||||
echo "Cloning Step-Audio2-mini"
|
||||
git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt $stepaudio2_path
|
||||
|
||||
echo "Cloning CosyVoice"
|
||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
|
||||
cd $cosyvoice_path
|
||||
git submodule update --init --recursive
|
||||
cd runtime/triton_trtllm
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
echo "Downloading CosyVoice2-0.5B"
|
||||
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
|
||||
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
|
||||
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
|
||||
|
||||
echo "Step-Audio2-mini"
|
||||
huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini
|
||||
cd $step_audio_model_dir/token2wav
|
||||
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx
|
||||
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx
|
||||
cd -
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
echo "Converting checkpoint to TensorRT weights"
|
||||
python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
|
||||
--output_dir $trt_weights_dir \
|
||||
--dtype $trt_dtype || exit 1
|
||||
|
||||
echo "Building TensorRT engines"
|
||||
trtllm-build --checkpoint_dir $trt_weights_dir \
|
||||
--output_dir $trt_engines_dir \
|
||||
--max_batch_size 64 \
|
||||
--max_num_tokens 32768 \
|
||||
--gemm_plugin $trt_dtype || exit 1
|
||||
|
||||
echo "Testing TensorRT engines"
|
||||
python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
|
||||
--tokenizer_dir $huggingface_model_local_dir \
|
||||
--top_k 50 --top_p 0.95 --temperature 0.8 \
|
||||
--engine_dir=$trt_engines_dir || exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
echo "Creating model repository async mode"
|
||||
rm -rf $model_repo
|
||||
mkdir -p $model_repo
|
||||
cosyvoice2_dir="cosyvoice2_dit"
|
||||
token2wav_dir="token2wav_dit"
|
||||
|
||||
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
||||
cp -r ./model_repo/${token2wav_dir} $model_repo
|
||||
cp -r ./model_repo/audio_tokenizer $model_repo
|
||||
cp -r ./model_repo/speaker_embedding $model_repo
|
||||
|
||||
|
||||
ENGINE_PATH=$trt_engines_dir
|
||||
MAX_QUEUE_DELAY_MICROSECONDS=0
|
||||
MODEL_DIR=$model_scope_model_local_dir
|
||||
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
|
||||
BLS_INSTANCE_NUM=$bls_instance_num
|
||||
TRITON_MAX_BATCH_SIZE=1
|
||||
DECOUPLED_MODE=True # Only streaming TTS mode is supported using Nvidia Triton for now
|
||||
STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav
|
||||
|
||||
python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve"
|
||||
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64 --kv_cache_free_gpu_memory_fraction 0.4 &
|
||||
tritonserver --model-repository $model_repo --http-port 18000 &
|
||||
wait
|
||||
# Test using curl
|
||||
# curl http://localhost:8000/v1/chat/completions \
|
||||
# -H "Content-Type: application/json" \
|
||||
# -d '{
|
||||
# "model": "",
|
||||
# "messages":[{"role": "user", "content": "Where is New York?"},
|
||||
# {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
|
||||
# "max_tokens": 512,
|
||||
# "temperature": 0.8,
|
||||
# "top_p": 0.95,
|
||||
# "top_k": 50,
|
||||
# "stop": ["<|eos1|>"],
|
||||
# "repetition_penalty": 1.2,
|
||||
# "stream": false
|
||||
# }'
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
echo "Running benchmark client"
|
||||
num_task=4
|
||||
mode=streaming
|
||||
BLS_INSTANCE_NUM=$bls_instance_num
|
||||
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--server-port 8001 \
|
||||
--model-name cosyvoice2_dit \
|
||||
--num-tasks $num_task \
|
||||
--mode $mode \
|
||||
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
||||
--log-dir ./log_single_gpu_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script"
|
||||
|
||||
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
|
||||
backend=trtllm # hf, trtllm, vllm, trtllm-serve
|
||||
|
||||
batch_sizes=(16)
|
||||
token2wav_batch_size=1
|
||||
|
||||
for batch_size in ${batch_sizes[@]}; do
|
||||
for dataset in ${datasets[@]}; do
|
||||
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
|
||||
CUDA_VISIBLE_DEVICES=1 \
|
||||
python3 offline_inference.py \
|
||||
--output-dir $output_dir \
|
||||
--llm-model-name-or-path $huggingface_model_local_dir \
|
||||
--token2wav-path $step_audio_model_dir/token2wav \
|
||||
--backend $backend \
|
||||
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
|
||||
--engine-dir $trt_engines_dir \
|
||||
--split-name ${dataset} || exit 1
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
echo "Running Step-Audio2-mini DiT Token2Wav inference using a single python script"
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
# Note: Using pre-computed cosyvoice2 tokens
|
||||
python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential
|
||||
# Offline Token2wav inference
|
||||
python3 token2wav_dit.py --enable-trt
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
echo "Disaggregated Server: LLM and Token2wav on different GPUs"
|
||||
echo "Starting LLM server on GPU 0"
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64 --kv_cache_free_gpu_memory_fraction 0.4 &
|
||||
echo "Starting Token2wav server on GPUs 1-3"
|
||||
Token2wav_num_gpus=3
|
||||
http_port=17000
|
||||
grpc_port=18000
|
||||
metrics_port=16000
|
||||
for i in $(seq 0 $(($Token2wav_num_gpus - 1))); do
|
||||
echo "Starting server on GPU $i"
|
||||
http_port=$((http_port + 1))
|
||||
grpc_port=$((grpc_port + 1))
|
||||
metrics_port=$((metrics_port + 1))
|
||||
# Two instances of Token2wav server on the same GPU
|
||||
CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
|
||||
http_port=$((http_port + 1))
|
||||
grpc_port=$((grpc_port + 1))
|
||||
metrics_port=$((metrics_port + 1))
|
||||
CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
echo "Running benchmark client for Disaggregated Server"
|
||||
per_gpu_instances=2
|
||||
mode=streaming
|
||||
BLS_INSTANCE_NUM=$bls_instance_num
|
||||
Token2wav_num_gpus=(1 2 3)
|
||||
concurrent_tasks=(1 2 3 4 5 6)
|
||||
for n_gpu in ${Token2wav_num_gpus[@]}; do
|
||||
echo "Test 1 GPU for LLM server and $n_gpu GPUs for Token2wav servers"
|
||||
for concurrent_task in ${concurrent_tasks[@]}; do
|
||||
num_instances=$((per_gpu_instances * n_gpu))
|
||||
for i in $(seq 1 $num_instances); do
|
||||
port=$(($i + 18000))
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--server-port $port \
|
||||
--model-name cosyvoice2_dit \
|
||||
--num-tasks $concurrent_task \
|
||||
--mode $mode \
|
||||
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
||||
--log-dir ./log_disagg_concurrent_tasks_${concurrent_task}_per_instance_total_token2wav_instances_${num_instances}_port_${port} &
|
||||
done
|
||||
wait
|
||||
done
|
||||
done
|
||||
fi
|
||||
330
runtime/triton_trtllm/scripts/convert_checkpoint.py
Normal file
330
runtime/triton_trtllm/scripts/convert_checkpoint.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import release_gc
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models import QWenForCausalLM
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_dir', type=str, default=None, required=True)
|
||||
parser.add_argument('--tp_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='N-way tensor parallelism size')
|
||||
parser.add_argument('--pp_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='N-way pipeline parallelism size')
|
||||
parser.add_argument('--cp_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='N-way context parallelism size')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
||||
help="The data type for the model weights and activations if not quantized. "
|
||||
"If 'auto', the data type is automatically inferred from the source model; "
|
||||
"however, if the source dtype is float32, it is converted to float16.")
|
||||
parser.add_argument(
|
||||
'--use_weight_only',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
||||
'See --weight_only_precision to set the precision')
|
||||
parser.add_argument(
|
||||
'--disable_weight_only_quant_plugin',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--weight_only_precision',
|
||||
const='int8',
|
||||
type=str,
|
||||
nargs='?',
|
||||
default='int8',
|
||||
choices=['int8', 'int4', 'int4_gptq'],
|
||||
help='Define the precision for the weights when using weight-only quantization.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--calib_dataset',
|
||||
type=str,
|
||||
default='ccdv/cnn_dailymail',
|
||||
help="The huggingface dataset name or the local directory of the dataset for calibration."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smoothquant",
|
||||
"-sq",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
|
||||
" to Smoothquant the model, and output int8 weights."
|
||||
" A good first try is 0.5. Must be in [0, 1]")
|
||||
parser.add_argument(
|
||||
'--per_channel',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help='By default, we use a single static scaling factor for the GEMM\'s result. '
|
||||
'per_channel instead uses a different static scaling factor for each channel. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
parser.add_argument(
|
||||
'--per_token',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help='By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
parser.add_argument(
|
||||
'--int8_kv_cache',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_group',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='By default, we use a single static scaling factor to scale weights in the int4 range. '
|
||||
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
||||
'The flag is built for GPTQ/AWQ quantization.')
|
||||
|
||||
parser.add_argument('--group_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Group size used in GPTQ quantization.')
|
||||
|
||||
parser.add_argument("--load_model_on_cpu", action="store_true")
|
||||
parser.add_argument(
|
||||
'--use_parallel_embedding',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--embedding_sharding_dim',
|
||||
type=int,
|
||||
default=0,
|
||||
choices=[0, 1],
|
||||
help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
||||
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
||||
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
||||
)
|
||||
parser.add_argument('--output_dir',
|
||||
type=str,
|
||||
default='tllm_checkpoint',
|
||||
help='The path to save the TensorRT-LLM checkpoint')
|
||||
parser.add_argument(
|
||||
'--workers',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of workers for converting checkpoint in parallel')
|
||||
parser.add_argument(
|
||||
'--moe_tp_size',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--moe_ep_size',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
|
||||
'''return config dict with quantization info based on the command line args
|
||||
'''
|
||||
quant_config = QuantConfig()
|
||||
if args.use_weight_only:
|
||||
if args.weight_only_precision == 'int8':
|
||||
quant_config.quant_algo = QuantAlgo.W8A16
|
||||
elif args.weight_only_precision == 'int4':
|
||||
quant_config.quant_algo = QuantAlgo.W4A16
|
||||
elif args.smoothquant:
|
||||
quant_config.smoothquant_val = args.smoothquant
|
||||
if args.per_channel:
|
||||
if args.per_token:
|
||||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
|
||||
else:
|
||||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
|
||||
else:
|
||||
if args.per_token:
|
||||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
|
||||
else:
|
||||
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
|
||||
|
||||
if args.int8_kv_cache:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.INT8
|
||||
|
||||
if args.weight_only_precision == 'int4_gptq':
|
||||
quant_config.group_size = args.group_size
|
||||
quant_config.has_zero_point = True
|
||||
quant_config.pre_quant_scale = False
|
||||
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
||||
|
||||
return quant_config
|
||||
|
||||
|
||||
def update_quant_config_from_hf(quant_config, hf_config,
|
||||
override_fields) -> tuple[QuantConfig, dict]:
|
||||
hf_config_dict = hf_config.to_dict()
|
||||
if hf_config_dict.get('quantization_config'):
|
||||
# update the quant_algo, and clamp_val.
|
||||
if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
|
||||
logger.info(
|
||||
"Load quantization configs from huggingface model_config.")
|
||||
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
||||
quant_config.group_size = hf_config_dict['quantization_config'].get(
|
||||
'group_size', 128)
|
||||
quant_config.has_zero_point = hf_config_dict[
|
||||
'quantization_config'].get('zero_point', False)
|
||||
override_fields.update({"use_autoawq": True})
|
||||
elif hf_config_dict['quantization_config'].get(
|
||||
'quant_method') == 'gptq':
|
||||
logger.info(
|
||||
"Load quantization configs from huggingface model_config.")
|
||||
desc_act = hf_config_dict['quantization_config'].get(
|
||||
'desc_act', False)
|
||||
if desc_act:
|
||||
raise ValueError("GPTQ with desc_act=True is not implemented!")
|
||||
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
||||
quant_config.group_size = hf_config_dict['quantization_config'].get(
|
||||
'group_size', 128)
|
||||
quant_config.has_zero_point = hf_config_dict[
|
||||
'quantization_config'].get('sym', False)
|
||||
return quant_config, override_fields
|
||||
|
||||
|
||||
def args_to_build_options(args):
|
||||
return {
|
||||
'use_parallel_embedding': args.use_parallel_embedding,
|
||||
'embedding_sharding_dim': args.embedding_sharding_dim,
|
||||
'disable_weight_only_quant_plugin':
|
||||
args.disable_weight_only_quant_plugin
|
||||
}
|
||||
|
||||
|
||||
def convert_and_save_hf(args):
|
||||
model_dir = args.model_dir
|
||||
world_size = args.tp_size * args.pp_size
|
||||
# Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
|
||||
# Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
|
||||
# before the refactor is done.
|
||||
override_fields = {}
|
||||
override_fields.update(args_to_build_options(args))
|
||||
quant_config = args_to_quant_config(args)
|
||||
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_dir,
|
||||
trust_remote_code=True)
|
||||
quant_config, override_fields = update_quant_config_from_hf(
|
||||
quant_config, hf_config, override_fields)
|
||||
except BaseException:
|
||||
logger.warning("AutoConfig cannot load the huggingface config.")
|
||||
|
||||
if args.smoothquant is not None or args.int8_kv_cache:
|
||||
mapping = Mapping(world_size=world_size,
|
||||
tp_size=args.tp_size,
|
||||
pp_size=args.pp_size,
|
||||
moe_tp_size=args.moe_tp_size,
|
||||
moe_ep_size=args.moe_ep_size,
|
||||
cp_size=args.cp_size)
|
||||
QWenForCausalLM.quantize(args.model_dir,
|
||||
args.output_dir,
|
||||
dtype=args.dtype,
|
||||
mapping=mapping,
|
||||
quant_config=quant_config,
|
||||
calib_dataset=args.calib_dataset,
|
||||
**override_fields)
|
||||
else:
|
||||
|
||||
def convert_and_save_rank(args, rank):
|
||||
mapping = Mapping(world_size=world_size,
|
||||
rank=rank,
|
||||
tp_size=args.tp_size,
|
||||
pp_size=args.pp_size,
|
||||
moe_tp_size=args.moe_tp_size,
|
||||
moe_ep_size=args.moe_ep_size)
|
||||
qwen = QWenForCausalLM.from_hugging_face(model_dir,
|
||||
args.dtype,
|
||||
mapping=mapping,
|
||||
quant_config=quant_config,
|
||||
**override_fields)
|
||||
qwen.config.mapping.cp_size = args.cp_size
|
||||
qwen.config.mapping.attn_tp_size = -1
|
||||
qwen.config.mapping.attn_cp_size = -1
|
||||
qwen.config.mapping.world_size *= args.cp_size
|
||||
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
||||
del qwen
|
||||
|
||||
execute(args.workers, [convert_and_save_rank] * world_size, args)
|
||||
release_gc()
|
||||
|
||||
|
||||
def execute(workers, func, args):
|
||||
if workers == 1:
|
||||
for rank, f in enumerate(func):
|
||||
f(args, rank)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as p:
|
||||
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
|
||||
exceptions = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
exceptions.append(e)
|
||||
assert len(
|
||||
exceptions
|
||||
) == 0, "Checkpoint conversion failed, please check error log."
|
||||
|
||||
|
||||
def main():
|
||||
print(tensorrt_llm.__version__)
|
||||
args = parse_arguments()
|
||||
|
||||
if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
|
||||
# moe default to tp-only
|
||||
args.moe_tp_size = args.tp_size
|
||||
args.moe_ep_size = 1
|
||||
elif (args.moe_tp_size == -1):
|
||||
args.moe_tp_size = args.tp_size // args.moe_ep_size
|
||||
elif (args.moe_ep_size == -1):
|
||||
args.moe_ep_size = args.tp_size // args.moe_tp_size
|
||||
assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
|
||||
), "moe_tp_size * moe_ep_size must equal to tp_size"
|
||||
|
||||
tik = time.time()
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
assert args.model_dir is not None
|
||||
convert_and_save_hf(args)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
print(f'Total time of converting checkpoints: {t}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
69
runtime/triton_trtllm/scripts/fill_template.py
Normal file
69
runtime/triton_trtllm/scripts/fill_template.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# /usr/bin/env python3
|
||||
from argparse import ArgumentParser
|
||||
from string import Template
|
||||
|
||||
|
||||
def split(string, delimiter):
|
||||
"""Split a string using delimiter. Supports escaping.
|
||||
|
||||
Args:
|
||||
string (str): The string to split.
|
||||
delimiter (str): The delimiter to split the string with.
|
||||
|
||||
Returns:
|
||||
list: A list of strings.
|
||||
"""
|
||||
result = []
|
||||
current = ""
|
||||
escape = False
|
||||
for char in string:
|
||||
if escape:
|
||||
current += char
|
||||
escape = False
|
||||
elif char == delimiter:
|
||||
result.append(current)
|
||||
current = ""
|
||||
elif char == "\\":
|
||||
escape = True
|
||||
else:
|
||||
current += char
|
||||
result.append(current)
|
||||
return result
|
||||
|
||||
|
||||
def main(file_path, substitutions, in_place):
|
||||
with open(file_path) as f:
|
||||
pbtxt = Template(f.read())
|
||||
|
||||
sub_dict = {
|
||||
"max_queue_size": 0,
|
||||
'max_queue_delay_microseconds': 0,
|
||||
}
|
||||
for sub in split(substitutions, ","):
|
||||
key, value = split(sub, ":")
|
||||
sub_dict[key] = value
|
||||
|
||||
assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}."
|
||||
|
||||
pbtxt = pbtxt.safe_substitute(sub_dict)
|
||||
|
||||
if in_place:
|
||||
with open(file_path, "w") as f:
|
||||
f.write(pbtxt)
|
||||
else:
|
||||
print(pbtxt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("file_path", help="path of the .pbtxt to modify")
|
||||
parser.add_argument(
|
||||
"substitutions",
|
||||
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
|
||||
)
|
||||
parser.add_argument("--in_place",
|
||||
"-i",
|
||||
action="store_true",
|
||||
help="do the operation in-place")
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
138
runtime/triton_trtllm/scripts/test_llm.py
Normal file
138
runtime/triton_trtllm/scripts/test_llm.py
Normal file
@@ -0,0 +1,138 @@
|
||||
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from tensorrt_llm.runtime import ModelRunnerCpp
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def parse_arguments(args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--input_text',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=["Born in north-east France, Soyer trained as a"])
|
||||
parser.add_argument('--tokenizer_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
parser.add_argument('--engine_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
parser.add_argument('--log_level', type=str, default="debug")
|
||||
parser.add_argument('--kv_cache_free_gpu_memory_fraction', type=float, default=0.6)
|
||||
parser.add_argument('--temperature', type=float, default=0.8)
|
||||
parser.add_argument('--top_k', type=int, default=50)
|
||||
parser.add_argument('--top_p', type=float, default=0.95)
|
||||
|
||||
return parser.parse_args(args=args)
|
||||
|
||||
|
||||
def parse_input(tokenizer,
|
||||
input_text=None,
|
||||
prompt_template=None):
|
||||
batch_input_ids = []
|
||||
for curr_text in input_text:
|
||||
if prompt_template is not None:
|
||||
curr_text = prompt_template.format(input_text=curr_text)
|
||||
input_ids = tokenizer.encode(
|
||||
curr_text)
|
||||
batch_input_ids.append(input_ids)
|
||||
|
||||
batch_input_ids = [
|
||||
torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
|
||||
]
|
||||
|
||||
logger.debug(f"Input token ids (batch_size = {len(batch_input_ids)}):")
|
||||
for i, input_ids in enumerate(batch_input_ids):
|
||||
logger.debug(f"Request {i}: {input_ids.tolist()}")
|
||||
|
||||
return batch_input_ids
|
||||
|
||||
|
||||
def main(args):
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
|
||||
prompt_template = "<|sos|>{input_text}<|task_id|>"
|
||||
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>")
|
||||
|
||||
batch_input_ids = parse_input(tokenizer=tokenizer,
|
||||
input_text=args.input_text,
|
||||
prompt_template=prompt_template)
|
||||
|
||||
input_lengths = [x.size(0) for x in batch_input_ids]
|
||||
|
||||
runner_kwargs = dict(
|
||||
engine_dir=args.engine_dir,
|
||||
rank=runtime_rank,
|
||||
max_output_len=1024,
|
||||
enable_context_fmha_fp32_acc=False,
|
||||
max_batch_size=len(batch_input_ids),
|
||||
max_input_len=max(input_lengths),
|
||||
kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
|
||||
cuda_graph_mode=False,
|
||||
gather_generation_logits=False,
|
||||
)
|
||||
|
||||
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = runner.generate(
|
||||
batch_input_ids=batch_input_ids,
|
||||
max_new_tokens=1024,
|
||||
end_id=end_id,
|
||||
pad_id=end_id,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
num_return_sequences=1,
|
||||
repetition_penalty=1.1,
|
||||
random_seed=42,
|
||||
streaming=False,
|
||||
output_sequence_lengths=True,
|
||||
output_generation_logits=False,
|
||||
return_dict=True,
|
||||
return_all_generated_tokens=False)
|
||||
torch.cuda.synchronize()
|
||||
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
|
||||
num_output_sents, num_beams, _ = output_ids.size()
|
||||
assert num_beams == 1
|
||||
beam = 0
|
||||
batch_size = len(input_lengths)
|
||||
num_return_sequences = num_output_sents // batch_size
|
||||
assert num_return_sequences == 1
|
||||
for i in range(batch_size * num_return_sequences):
|
||||
batch_idx = i // num_return_sequences
|
||||
seq_idx = i % num_return_sequences
|
||||
inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist()
|
||||
input_text = tokenizer.decode(inputs)
|
||||
print(f'Input [Text {batch_idx}]: \"{input_text}\"')
|
||||
output_begin = input_lengths[batch_idx]
|
||||
output_end = sequence_lengths[i][beam]
|
||||
outputs = output_ids[i][beam][output_begin:output_end].tolist()
|
||||
output_text = tokenizer.decode(outputs)
|
||||
print(f'Output [Text {batch_idx}]: \"{output_text}\"')
|
||||
logger.debug(str(outputs))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
main(args)
|
||||
122
runtime/triton_trtllm/streaming_inference.py
Normal file
122
runtime/triton_trtllm/streaming_inference.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import torch
|
||||
import os
|
||||
import argparse
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
import time
|
||||
from token2wav_dit import CosyVoice2_Token2Wav
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||
prompt_speech_tokens_list, prompt_text_list = [], []
|
||||
for item in batch:
|
||||
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||
prompt_audios_list.append(audio)
|
||||
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||
ids.append(item['id'])
|
||||
prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens'])
|
||||
prompt_text_list.append(item['prompt_text'])
|
||||
|
||||
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--enable-trt", action="store_true")
|
||||
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
|
||||
parser.add_argument("--batch-size", type=int, default=1)
|
||||
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
||||
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
||||
parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
|
||||
parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"])
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
dataset_name = args.dataset_name
|
||||
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||
|
||||
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
|
||||
|
||||
CHUNK_SIZE = 25
|
||||
token_frame_rate = 25
|
||||
OVERLAP_SIZE = 0
|
||||
|
||||
warmup_times = 3
|
||||
for _ in range(warmup_times):
|
||||
start_time = time.time()
|
||||
total_forward_count = 0
|
||||
for batch in data_loader:
|
||||
tts_speech_list = []
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
|
||||
|
||||
id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
|
||||
|
||||
assert prompt_audio_sample_rate == 16000
|
||||
|
||||
prompt_text = prompt_text_list[0]
|
||||
prompt_speech_tokens = prompt_speech_tokens_list[0]
|
||||
|
||||
semantic_token_ids_arr, token_offset = [], 0
|
||||
flow_prompt_speech_token_len = len(prompt_speech_tokens)
|
||||
|
||||
buffer = generated_speech_tokens
|
||||
output_wavs = []
|
||||
chunk_index = 0
|
||||
while True:
|
||||
if args.strategy == "equal":
|
||||
this_chunk_size = CHUNK_SIZE
|
||||
elif args.strategy == "exponential":
|
||||
this_chunk_size = token_frame_rate * (2 ** chunk_index)
|
||||
|
||||
if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
|
||||
wavs = token2wav_model.forward_streaming(
|
||||
buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len],
|
||||
False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio,
|
||||
prompt_audio_sample_rate=prompt_audio_sample_rate
|
||||
)
|
||||
buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
|
||||
|
||||
output_wavs.append(wavs)
|
||||
total_forward_count += 1
|
||||
chunk_index += 1
|
||||
|
||||
else:
|
||||
wavs = token2wav_model.forward_streaming(
|
||||
buffer, True, request_id=id, speaker_id=f"{id}",
|
||||
prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate
|
||||
)
|
||||
output_wavs.append(wavs)
|
||||
total_forward_count += 1
|
||||
# chunk_index += 1
|
||||
break
|
||||
|
||||
for i, wav in enumerate(output_wavs):
|
||||
output_wavs[i] = wav.cpu().numpy().squeeze()
|
||||
|
||||
audios = output_wavs
|
||||
reconstructed_audio = np.concatenate(audios)
|
||||
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
if _ == 0:
|
||||
token2wav_model.speaker_cache = {}
|
||||
print(f"Warmup time: {end_time - start_time} seconds")
|
||||
print("clear speaker cache")
|
||||
elif _ == 1:
|
||||
print(f"Cost time without speaker cache: {end_time - start_time} seconds")
|
||||
else:
|
||||
print(f"Cost time with speaker cache: {end_time - start_time} seconds")
|
||||
print(f"Total flow matching forward calls: {total_forward_count}")
|
||||
335
runtime/triton_trtllm/token2wav.py
Normal file
335
runtime/triton_trtllm/token2wav.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Example Usage
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python3 token2wav.py --enable-trt || exit 1
|
||||
"""
|
||||
import torch
|
||||
from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
|
||||
from flashcosyvoice.modules.hifigan import HiFTGenerator
|
||||
from flashcosyvoice.utils.audio import mel_spectrogram
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
import onnxruntime
|
||||
import s3tokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
from datasets import load_dataset
|
||||
import torchaudio
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import queue
|
||||
import time
|
||||
|
||||
|
||||
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
import tensorrt as trt
|
||||
logging.info("Converting onnx to trt...")
|
||||
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
profile = builder.create_optimization_profile()
|
||||
# load onnx model
|
||||
with open(onnx_model, "rb") as f:
|
||||
if not parser.parse(f.read()):
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
raise ValueError('failed to parse {}'.format(onnx_model))
|
||||
# set input shapes
|
||||
for i in range(len(trt_kwargs['input_names'])):
|
||||
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
||||
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
||||
# set input and output data type
|
||||
for i in range(network.num_inputs):
|
||||
input_tensor = network.get_input(i)
|
||||
input_tensor.dtype = tensor_dtype
|
||||
for i in range(network.num_outputs):
|
||||
output_tensor = network.get_output(i)
|
||||
output_tensor.dtype = tensor_dtype
|
||||
config.add_optimization_profile(profile)
|
||||
engine_bytes = builder.build_serialized_network(network, config)
|
||||
# save trt engine
|
||||
with open(trt_model, "wb") as f:
|
||||
f.write(engine_bytes)
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
self.trt_engine = trt_engine
|
||||
self.device = device
|
||||
for _ in range(trt_concurrent):
|
||||
trt_context = trt_engine.create_execution_context()
|
||||
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
|
||||
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||
self.trt_context_pool.put([trt_context, trt_stream])
|
||||
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.trt_context_pool.get(), self.trt_engine
|
||||
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
|
||||
|
||||
class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0):
|
||||
super().__init__()
|
||||
self.device_id = device_id
|
||||
self.device = f"cuda:{device_id}"
|
||||
|
||||
self.flow = CausalMaskedDiffWithXvec()
|
||||
self.flow.half()
|
||||
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
|
||||
self.hift = HiFTGenerator()
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
|
||||
|
||||
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval()
|
||||
|
||||
gpu = "l20"
|
||||
if enable_trt:
|
||||
self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
||||
1,
|
||||
True)
|
||||
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||
f'{model_dir}/campplus.onnx',
|
||||
1,
|
||||
False)
|
||||
|
||||
def forward_spk_embedding(self, spk_feat):
|
||||
if isinstance(self.spk_model, onnxruntime.InferenceSession):
|
||||
return self.spk_model.run(
|
||||
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
||||
)[0].flatten().tolist()
|
||||
else:
|
||||
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
|
||||
# NOTE need to synchronize when switching stream
|
||||
with torch.cuda.device(self.device_id):
|
||||
torch.cuda.current_stream().synchronize()
|
||||
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
|
||||
batch_size = spk_feat.size(0)
|
||||
|
||||
with stream:
|
||||
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
|
||||
output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
|
||||
|
||||
data_ptrs = [spk_feat.contiguous().data_ptr(),
|
||||
output_tensor.contiguous().data_ptr()]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
|
||||
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
# run trt engine
|
||||
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.spk_model.release_estimator(spk_model, stream)
|
||||
|
||||
return output_tensor.cpu().numpy().flatten().tolist()
|
||||
|
||||
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
|
||||
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
|
||||
trt_kwargs = self.get_spk_trt_kwargs()
|
||||
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
|
||||
import tensorrt as trt
|
||||
with open(spk_model, 'rb') as f:
|
||||
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
|
||||
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_spk_trt_kwargs(self):
|
||||
min_shape = [(1, 4, 80)]
|
||||
opt_shape = [(1, 500, 80)]
|
||||
max_shape = [(1, 3000, 80)]
|
||||
input_names = ["input"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_bs=2, max_batch_size=16)
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16)
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64):
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
|
||||
opt_shape = [(opt_bs * 2, 80, 500), (opt_bs * 2, 1, 500), (opt_bs * 2, 80, 500), (opt_bs * 2, 80, 500), (opt_bs * 2,), (opt_bs * 2, 80)]
|
||||
max_shape = [(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2,),
|
||||
(max_batch_size * 2, 80)]
|
||||
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
|
||||
prompt_speech_tokens_list, prompt_speech_mels_list = [], []
|
||||
for audio in prompt_audios_list:
|
||||
assert len(audio.shape) == 1
|
||||
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
|
||||
prompt_speech_mels_list.append(log_mel)
|
||||
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
|
||||
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
|
||||
prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
|
||||
)
|
||||
for i in range(len(prompt_speech_tokens)):
|
||||
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
||||
prompt_speech_tokens_list.append(speech_tokens_i)
|
||||
return prompt_speech_tokens_list
|
||||
|
||||
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
||||
spk_emb_for_flow = []
|
||||
for audio in prompt_audios_list:
|
||||
assert len(audio.shape) == 1
|
||||
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
|
||||
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
|
||||
spk_emb = self.forward_spk_embedding(spk_feat)
|
||||
|
||||
spk_emb_for_flow.append(spk_emb)
|
||||
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
||||
return spk_emb_for_flow
|
||||
|
||||
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
||||
prompt_mels_for_flow = []
|
||||
prompt_mels_lens_for_flow = []
|
||||
for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
|
||||
assert len(audio.shape) == 1
|
||||
audio = audio.unsqueeze(0)
|
||||
if sample_rate != 24000:
|
||||
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
|
||||
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
|
||||
mel_len = mel.shape[0]
|
||||
prompt_mels_for_flow.append(mel)
|
||||
prompt_mels_lens_for_flow.append(mel_len)
|
||||
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
||||
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
||||
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
||||
|
||||
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor,
|
||||
prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
||||
batch_size = prompt_mels_for_flow.shape[0]
|
||||
flow_inputs = []
|
||||
flow_inputs_lens = []
|
||||
for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
|
||||
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
|
||||
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
|
||||
|
||||
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
|
||||
flow_inputs_lens = torch.tensor(flow_inputs_lens)
|
||||
|
||||
with torch.amp.autocast(self.device, dtype=torch.float16):
|
||||
generated_mels, generated_mels_lens = self.flow(
|
||||
flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
|
||||
prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device),
|
||||
streaming=False, finalize=True
|
||||
)
|
||||
|
||||
return generated_mels, generated_mels_lens
|
||||
|
||||
def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
|
||||
batch_size = generated_mels.shape[0]
|
||||
generated_wavs = []
|
||||
for i in range(batch_size):
|
||||
mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
|
||||
wav, _ = self.hift(speech_feat=mel)
|
||||
generated_wavs.append(wav)
|
||||
return generated_wavs
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
):
|
||||
# assert all item in prompt_audios_sample_rate is 16000
|
||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||
|
||||
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
||||
|
||||
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
||||
|
||||
generated_mels, generated_mels_lens = self.forward_flow(
|
||||
prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||
|
||||
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
||||
|
||||
return generated_wavs
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||
for _, item in enumerate(batch):
|
||||
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||
prompt_audios_list.append(audio)
|
||||
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||
ids.append(item['id'])
|
||||
|
||||
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--enable-trt", action="store_true")
|
||||
parser.add_argument("--model-dir", type=str, default="./CosyVoice2-0.5B")
|
||||
parser.add_argument("--batch-size", type=int, default=4)
|
||||
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
||||
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
||||
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
||||
# mkdir output_dir if not exists
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
dataset_name = "yuekai/seed_tts_cosy2"
|
||||
|
||||
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||
|
||||
for _ in range(args.warmup):
|
||||
start_time = time.time()
|
||||
|
||||
for batch in data_loader:
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
||||
|
||||
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
for id, wav in zip(ids, generated_wavs):
|
||||
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
||||
|
||||
end_time = time.time()
|
||||
epoch_time = end_time - start_time
|
||||
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||
1
runtime/triton_trtllm/token2wav_dit.py
Symbolic link
1
runtime/triton_trtllm/token2wav_dit.py
Symbolic link
@@ -0,0 +1 @@
|
||||
model_repo/token2wav_dit/1/token2wav_dit.py
|
||||
@@ -29,27 +29,24 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
|
||||
for utt in tqdm(utt_list):
|
||||
data = open(utt2wav[utt], 'rb').read()
|
||||
data_list.append(data)
|
||||
wav_list = [utt2wav[utt] for utt in utt_list]
|
||||
text_list = [utt2text[utt] for utt in utt_list]
|
||||
spk_list = [utt2spk[utt] for utt in utt_list]
|
||||
uttembedding_list = [utt2embedding[utt] for utt in utt_list]
|
||||
spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
|
||||
speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
|
||||
if args.dpo:
|
||||
reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
|
||||
|
||||
# 保存到parquet,utt2parquet_file,spk2parquet_file
|
||||
df = pd.DataFrame()
|
||||
df['utt'] = utt_list
|
||||
df['wav'] = wav_list
|
||||
df['audio_data'] = data_list
|
||||
df['text'] = text_list
|
||||
df['spk'] = spk_list
|
||||
df['utt_embedding'] = uttembedding_list
|
||||
df['spk_embedding'] = spkembedding_list
|
||||
df['speech_token'] = speech_token_list
|
||||
df['wav'] = [utt2wav[utt] for utt in utt_list]
|
||||
df['text'] = [utt2text[utt] for utt in utt_list]
|
||||
df['spk'] = [utt2spk[utt] for utt in utt_list]
|
||||
if utt2embedding is not None:
|
||||
df['utt_embedding'] = [utt2embedding[utt] for utt in utt_list]
|
||||
if spk2embedding is not None:
|
||||
df['spk_embedding'] = [spk2embedding[utt2spk[utt]] for utt in utt_list]
|
||||
if utt2speech_token is not None:
|
||||
df['speech_token'] = [utt2speech_token[utt] for utt in utt_list]
|
||||
if utt2instruct is not None:
|
||||
df['instruct'] = [utt2instruct[utt] for utt in utt_list]
|
||||
if args.dpo:
|
||||
df['reject_speech_token'] = reject_speech_token_list
|
||||
df['reject_speech_token'] = [utt2reject_speech_token.get(utt, None) for utt in utt_list]
|
||||
df.to_parquet(parquet_file)
|
||||
with open(utt2parquet_file, 'w') as f:
|
||||
json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
|
||||
@@ -91,11 +88,19 @@ if __name__ == "__main__":
|
||||
for l in f:
|
||||
l = l.replace('\n', '').split()
|
||||
utt2spk[l[0]] = l[1]
|
||||
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
|
||||
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
|
||||
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
|
||||
if os.path.exists('{}/instruct'.format(args.src_dir)):
|
||||
utt2instruct = {}
|
||||
with open('{}/instruct'.format(args.src_dir)) as f:
|
||||
for l in f:
|
||||
l = l.replace('\n', '').split()
|
||||
utt2instruct[l[0]] = ' '.join(l[1:])
|
||||
else:
|
||||
utt2instruct = None
|
||||
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) if os.path.exists('{}/utt2embedding.pt'.format(args.src_dir)) else None
|
||||
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) if os.path.exists('{}/spk2embedding.pt'.format(args.src_dir)) else None
|
||||
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) if os.path.exists('{}/utt2speech_token.pt'.format(args.src_dir)) else None
|
||||
if args.dpo:
|
||||
utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir))
|
||||
utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir)) if os.path.exists('{}_reject/utt2speech_token.pt'.format(args.src_dir)) else {}
|
||||
utts = list(utt2wav.keys())
|
||||
|
||||
# Using process pool to speedup
|
||||
|
||||
@@ -4,20 +4,36 @@ from vllm import ModelRegistry
|
||||
from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
|
||||
ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
|
||||
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
from cosyvoice.utils.common import set_all_random_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
|
||||
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
||||
def cosyvoice2_example():
|
||||
""" CosyVoice2 vllm usage
|
||||
"""
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
|
||||
for i in tqdm(range(100)):
|
||||
set_all_random_seed(i)
|
||||
for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
||||
for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
|
||||
continue
|
||||
|
||||
|
||||
def cosyvoice3_example():
|
||||
""" CosyVoice3 vllm usage
|
||||
"""
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B', load_trt=True, load_vllm=True, fp16=False)
|
||||
for i in tqdm(range(100)):
|
||||
set_all_random_seed(i)
|
||||
for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
continue
|
||||
|
||||
|
||||
def main():
|
||||
# cosyvoice2_example()
|
||||
cosyvoice3_example()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user