142 Commits
v2.0 ... main

Author SHA1 Message Date
Xiang Lyu
f08872a82f Merge pull request #1814 from hexisyztem/main
[BUG FIX] 使用 float64 避免精度误差问题,弃用 CPU 计算,避免拖累性能
2026-02-04 13:10:40 +08:00
lyuxiang.lx
f2ddcbe7f9 fix ras 2026-02-04 11:44:39 +08:00
禾息
0d990d6074 [BUG FIX] 使用 float64 避免精度误差问题,弃用 CPU 计算,避免拖累性能 2026-01-30 18:10:36 +08:00
lyuxiang.lx
c93d3dda01 update 2026-01-29 17:31:04 +00:00
lyuxiang.lx
84e41729ea update 2026-01-29 10:29:22 +00:00
lyuxiang.lx
f26cde56df update 2026-01-29 06:13:36 +00:00
lyuxiang.lx
66b80dbccb online feature 2026-01-28 15:19:07 +00:00
lyuxiang.lx
1822c5c908 fix fm train bug 2026-01-19 10:48:24 +08:00
lyuxiang.lx
1dcc59676f add japanese 2026-01-12 07:06:45 +00:00
lyuxiang.lx
7fdd80dc64 fix padding 2026-01-07 07:12:52 +00:00
lyuxiang.lx
f97d50d559 fix sequence logic 2026-01-07 07:06:04 +00:00
lyuxiang.lx
652132ebaa Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2026-01-05 11:02:01 +08:00
lyuxiang.lx
1ceb2b7b1e fix bug 2026-01-05 11:01:19 +08:00
Xiang Lyu
55e3e370a0 Merge pull request #1758 from orbisai0security/fix/V-005-pickle-deserialization
[Security] Fix CRITICAL vulnerability: V-005
2025-12-31 11:45:44 +08:00
lyuxiang.lx
46788c7379 update readme 2025-12-31 11:28:47 +08:00
Xiang Lyu
881177287c Merge pull request #1640 from Jzz1943/main
support vLLM >=0.11.0 (V1 engine) for better performance
2025-12-31 10:40:02 +08:00
Xiang Lyu
f88a14e41d Merge branch 'main' into main 2025-12-31 10:37:18 +08:00
lyuxiang.lx
dac6566fc3 update tokens 2025-12-30 16:14:36 +00:00
lyuxiang.lx
cc91e40db8 more silent_token 2025-12-30 15:53:33 +00:00
lyuxiang.lx
ab7f1f4a86 update silent token 2025-12-30 15:12:18 +00:00
lyuxiang.lx
e15222b17c refine code 2025-12-30 09:24:52 +00:00
lyuxiang.lx
cfa1c115b2 add silent_token 2025-12-30 09:18:17 +00:00
lyuxiang.lx
dd5cdb6ebf Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-12-30 14:03:52 +08:00
lyuxiang.lx
2d7ef0b719 remove instruct warning 2025-12-30 12:36:18 +08:00
Xiang Lyu
ba5db602a9 Merge pull request #1622 from GoyoUijin/Fix/token2wav-cache-thread-unsafe
fix triton token2wav model cache thread unsafety
2025-12-30 11:24:25 +08:00
orbisai0security
5b94675f62 fix: resolve critical vulnerability V-005
Automatically generated security fix
2025-12-29 13:25:05 +00:00
lyuxiang.lx
4c19646b9a update dataset 2025-12-29 12:46:34 +00:00
lyuxiang.lx
63a06227d1 Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-12-29 18:32:54 +08:00
lyuxiang.lx
3b44913782 fix bug 2025-12-29 10:30:54 +00:00
Xiang Lyu
055f64d002 Merge pull request #1728 from yigitcatak/main
fixed a typo in Dockerfile regarding environment variable
2025-12-29 18:24:39 +08:00
lyuxiang.lx
4d7295a9a7 fix cosyvoice3 training 2025-12-29 10:03:14 +00:00
lyuxiang.lx
8524c81acd fix cv3 train 2025-12-29 10:03:14 +00:00
Xiang Lyu
a14e063ead Merge pull request #1722 from majiayu000/fix/issue-1683-ja-language-tag
docs: fix Japanese language tag in example.py comment
2025-12-29 16:12:54 +08:00
lyuxiang.lx
2db78e7058 fix export_jit.py 2025-12-23 17:23:23 +08:00
lyuxiang.lx
7538c6a73d update export_jit 2025-12-23 15:23:29 +08:00
yigitcatak
823ae2c60d fix a typo in Dockerfile 2025-12-23 16:22:48 +09:00
lyuxiang.lx
59cb2bf16c fix jit_export 2025-12-23 14:06:40 +08:00
majiayu000
80bebb1978 docs: fix Japanese language tag in example.py comment
Changed <|jp|> to <|ja|> to match the actual tokenizer implementation.

The LANGUAGES dict in cosyvoice/tokenizer/tokenizer.py defines 'ja' for Japanese,
not 'jp'. This fixes the misleading comment that could cause issues like #621.

Fixes #1683
2025-12-22 19:11:47 +08:00
Xiang Lyu
bc34459bb8 Merge pull request #1693 from whiteshirt0429/main
Fix CosyVoice3 config error
2025-12-22 13:56:19 +08:00
lyuxiang.lx
9f27b42cd9 update 2025-12-17 18:58:50 +08:00
lyuxiang.lx
a7d6e2251a update libritts cosyvoice3.yaml 2025-12-17 17:15:10 +08:00
lyuxiang.lx
7baefaf0f2 update libritts cosyvoice3.yaml 2025-12-17 17:14:17 +08:00
di.wu
ff0d05c380 Fix CosyVoice3 config error 2025-12-17 14:57:17 +08:00
lyuxiang.lx
f5816b4e51 update readme 2025-12-17 03:16:00 +00:00
lyuxiang.lx
8b54619760 update 2025-12-16 15:00:03 +08:00
lyuxiang.lx
2abd42220e add x-transformer version 2025-12-16 14:45:02 +08:00
lyuxiang.lx
2d6bb9bd80 Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-12-15 16:16:35 +08:00
lyuxiang.lx
0b80c0746a update 2025-12-15 16:13:53 +08:00
Xiangang Li
e98b828f33 Update README.md 2025-12-15 15:54:32 +08:00
lyuxiang.lx
4d4c787be0 update 2025-12-15 15:33:52 +08:00
lyuxiang.lx
781a49acb4 update metric 2025-12-15 14:52:32 +08:00
lyuxiang.lx
9476a063b3 update metric 2025-12-15 14:48:17 +08:00
Xiang Lyu
3426ceb70f Merge pull request #1671 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-12-15 14:00:40 +08:00
hengwu.zty
a460960ade [mod]readme 2025-12-15 13:53:11 +08:00
hengwu.zty
f51f5c5c6a [mod]readme 2025-12-15 13:39:56 +08:00
hengwu.zty
f11ba4024c [mod]readme 2025-12-15 13:24:44 +08:00
lyuxiang.lx
089343ab0a update 2025-12-15 12:44:06 +08:00
Xiang Lyu
0c50894d49 Merge pull request #1670 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-12-15 12:17:39 +08:00
lyuxiang.lx
95d56cba64 update 2025-12-15 03:56:42 +00:00
lyuxiang.lx
095f7bad55 update readme 2025-12-15 03:53:10 +00:00
lyuxiang.lx
a6eb2c56da update dingding 2025-12-14 14:00:40 +00:00
lyuxiang.lx
ca3b054a52 fix bistream bug 2025-12-12 14:28:27 +00:00
lyuxiang.lx
b02d7e61f7 update prompt wav 2025-12-12 14:28:26 +00:00
hengwu.zty
6b6a5a7bd1 Merge branch 'dev/lyuxiang.lx' of http://gitlab.alibaba-inc.com/NLS/CosyVoice into dev/lyuxiang.lx 2025-12-12 18:39:18 +08:00
hengwu.zty
5640545406 update readme 2025-12-12 18:38:30 +08:00
lyuxiang.lx
5bc4b23f02 use amp in flow 2025-12-12 07:44:36 +00:00
lyuxiang.lx
ebef63066f add instruct 2025-12-12 07:44:36 +00:00
quantu
3298d6f3e3 readme 2025-12-11 15:55:26 +08:00
lyuxiang.lx
f21c4764ec fix bug 2025-12-11 06:04:36 +00:00
lyuxiang.lx
927addadd8 fix lint 2025-12-10 02:17:00 +00:00
lyuxiang.lx
a051a09ba4 remove unncessary code 2025-12-09 15:41:02 +00:00
lyuxiang.lx
0c65d3c7ab use automodel 2025-12-09 15:15:05 +00:00
lyuxiang.lx
56d9876037 fix bug 2025-12-09 07:57:10 +00:00
lyuxiang.lx
b35ece675b dit need higher trt version 2025-12-08 11:28:45 +00:00
lyuxiang.lx
59f02cb85d add vllm example 2025-12-08 11:14:12 +00:00
lyuxiang.lx
b4dd67a8af add cosyvoice3 vllm example 2025-12-08 10:55:53 +00:00
lyuxiang.lx
bfa835a74b add cosyvoice3 inference code 2025-12-08 10:04:11 +00:00
lyuxiang.lx
622a3a19b0 use wav file rather than tensor 2025-12-08 08:43:09 +00:00
lyuxiang.lx
d985100326 Merge branch 'main' into dev/lyuxiang.lx 2025-12-04 18:00:17 +08:00
zhongze.jiang
6816fc6a6f support vLLM >=0.11.0 (V1 engine only) 2025-11-10 16:30:42 +08:00
김의진
e8bf717333 Fix: generate token2wav_request_id from cosyvoice2
- Since all token2wav requests within a single cosyvoice2 request must share the same request_id, modify the logic so that a new request_id is generated only if it does not already exist, and ensure that the same request_id is sent consistently.
2025-10-27 18:12:17 +09:00
김의진
fa2781405f Revert "fix triton token2wav model cache thread unsafety"
This reverts commit cd26dd1932.
2025-10-27 18:07:30 +09:00
김의진
cd26dd1932 fix triton token2wav model cache thread unsafety 2025-10-27 17:20:14 +09:00
Xiang Lyu
6e01309e01 Merge pull request #1598 from yuekaizhang/streaming
[Runtime] StepAudio2 Streaming DiT Token2Wav Integration
2025-10-21 15:30:45 +08:00
yuekaiz
1fc8435146 add disaggregated deployment 2025-10-16 15:58:22 +08:00
yuekaiz
a224be6117 fix lint 2025-10-09 15:18:09 +08:00
yuekaiz
33aee03ed5 fix lint 2025-10-09 15:13:43 +08:00
root
8811e9f33a fix white space 2025-10-09 14:49:22 +08:00
root
807bb6ee0b add dit results 2025-10-08 22:01:24 +08:00
root
aceede59ba fix bug 2025-10-08 18:13:09 +08:00
root
7cbd490253 add docker compose for streaming tts 2025-10-08 17:20:04 +08:00
root
a019a2504e clean code 2025-10-08 16:48:00 +08:00
root
f186ec3338 clean code 2025-10-08 15:21:52 +08:00
root
988d395162 mark multi client 2025-10-08 14:06:19 +08:00
lyuxiang.lx
4d60ff6abc Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-09-28 15:27:41 +08:00
lyuxiang.lx
be005c825f fix vllm transformer version bug 2025-09-28 15:25:42 +08:00
root
79116ac32e remove cache router 2025-09-26 15:14:31 +08:00
root
31a0adc73d mark stateless token2wav 2025-09-26 14:51:41 +08:00
yuekaiz
482464ea27 add streaming dit 2025-09-24 15:18:01 +08:00
root
444b7ff5df fix cache shallow copy 2025-09-19 13:48:32 +08:00
yuekaiz
b207c60885 init step-audio2 token2wav 2025-09-18 19:07:23 +08:00
Xiang Lyu
0b357ba25d Merge pull request #1583 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-09-17 10:57:18 +08:00
Xiang Lyu
0867ebcb8c Merge pull request #1566 from yuekaizhang/streaming
[runtime: TRT-LLM] support prompt audio cache & offline inference mode
2025-09-12 11:52:28 +08:00
root
52556a6de9 fix lint 2025-09-08 09:59:58 +00:00
root
66ef5a097b fix lint 2025-09-08 09:55:33 +00:00
root
cc1991870b add cosyvoice2 offline inference 2025-09-08 17:37:33 +08:00
yuekaiz
8ded65e611 set use_spk2info_cache=False 2025-09-05 14:03:23 +08:00
yuekaiz
6971536358 add prompt audio cache 2025-09-05 13:55:41 +08:00
Xiang Lyu
86e7c2d731 Merge pull request #1561 from yuekaizhang/streaming
[Runtime] Support Streaming TTS for Triton + TensorRT-LLM runtime
2025-09-04 09:48:43 +08:00
Yuekai Zhang
8a4309d89c fix space 2025-09-03 03:51:38 -07:00
Yuekai Zhang
ad257b06e3 fix lint 2025-09-03 03:45:17 -07:00
yuekaiz
633b991290 update readme 2025-09-03 17:42:14 +08:00
yuekaiz
e04699c6da add spk trt 2025-09-03 11:44:36 +08:00
root
73d261dd48 support streaming tts 2025-09-02 18:32:12 +08:00
Xiang Lyu
b7ec6c4678 Merge pull request #1459 from huiwq1990/feat-fixpip
pip install disable cache
2025-09-01 17:34:19 +08:00
lyuxiang.lx
f76f5abcc1 update 2025-08-22 14:42:34 +08:00
lyuxiang.lx
6b5eef62cc update 2025-08-22 14:24:27 +08:00
lyuxiang.lx
dc96e4c984 update 2025-08-21 21:03:58 +08:00
lyuxiang.lx
70991d7327 update 2025-08-21 20:08:08 +08:00
lyuxiang.lx
8c96081f94 update 2025-08-21 11:45:36 +08:00
lyuxiang.lx
dd2d926147 update 2025-08-20 16:55:03 +08:00
lyuxiang.lx
da41f6175b update 2025-08-19 18:53:18 +08:00
lyuxiang.lx
e3c2400abb update 2025-08-19 16:23:59 +08:00
lyuxiang.lx
a976519ada fix transformer version 2025-08-14 16:24:24 +08:00
lyuxiang.lx
cf615011ce update readme 2025-08-07 16:55:46 +08:00
Xiang Lyu
9ddb9e4a83 Merge pull request #1491 from yuekaizhang/rl
[recipe] Add GRPO training recipe for cosyvoice2 llm
2025-08-07 16:45:49 +08:00
Xiang Lyu
0a496c18f7 Merge pull request #1507 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-08-05 11:18:46 +08:00
lyuxiang.lx
05bdf4c769 add contributor info 2025-08-05 11:15:42 +08:00
Xiang Lyu
1850e2a56e Merge pull request #1489 from yuekaizhang/triton
[runtime] Support Cosyvoice2 Nvidia TensorRT-LLM Inference Solution
2025-08-05 10:21:40 +08:00
root
47e4137651 update test set 2025-07-30 11:07:50 +00:00
root
0bc48c1180 update readme 2025-07-30 11:05:49 +00:00
root
62d082634e fix lint 2025-07-29 08:40:51 +00:00
root
07cbc51cd1 fix lint 2025-07-29 08:39:41 +00:00
root
d1c354eac7 add huggingface to pretrained 2025-07-29 07:54:42 +00:00
yuekaiz
1b8d194b67 fix commit 2025-07-29 12:01:55 +08:00
yuekaiz
b44f121102 update readme 2025-07-29 11:58:23 +08:00
yuekaiz
dc196df940 fix decoupled mode 2025-07-29 11:13:07 +08:00
Yuekai Zhang
178da09993 clean code 2025-07-27 23:33:10 -07:00
lyuxiang.lx
11515d0d5a use bf16 for amp 2025-07-28 11:55:38 +08:00
Yuekai Zhang
5427c274e3 add triton solution 2025-07-22 06:50:13 -07:00
huiwq1990
3387f07266 pip install disable cache
Signed-off-by: huiwq1990 <huiwq1990@163.com>
2025-07-17 10:35:44 +08:00
lyuxiang.lx
b048a2d6db fix dpo bug 2025-07-16 20:09:10 +08:00
101 changed files with 11660 additions and 1067 deletions

View File

@@ -52,5 +52,5 @@ jobs:
set -eux
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
flake8 --version
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F722,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
if [ $? != 0 ]; then exit 1; fi

198
README.md
View File

@@ -1,50 +1,52 @@
[![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)](https://github.com/Akshay090/svg-banners)
![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)
## 👉🏻 CosyVoice 👈🏻
**CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/abs/2505.17589); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
**Fun-CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/pdf/2505.17589); [Modelscope](https://www.modelscope.cn/models/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [Huggingface](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B)
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/pdf/2412.10117); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B)
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M)
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice-300M); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice-300M)
## Highlight🔥
**CosyVoice 2.0** has been released! Compared to version 1.0, the new version offers more accurate, more stable, faster, and better speech generation capabilities.
### Multilingual
- **Supported Language**: Chinese, English, Japanese, Korean, Chinese dialects (Cantonese, Sichuanese, Shanghainese, Tianjinese, Wuhanese, etc.)
- **Crosslingual & Mixlingual**Support zero-shot voice cloning for cross-lingual and code-switching scenarios.
### Ultra-Low Latency
- **Bidirectional Streaming Support**: CosyVoice 2.0 integrates offline and streaming modeling technologies.
- **Rapid First Packet Synthesis**: Achieves latency as low as 150ms while maintaining high-quality audio output.
### High Accuracy
- **Improved Pronunciation**: Reduces pronunciation errors by 30% to 50% compared to CosyVoice 1.0.
- **Benchmark Achievements**: Attains the lowest character error rate on the hard test set of the Seed-TTS evaluation set.
### Strong Stability
- **Consistency in Timbre**: Ensures reliable voice consistency for zero-shot and cross-language speech synthesis.
- **Cross-language Synthesis**: Marked improvements compared to version 1.0.
### Natural Experience
- **Enhanced Prosody and Sound Quality**: Improved alignment of synthesized audio, raising MOS evaluation scores from 5.4 to 5.53.
- **Emotional and Dialectal Flexibility**: Now supports more granular emotional controls and accent adjustments.
**Fun-CosyVoice 3.0** is an advanced text-to-speech (TTS) system based on large language models (LLM), surpassing its predecessor (CosyVoice 2.0) in content consistency, speaker similarity, and prosody naturalness. It is designed for zero-shot multilingual speech synthesis in the wild.
### Key Features
- **Language Coverage**: Covers 9 common languages (Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian), 18+ Chinese dialects/accents (Guangdong, Minnan, Sichuan, Dongbei, Shan3xi, Shan1xi, Shanghai, Tianjin, Shandong, Ningxia, Gansu, etc.) and meanwhile supports both multi-lingual/cross-lingual zero-shot voice cloning.
- **Content Consistency & Naturalness**: Achieves state-of-the-art performance in content consistency, speaker similarity, and prosody naturalness.
- **Pronunciation Inpainting**: Supports pronunciation inpainting of Chinese Pinyin and English CMU phonemes, providing more controllability and thus suitable for production use.
- **Text Normalization**: Supports reading of numbers, special symbols and various text formats without a traditional frontend module.
- **Bi-Streaming**: Support both text-in streaming and audio-out streaming, and achieves latency as low as 150ms while maintaining high-quality audio output.
- **Instruct Support**: Supports various instructions such as languages, dialects, emotions, speed, volume, etc.
## Roadmap
- [x] 2025/12
- [x] release Fun-CosyVoice3-0.5B-2512 base model, rl model and its training/inference script
- [x] release Fun-CosyVoice3-0.5B modelscope gradio space
- [x] 2025/08
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
- [x] 2025/07
- [x] release cosyvoice 3.0 eval set
- [x] release Fun-CosyVoice 3.0 eval set
- [x] 2025/05
- [x] add cosyvoice 2.0 vllm support
- [x] add CosyVoice2-0.5B vllm support
- [x] 2024/12
- [x] 25hz cosyvoice 2.0 released
- [x] 25hz CosyVoice2-0.5B released
- [x] 2024/09
- [x] 25hz cosyvoice base model
- [x] 25hz cosyvoice voice conversion model
- [x] 25hz CosyVoice-300M base model
- [x] 25hz CosyVoice-300M voice conversion function
- [x] 2024/08
@@ -57,6 +59,27 @@
- [x] WeTextProcessing support when ttsfrd is not available
- [x] Fastapi server and client
## Evaluation
| Model | Open-Source | Model Size | test-zh<br>CER (%) ↓ | test-zh<br>SS (%) ↑ | test-en<br>WER (%) ↓ | test-en<br>SS (%) ↑ | test-hard<br>CER (%) ↓ | test-hard<br>SS (%) ↑ |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Human | - | - | 1.26 | 75.5 | 2.14 | 73.4 | - | - |
| Seed-TTS | ❌ | - | 1.12 | 79.6 | 2.25 | 76.2 | 7.59 | 77.6 |
| MiniMax-Speech | ❌ | - | 0.83 | 78.3 | 1.65 | 69.2 | - | - |
| F5-TTS | ✅ | 0.3B | 1.52 | 74.1 | 2.00 | 64.7 | 8.67 | 71.3 |
| Spark TTS | ✅ | 0.5B | 1.2 | 66.0 | 1.98 | 57.3 | - | - |
| CosyVoice2 | ✅ | 0.5B | 1.45 | 75.7 | 2.57 | 65.9 | 6.83 | 72.4 |
| FireRedTTS2 | ✅ | 1.5B | 1.14 | 73.2 | 1.95 | 66.5 | - | - |
| Index-TTS2 | ✅ | 1.5B | 1.03 | 76.5 | 2.23 | 70.6 | 7.12 | 75.5 |
| VibeVoice-1.5B | ✅ | 1.5B | 1.16 | 74.4 | 3.04 | 68.9 | - | - |
| VibeVoice-Realtime | ✅ | 0.5B | - | - | 2.05 | 63.3 | - | - |
| HiggsAudio-v2 | ✅ | 3B | 1.50 | 74.0 | 2.44 | 67.7 | - | - |
| VoxCPM | ✅ | 0.5B | 0.93 | 77.2 | 1.85 | 72.9 | 8.87 | 73.0 |
| GLM-TTS | ✅ | 1.5B | 1.03 | 76.1 | - | - | - | - |
| GLM-TTS RL | ✅ | 1.5B | 0.89 | 76.4 | - | - | - | - |
| Fun-CosyVoice3-0.5B-2512 | ✅ | 0.5B | 1.21 | 78.0 | 2.24 | 71.8 | 6.71 | 75.8 |
| Fun-CosyVoice3-0.5B-2512_RL | ✅ | 0.5B | 0.81 | 77.4 | 1.68 | 69.5 | 5.44 | 75.0 |
## Install
@@ -87,26 +110,26 @@
### Model download
We strongly recommend that you download our pretrained `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
``` python
# SDK模型下载
# modelscope SDK model download
from modelscope import snapshot_download
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
```
``` sh
# git模型下载请确保已安装git lfs
mkdir -p pretrained_models
git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
# for oversea users, huggingface SDK model download
from huggingface_hub import snapshot_download
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
snapshot_download('FunAudioLLM/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
snapshot_download('FunAudioLLM/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('FunAudioLLM/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('FunAudioLLM/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('FunAudioLLM/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
```
Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.
@@ -122,94 +145,28 @@ pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
### Basic Usage
We strongly recommend using `CosyVoice2-0.5B` for better performance.
Follow the code below for detailed usage of each model.
``` python
import sys
sys.path.append('third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
import torchaudio
We strongly recommend using `Fun-CosyVoice3-0.5B` for better performance.
Follow the code in `example.py` for detailed usage of each model.
```sh
python example.py
```
#### CosyVoice2 Usage
```python
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, load_vllm=False, fp16=False)
#### vLLM Usage
CosyVoice2/3 now supports **vLLM 0.11.x+ (V1 engine)** and **vLLM 0.9.0 (legacy)**.
Older vllm version(<0.9.0) do not support CosyVoice inference, and versions in between (e.g., 0.10.x) are not tested.
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
# zero_shot usage
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# save zero_shot spk for future usage
assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice.save_spkinfo()
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# instruct usage
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# bistream usage, you can use generator as input, this is useful when using text llm model as input
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
def text_generator():
yield '收到好友从远方寄来的生日礼物,'
yield '那份意外的惊喜与深深的祝福'
yield '让我心中充满了甜蜜的快乐,'
yield '笑容如花儿般绽放。'
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
```
#### CosyVoice2 vllm Usage
If you want to use vllm for inference, please install `vllm==v0.9.0`. Older vllm version do not support CosyVoice2 inference.
Notice that `vllm==v0.9.0` has a lot of specific requirements, for example `torch==2.7.0`. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
Notice that `vllm` has a lot of specific requirements. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
``` sh
conda create -n cosyvoice_vllm --clone cosyvoice
conda activate cosyvoice_vllm
pip install vllm==v0.9.0 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# for vllm==0.9.0
pip install vllm==v0.9.0 transformers==4.51.3 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# for vllm>=0.11.0
pip install vllm==v0.11.0 transformers==4.57.1 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
python vllm_example.py
```
#### CosyVoice Usage
```python
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
# sft usage
print(cosyvoice.list_available_spks())
# change stream=True for chunk stream inference
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# cross_lingual usage
prompt_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# vc usage
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
source_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
```
#### Start web demo
You can use our web demo page to get familiar with CosyVoice quickly.
@@ -223,7 +180,7 @@ python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
#### Advanced Usage
For advanced users, we have provided training and inference scripts in `examples/libritts/cosyvoice/run.sh`.
For advanced users, we have provided training and inference scripts in `examples/libritts`.
#### Build for deployment
@@ -242,6 +199,17 @@ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /o
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
```
#### Using Nvidia TensorRT-LLM for deployment
Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
To quick start:
``` sh
cd runtime/triton_trtllm
docker compose up -d
```
For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
## Discussion & Communication
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).

Binary file not shown.

Before

Width:  |  Height:  |  Size: 94 KiB

After

Width:  |  Height:  |  Size: 120 KiB

Binary file not shown.

View File

@@ -23,7 +23,7 @@ import torch
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import logging
@@ -57,15 +57,9 @@ def main():
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
try:
model = CosyVoice(args.model_dir)
except Exception:
try:
model = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
model = AutoModel(model_dir=args.model_dir)
if not isinstance(model, CosyVoice2):
if model.__class__.__name__ == 'CosyVoice':
# 1. export llm text_encoder
llm_text_encoder = model.model.llm.text_encoder
script = get_optimized_script(llm_text_encoder)
@@ -89,14 +83,16 @@ def main():
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
else:
# 3. export flow encoder
elif model.__class__.__name__ == 'CosyVoice2':
# 1. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
else:
raise ValueError('unsupported model type')
if __name__ == '__main__':

View File

@@ -27,7 +27,7 @@ from tqdm import tqdm
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import logging
@@ -58,13 +58,7 @@ def main():
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
try:
model = CosyVoice(args.model_dir)
except Exception:
try:
model = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
model = AutoModel(model_dir=args.model_dir)
# 1. export flow decoder estimator
estimator = model.model.flow.decoder.estimator

View File

@@ -1,126 +0,0 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.dataset.dataset import Dataset
def get_args():
parser = argparse.ArgumentParser(description='inference with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
parser.add_argument('--tts_text', required=True, help='tts input file')
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
parser.add_argument('--llm_model', required=True, help='llm model file')
parser.add_argument('--flow_model', required=True, help='flow model file')
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--mode',
default='sft',
choices=['sft', 'zero_shot'],
help='inference mode')
parser.add_argument('--result_dir', required=True, help='asr result file')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
# Init cosyvoice models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
except Exception:
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
except Exception:
raise TypeError('no valid model_type!')
model.load(args.llm_model, args.flow_model, args.hifigan_model)
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
sample_rate = configs['sample_rate']
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
with torch.no_grad():
for _, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
assert len(utts) == 1, "inference mode only support batchsize 1"
text_token = batch["text_token"].to(device)
text_token_len = batch["text_token_len"].to(device)
tts_index = batch["tts_index"]
tts_text_token = batch["tts_text_token"].to(device)
tts_text_token_len = batch["tts_text_token_len"].to(device)
speech_token = batch["speech_token"].to(device)
speech_token_len = batch["speech_token_len"].to(device)
speech_feat = batch["speech_feat"].to(device)
speech_feat_len = batch["speech_feat_len"].to(device)
utt_embedding = batch["utt_embedding"].to(device)
spk_embedding = batch["spk_embedding"].to(device)
if args.mode == 'sft':
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
else:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': text_token, 'prompt_text_len': text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
tts_speeches = []
for model_output in model.tts(**model_input):
tts_speeches.append(model_output['tts_speech'])
tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush()
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
if __name__ == '__main__':
logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
main()

View File

@@ -49,6 +49,7 @@ def get_args():
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
parser.add_argument('--onnx_path', required=False, help='onnx path, which is required for online feature extraction')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir',
@@ -96,6 +97,7 @@ def get_args():
@record
def main():
args = get_args()
os.environ['onnx_path'] = args.onnx_path
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
# gan train has some special initialization logic
@@ -104,10 +106,8 @@ def main():
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
if gan is True:
override_dict.pop('hift')
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
except Exception:
if args.qwen_pretrain_path is not None:
override_dict['qwen_pretrain_path'] = args.qwen_pretrain_path
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
if gan is True:

View File

@@ -19,7 +19,7 @@ from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download
import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
@@ -27,7 +27,6 @@ from cosyvoice.utils.class_utils import get_model_type
class CosyVoice:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
@@ -37,7 +36,7 @@ class CosyVoice:
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f)
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
@@ -67,9 +66,9 @@ class CosyVoice:
spks = list(self.frontend.spk2info.keys())
return spks
def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id):
def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '')
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '')
del model_input['text']
del model_input['text_len']
self.frontend.spk2info[zero_shot_spk_id] = model_input
@@ -89,12 +88,14 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
if self.__class__.__name__ == 'CosyVoice3' and '<|endofprompt|>' not in prompt_text + tts_text:
logging.warning('<|endofprompt|> not found in CosyVoice3 inference, check your input text')
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -103,9 +104,9 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -115,9 +116,7 @@ class CosyVoice:
start_time = time.time()
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
if self.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
assert self.__class__.__name__ == 'CosyVoice', 'inference_instruct is only implemented for CosyVoice!'
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
@@ -129,8 +128,8 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate)
start_time = time.time()
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
@@ -142,7 +141,6 @@ class CosyVoice:
class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
@@ -160,9 +158,9 @@ class CosyVoice2(CosyVoice):
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
load_jit, load_trt, load_vllm, fp16 = False, False, False, False
logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
@@ -178,13 +176,9 @@ class CosyVoice2(CosyVoice):
self.fp16)
del configs
def inference_instruct(self, *args, **kwargs):
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -192,3 +186,55 @@ class CosyVoice2(CosyVoice):
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
class CosyVoice3(CosyVoice2):
def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v3.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
load_trt, fp16 = False, False
logging.warning('no cuda device, set load_trt/fp16 to False')
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
if load_vllm:
self.model.load_vllm('{}/vllm'.format(model_dir))
if load_trt:
if self.fp16 is True:
logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
del configs
def AutoModel(**kwargs):
if not os.path.exists(kwargs['model_dir']):
kwargs['model_dir'] = snapshot_download(kwargs['model_dir'])
if os.path.exists('{}/cosyvoice.yaml'.format(kwargs['model_dir'])):
return CosyVoice(**kwargs)
elif os.path.exists('{}/cosyvoice2.yaml'.format(kwargs['model_dir'])):
return CosyVoice2(**kwargs)
elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
return CosyVoice3(**kwargs)
else:
raise TypeError('No valid model type found!')

View File

@@ -20,19 +20,10 @@ import numpy as np
import whisper
from typing import Callable
import torchaudio.compliance.kaldi as kaldi
import torchaudio
import os
import re
import inflect
try:
import ttsfrd
use_ttsfrd = True
except ImportError:
print("failed to import ttsfrd, use wetext instead")
from wetext import Normalizer as ZhNormalizer
from wetext import Normalizer as EnNormalizer
use_ttsfrd = False
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.file_utils import logging, load_wav
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
@@ -56,21 +47,33 @@ class CosyVoiceFrontEnd:
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
"CPUExecutionProvider"])
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device)
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
else:
self.spk2info = {}
self.allowed_special = allowed_special
self.use_ttsfrd = use_ttsfrd
if self.use_ttsfrd:
self.inflect_parser = inflect.engine()
# NOTE compatible when no text frontend tool is avaliable
try:
import ttsfrd
self.frd = ttsfrd.TtsFrontendEngine()
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyinvg')
else:
self.text_frontend = 'ttsfrd'
logging.info('use ttsfrd frontend')
except:
try:
from wetext import Normalizer as ZhNormalizer
from wetext import Normalizer as EnNormalizer
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine()
self.text_frontend = 'wetext'
logging.info('use wetext frontend')
except:
self.text_frontend = ''
logging.info('no frontend is avaliable')
def _extract_text_token(self, text):
if isinstance(text, Generator):
@@ -89,7 +92,8 @@ class CosyVoiceFrontEnd:
for i in range(text_token.shape[1]):
yield text_token[:, i: i + 1]
def _extract_speech_token(self, speech):
def _extract_speech_token(self, prompt_wav):
speech = load_wav(prompt_wav, 16000)
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
speech_token = self.speech_tokenizer_session.run(None,
@@ -101,7 +105,8 @@ class CosyVoiceFrontEnd:
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
return speech_token, speech_token_len
def _extract_spk_embedding(self, speech):
def _extract_spk_embedding(self, prompt_wav):
speech = load_wav(prompt_wav, 16000)
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
@@ -112,7 +117,8 @@ class CosyVoiceFrontEnd:
embedding = torch.tensor([embedding]).to(self.device)
return embedding
def _extract_speech_feat(self, speech):
def _extract_speech_feat(self, prompt_wav):
speech = load_wav(prompt_wav, 24000)
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
@@ -122,14 +128,18 @@ class CosyVoiceFrontEnd:
if isinstance(text, Generator):
logging.info('get tts_text generator, will skip text_normalize!')
return [text]
# NOTE skip text_frontend when ssml symbol in text
if '<|' in text and '|>' in text:
text_frontend = False
if text_frontend is False or text == '':
return [text] if split is True else text
text = text.strip()
if self.use_ttsfrd:
if self.text_frontend == 'ttsfrd':
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
text = ''.join(texts)
else:
if contains_chinese(text):
if self.text_frontend == 'wetext':
text = self.zh_tn_model.normalize(text)
text = text.replace("\n", "")
text = replace_blank(text)
@@ -141,6 +151,7 @@ class CosyVoiceFrontEnd:
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
else:
if self.text_frontend == 'wetext':
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
@@ -154,32 +165,31 @@ class CosyVoiceFrontEnd:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
if zero_shot_spk_id == '':
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
speech_token, speech_token_len = self._extract_speech_token(prompt_wav)
if resample_rate == 24000:
# cosyvoice2, force speech_feat % speech_token = 2
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
embedding = self._extract_spk_embedding(prompt_speech_16k)
embedding = self._extract_spk_embedding(prompt_wav)
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding, 'flow_embedding': embedding}
else:
model_input = self.spk2info[zero_shot_spk_id]
model_input = {**self.spk2info[zero_shot_spk_id]}
model_input['text'] = tts_text_token
model_input['text_len'] = tts_text_token_len
return model_input
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id)
# in cross lingual mode, we remove prompt in llm
del model_input['prompt_text']
del model_input['prompt_text_len']
@@ -191,22 +201,21 @@ class CosyVoiceFrontEnd:
model_input = self.frontend_sft(tts_text, spk_id)
# in instruct mode, we remove spk_embedding in llm due to information leakage
del model_input['llm_embedding']
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text)
model_input['prompt_text'] = instruct_text_token
model_input['prompt_text_len'] = instruct_text_token_len
return model_input
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id)
del model_input['llm_prompt_speech_token']
del model_input['llm_prompt_speech_token_len']
return model_input
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
embedding = self._extract_spk_embedding(prompt_speech_16k)
def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav)
embedding = self._extract_spk_embedding(prompt_wav)
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,

View File

@@ -38,9 +38,6 @@ class CosyVoiceModel:
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 20
@@ -63,14 +60,15 @@ class CosyVoiceModel:
self.mel_overlap_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {}
self.silent_tokens = []
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device, weights_only=True), strict=True)
self.llm.to(self.device).eval()
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device, weights_only=True), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device, weights_only=True).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
@@ -101,25 +99,32 @@ class CosyVoiceModel:
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
cur_silent_token_num, max_silent_token_num = 0, 5
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
if isinstance(text, Generator):
assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
for i in self.llm.inference_bistream(text=text,
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
token_generator = self.llm.inference_bistream(text=text,
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device)):
self.tts_speech_token_dict[uuid].append(i)
embedding=llm_embedding.to(self.device))
else:
for i in self.llm.inference(text=text.to(self.device),
token_generator = self.llm.inference(text=text.to(self.device),
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device),
uuid=uuid):
uuid=uuid)
for i in token_generator:
if i in self.silent_tokens:
cur_silent_token_num += 1
if cur_silent_token_num > max_silent_token_num:
continue
else:
cur_silent_token_num = 0
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
@@ -129,7 +134,7 @@ class CosyVoiceModel:
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
@@ -249,11 +254,12 @@ class CosyVoice2Model(CosyVoiceModel):
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
# NOTE increase token_hop_len incrementally to avoid duplicate inference
self.token_max_hop_len = 4 * self.token_hop_len
self.stream_scale_factor = 2
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
# hift cache
self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480)
@@ -266,6 +272,7 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
self.silent_tokens = []
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
@@ -284,7 +291,7 @@ class CosyVoice2Model(CosyVoiceModel):
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
@@ -350,6 +357,7 @@ class CosyVoice2Model(CosyVoiceModel):
stream=stream,
finalize=False)
token_offset += this_token_hop_len
self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
yield {'tts_speech': this_tts_speech.cpu()}
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
break
@@ -384,3 +392,59 @@ class CosyVoice2Model(CosyVoiceModel):
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
class CosyVoice3Model(CosyVoice2Model):
def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
# NOTE increase token_hop_len incrementally to avoid duplicate inference
self.token_max_hop_len = 4 * self.token_hop_len
self.stream_scale_factor = 2
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
# rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
# FSQ silent and breath token
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append mel cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel = self.hift_cache_dict[uuid]['mel']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
self.hift_cache_dict[uuid]['mel'] = tts_mel
else:
self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
if speed != 1.0:
assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
return tts_speech

View File

@@ -145,7 +145,11 @@ def Dataset(data_list_file,
shuffle=shuffle,
partition=partition)
# map partial arg to padding func
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
for i in range(1, len(data_pipeline)):
if data_pipeline[i].func.__name__ == 'compute_fbank' and gan is True:
data_pipeline[i] = partial(data_pipeline[i], token_mel_ratio=0)
if data_pipeline[i].func.__name__ == 'padding':
data_pipeline[i] = partial(data_pipeline[i], gan=gan, dpo=dpo)
for func in data_pipeline:
dataset = Processor(dataset, func, mode=mode)
return dataset

View File

@@ -16,17 +16,19 @@ import random
import pyarrow.parquet as pq
from io import BytesIO
import numpy as np
import whisper
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import pyworld as pw
from cosyvoice.utils.onnx import embedding_extractor, online_feature
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
def parquet_opener(data, mode='train', tts_data={}):
def parquet_opener(data, mode='train'):
""" Give url or local file, return file descriptor
Inplace operation.
@@ -44,12 +46,8 @@ def parquet_opener(data, mode='train', tts_data={}):
df = df.to_pandas()
for i in range(len(df)):
sample.update(dict(df.loc[i]))
if mode == 'train':
# NOTE do not return sample directly, must initialize a new dict
yield {**sample}
else:
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
yield {**sample, 'tts_index': index, 'tts_text': text}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
@@ -96,9 +94,9 @@ def filter(data,
continue
if len(sample['text_token']) > token_max_length:
continue
if len(sample['speech_token']) == 0:
if online_feature is False and len(sample['speech_token']) == 0:
continue
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['text_token']) / num_frames < min_output_input_ratio:
@@ -159,7 +157,7 @@ def truncate(data, truncate_length=24576, mode='train'):
def compute_fbank(data,
feat_extractor,
token_mel_ratio=0,
num_frames=-1,
mode='train'):
""" Extract fbank
@@ -174,14 +172,28 @@ def compute_fbank(data,
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
if token_mel_ratio != 0:
# trim to align speech_token and speech_feat
token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
feat = feat[:token_mel_ratio * token_len]
sample["speech_token"] = sample["speech_token"][:token_len]
sample['speech_feat'] = feat
# NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
if num_frames != -1:
index = int(np.ceil(sample['speech'].shape[1] / num_frames))
sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
yield sample
def compute_whisper_fbank(data, num_frames=-1, mode='train'):
""" Extract whisper fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
if num_frames != -1:
assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
yield sample
@@ -220,6 +232,11 @@ def parse_embedding(data, normalize, mode='train'):
Iterable[{key, feat, label}]
"""
for sample in data:
if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
embedding = embedding_extractor.inference(sample['speech_16k'])
sample['spk_embedding'] = sample['utt_embedding'] = embedding
else:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
if normalize:
@@ -242,6 +259,8 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if 'instruct' in sample:
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
yield sample
@@ -256,13 +275,14 @@ def shuffle(data, shuffle_size=10000, mode='train'):
Iterable[{key, feat, label}]
"""
buf = []
yield_size = int(shuffle_size / 2)
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
for x in buf[:yield_size]:
yield x
buf = []
buf = buf[yield_size:]
# The sample left over
random.shuffle(buf)
for x in buf:
@@ -368,65 +388,42 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
"""
for sample in data:
assert isinstance(sample, list)
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
dtype=torch.int32)
order = torch.argsort(speech_feat_len, descending=True)
utts = [sample[i]['utt'] for i in order]
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
speech = pad_sequence(speech, batch_first=True, padding_value=0)
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
order = torch.argsort(torch.tensor([x['speech'].size(1) for x in sample], dtype=torch.int32), descending=True)
batch = {}
batch['utts'] = [sample[i]['utt'] for i in order]
batch['text'] = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
"utts": utts,
"speech": speech,
"speech_len": speech_len,
"speech_token": speech_token,
"speech_token_len": speech_token_len,
"speech_feat": speech_feat,
"speech_feat_len": speech_feat_len,
"text": text,
"text_token": text_token,
"text_token_len": text_token_len,
"utt_embedding": utt_embedding,
"spk_embedding": spk_embedding,
}
batch['text_token_len'] = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
batch['text_token'] = pad_sequence(text_token, batch_first=True, padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
batch['speech_feat_len'] = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
batch['speech_feat'] = pad_sequence(speech_feat, batch_first=True, padding_value=0)
batch['utt_embedding'] = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
batch['spk_embedding'] = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
if torch.tensor(['instruct_token' in sample[i] for i in order]).all():
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
whisper_feat = [sample[i]['whisper_feat'] for i in order]
batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
if torch.tensor(['speech_token' in sample[i] for i in order]).all():
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
batch['speech_token_len'] = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
batch['speech_token'] = pad_sequence(speech_token, batch_first=True, padding_value=0)
if gan is True:
# in gan train, we need pitch_feat
# in gan train, we need speech/pitch_feat
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
batch['speech_len'] = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
batch['speech'] = pad_sequence(speech, batch_first=True, padding_value=0)
pitch_feat = [sample[i]['pitch_feat'] for i in order]
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
pitch_feat = pad_sequence(pitch_feat,
batch_first=True,
padding_value=0)
batch["pitch_feat"] = pitch_feat
batch["pitch_feat_len"] = pitch_feat_len
else:
# only gan train needs speech, delete it to save memory
del batch["speech"]
del batch["speech_len"]
batch['pitch_feat_len'] = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
batch['pitch_feat'] = pad_sequence(pitch_feat, batch_first=True, padding_value=0)
if dpo is True:
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
reject_speech_token = pad_sequence(reject_speech_token,
batch_first=True,
padding_value=0)
batch['reject_speech_token'] = reject_speech_token
batch['reject_speech_token_len'] = reject_speech_token_len
batch['reject_speech_token_len'] = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
batch['reject_speech_token'] = pad_sequence(reject_speech_token, batch_first=True, padding_value=0)
if use_spk_embedding is True:
batch["embedding"] = batch["spk_embedding"]
else:

176
cosyvoice/flow/DiT/dit.py Normal file
View File

@@ -0,0 +1,176 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from x_transformers.x_transformers import RotaryEmbedding
from cosyvoice.utils.mask import add_optional_chunk_mask
from cosyvoice.flow.DiT.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
CausalConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
precompute_freqs_cis,
get_pos_embed_indices,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
)
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
batch, text_len = text.shape[0], text.shape[1]
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text_len), value=0)
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
return text
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
super().__init__()
spk_dim = 0 if spk_dim is None else spk_dim
self.spk_dim = spk_dim
self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
def forward(
self,
x: float["b n d"],
cond: float["b n d"],
text_embed: float["b n d"],
spks: float["b d"],
):
to_cat = [x, cond, text_embed]
if self.spk_dim > 0:
spks = repeat(spks, "b c -> b t c", t=x.shape[1])
to_cat.append(spks)
x = self.proj(torch.cat(to_cat, dim=-1))
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using DiT blocks
class DiT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=80,
mu_dim=None,
long_skip_connection=False,
spk_dim=None,
out_channels=None,
static_chunk_size=50,
num_decoding_left_chunks=2
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
if mu_dim is None:
mu_dim = mel_dim
self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
)
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.out_channels = out_channels
self.static_chunk_size = static_chunk_size
self.num_decoding_left_chunks = num_decoding_left_chunks
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
x = x.transpose(1, 2)
mu = mu.transpose(1, 2)
cond = cond.transpose(1, 2)
spks = spks.unsqueeze(dim=1)
batch, seq_len = x.shape[0], x.shape[1]
if t.ndim == 0:
t = t.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(t)
x = self.input_embed(x, cond, mu, spks.squeeze(1))
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:
residual = x
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
else:
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)
for block in self.transformer_blocks:
x = block(x, t, mask=attn_mask.bool(), rope=rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
x = self.norm_out(x, t)
output = self.proj_out(x).transpose(1, 2)
return output

View File

@@ -0,0 +1,616 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from typing import Optional
import math
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
from x_transformers.x_transformers import apply_rotary_pos_emb
# raw wav to mel spec
class MelSpec(nn.Module):
def __init__(
self,
filter_length=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24_000,
normalize=False,
power=1,
norm=None,
center=True,
):
super().__init__()
self.n_mel_channels = n_mel_channels
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=filter_length,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=power,
center=center,
normalized=normalize,
norm=norm,
)
self.register_buffer("dummy", torch.tensor(0), persistent=False)
def forward(self, inp):
if len(inp.shape) == 3:
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
assert len(inp.shape) == 2
if self.dummy.device != inp.device:
self.to(inp.device)
mel = self.mel_stft(inp)
mel = mel.clamp(min=1e-5).log()
return mel
# sinusoidal position embedding
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# convolutional position embedding
class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = self.conv1d(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.0)
return out
class CausalConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.kernel_size = kernel_size
self.conv1 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
nn.Mish(),
)
self.conv2 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
x = self.conv1(x)
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
x = self.conv2(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.0)
return out
# rotary positional embedding related
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
# length = length if isinstance(length, int) else length.max()
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
pos = (
start.unsqueeze(1)
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
)
# avoid extra long error.
pos = torch.where(pos < max_pos, pos, max_pos - 1)
return pos
# Global Response Normalization layer (Instance Normalization ?)
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
# AdaLayerNormZero
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNormZero(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 6)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# AdaLayerNormZero for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNormZero_Final(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 2)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
# FeedForward
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
activation = nn.GELU(approximate=approximate)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.ff(x)
# Attention with possible joint part
# modified from diffusers/src/diffusers/models/attention_processor.py
class Attention(nn.Module):
def __init__(
self,
processor: JointAttnProcessor | AttnProcessor,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.to_q = nn.Linear(dim, self.inner_dim)
self.to_k = nn.Linear(dim, self.inner_dim)
self.to_v = nn.Linear(dim, self.inner_dim)
if self.context_dim is not None:
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, dim)
def forward(
self,
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
else:
return self.processor(self, x, mask=mask, rope=rope)
# Attention processor
class AttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(-1)
else:
mask = mask[:, 0, -1].unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
return x
# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py
class JointAttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
residual = x
batch_size = c.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# `context` projections.
c_query = attn.to_q_c(c)
c_key = attn.to_k_c(c)
c_value = attn.to_v_c(c)
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if c_rope is not None:
freqs, xpos_scale = c_rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
# attention
query = torch.cat([query, c_query], dim=1)
key = torch.cat([key, c_key], dim=1)
value = torch.cat([value, c_value], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# Split the attention outputs.
x, c = (
x[:, : residual.shape[1]],
x[:, residual.shape[1]:],
)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if not attn.context_pre_only:
c = attn.to_out_c(c)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
return x, c
# DiT Block
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
attn_output = self.attn(x=norm, mask=mask, rope=rope)
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(ff_norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
return x
# MMDiT Block https://arxiv.org/abs/2403.03206
class MMDiTBlock(nn.Module):
r"""
modified from diffusers/src/diffusers/models/attention.py
notes.
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
_x: noised input related. (right part)
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
super().__init__()
self.context_pre_only = context_pre_only
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
self.attn_norm_x = AdaLayerNormZero(dim)
self.attn = Attention(
processor=JointAttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
context_dim=dim,
context_pre_only=context_pre_only,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
else:
self.ff_norm_c = None
self.ff_c = None
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
# pre-norm & modulation for attention input
if self.context_pre_only:
norm_c = self.attn_norm_c(c, t)
else:
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
# attention
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
# process attention output for context c
if self.context_pre_only:
c = None
else: # if not last layer
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
c_ff_output = self.ff_c(norm_c)
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
# process attention output for input x
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
x_ff_output = self.ff_x(norm_x)
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
return c, x
# time step conditioning embedding
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float["b"]): # noqa: F821
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d
return time

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os, logging
import random
from typing import Dict, Optional
import torch
@@ -19,6 +19,7 @@ import torch.nn as nn
from torch.nn import functional as F
from omegaconf import DictConfig
from cosyvoice.utils.mask import make_pad_mask
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
class MaskedDiffWithXvec(torch.nn.Module):
@@ -37,14 +38,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
@@ -165,14 +163,11 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
@@ -185,12 +180,17 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
if 'speech_token' not in batch:
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
@@ -279,3 +279,165 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), None
class CausalMaskedDiffWithDiT(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
token_mel_ratio: int = 2,
pre_lookahead_len: int = 3,
pre_lookahead_layer: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.pre_lookahead_len = pre_lookahead_len
self.pre_lookahead_layer = pre_lookahead_layer
self.decoder = decoder
self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
if 'speech_token' not in batch:
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
streaming = True if random.random() < 0.5 else False
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h = self.pre_lookahead_layer(token)
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds,
streaming=streaming,
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
streaming,
finalize):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
if finalize is True:
h = self.pre_lookahead_layer(token)
else:
h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
streaming=streaming
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), None
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from hyperpyyaml import load_hyperpyyaml
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
model = configs['flow']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
max_len = 10 * model.decoder.estimator.static_chunk_size
chunk_size = model.decoder.estimator.static_chunk_size
context_size = model.pre_lookahead_layer.pre_lookahead_len
token = torch.randint(0, 6561, size=(1, max_len)).to(device)
token_len = torch.tensor([max_len]).to(device)
prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
prompt_token_len = torch.tensor([chunk_size]).to(device)
prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
prompt_embedding = torch.rand(1, 192).to(device)
pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
for i in range(0, max_len, chunk_size):
finalize = True if i + chunk_size + context_size >= max_len else False
pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())

View File

@@ -91,12 +91,13 @@ class ConditionalCFM(BASECFM):
sol = []
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
# NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x
@@ -173,8 +174,7 @@ class ConditionalCFM(BASECFM):
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t = 1 - torch.cos(t * 0.5 * torch.pi)
# sample noise p(x_0)
z = torch.randn_like(x1)

View File

@@ -17,6 +17,7 @@ try:
from torch.nn.utils.parametrizations import weight_norm
except ImportError:
from torch.nn.utils import weight_norm
from cosyvoice.transformer.convolution import CausalConv1d
class ConvRNNF0Predictor(nn.Module):
@@ -56,3 +57,47 @@ class ConvRNNF0Predictor(nn.Module):
x = self.condnet(x)
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))
class CausalConvRNNF0Predictor(nn.Module):
def __init__(self,
num_class: int = 1,
in_channels: int = 80,
cond_channels: int = 512
):
super().__init__()
self.num_class = num_class
self.condnet = nn.Sequential(
weight_norm(
CausalConv1d(in_channels, cond_channels, kernel_size=4, causal_type='right')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
)
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
def forward(self, x: torch.Tensor, finalize: bool = True) -> torch.Tensor:
if finalize is True:
x = self.condnet[0](x)
else:
x = self.condnet[0](x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:])
for i in range(1, len(self.condnet)):
x = self.condnet[i](x)
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))

View File

@@ -28,7 +28,7 @@ try:
except ImportError:
from torch.nn.utils import weight_norm
from torch.distributions.uniform import Uniform
from cosyvoice.transformer.convolution import CausalConv1d, CausalConv1dDownSample, CausalConv1dUpsample
from cosyvoice.transformer.activation import Snake
from cosyvoice.utils.common import get_padding
from cosyvoice.utils.common import init_weights
@@ -50,8 +50,10 @@ class ResBlock(torch.nn.Module):
channels: int = 512,
kernel_size: int = 3,
dilations: List[int] = [1, 3, 5],
causal: bool = False,
):
super(ResBlock, self).__init__()
self.causal = causal
self.convs1 = nn.ModuleList()
self.convs2 = nn.ModuleList()
@@ -64,7 +66,14 @@ class ResBlock(torch.nn.Module):
kernel_size,
1,
dilation=dilation,
padding=get_padding(kernel_size, dilation)
padding=get_padding(kernel_size, dilation)) if causal is False else
CausalConv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
causal_type='left'
)
)
)
@@ -76,7 +85,14 @@ class ResBlock(torch.nn.Module):
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)
padding=get_padding(kernel_size, 1)) if causal is False else
CausalConv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
causal_type='left'
)
)
)
@@ -139,11 +155,13 @@ class SineGen(torch.nn.Module):
@torch.no_grad()
def forward(self, f0):
""" sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, dim=1, length)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
:param f0: [B, 1, sample_len], Hz
:return: [B, 1, sample_len]
"""
f0 = f0.transpose(1, 2)
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
for i in range(self.harmonic_num + 1):
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
@@ -168,59 +186,7 @@ class SineGen(torch.nn.Module):
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
# source for harmonic branch
with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
sine_wavs = sine_wavs.transpose(1, 2)
uv = uv.transpose(1, 2)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
return sine_waves.transpose(1, 2), uv.transpose(1, 2), noise
class SineGen2(torch.nn.Module):
@@ -242,7 +208,8 @@ class SineGen2(torch.nn.Module):
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
sine_amp=0.1, noise_std=0.003,
voiced_threshold=0,
flag_for_pulse=False):
flag_for_pulse=False,
causal=False):
super(SineGen2, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
@@ -252,6 +219,11 @@ class SineGen2(torch.nn.Module):
self.voiced_threshold = voiced_threshold
self.flag_for_pulse = flag_for_pulse
self.upsample_scale = upsample_scale
self.causal = causal
if causal is True:
self.rand_ini = torch.rand(1, 9)
self.rand_ini[:, 0] = 0
self.sine_waves = torch.rand(1, 300 * 24000, 9)
def _f02uv(self, f0):
# generate uv signal
@@ -267,6 +239,9 @@ class SineGen2(torch.nn.Module):
rad_values = (f0_values / self.sampling_rate) % 1
# initial phase noise (no noise for fundamental component)
if self.training is False and self.causal is True:
rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
else:
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
@@ -279,7 +254,7 @@ class SineGen2(torch.nn.Module):
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
scale_factor=self.upsample_scale, mode="nearest" if self.causal is True else 'linear').transpose(1, 2)
sines = torch.sin(phase)
else:
# If necessary, make sure that the first time step of every
@@ -331,6 +306,9 @@ class SineGen2(torch.nn.Module):
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
if self.training is False and self.causal is True:
noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
else:
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
@@ -339,7 +317,7 @@ class SineGen2(torch.nn.Module):
return sine_waves, uv, noise
class SourceModuleHnNSF2(torch.nn.Module):
class SourceModuleHnNSF(torch.nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
@@ -358,19 +336,24 @@ class SourceModuleHnNSF2(torch.nn.Module):
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF2, self).__init__()
add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
if sinegen_type == '1':
self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
else:
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
self.causal = causal
if causal is True:
self.uv = torch.rand(1, 300 * 24000, 1)
def forward(self, x):
"""
@@ -385,6 +368,9 @@ class SourceModuleHnNSF2(torch.nn.Module):
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
if self.training is False and self.causal is True:
noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
else:
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
@@ -425,15 +411,16 @@ class HiFTGenerator(nn.Module):
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
# NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
self.m_source = this_SourceModuleHnNSF(
# NOTE in CosyVoice2, we use the original SineGen implementation
self.m_source = SourceModuleHnNSF(
sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics,
sine_amp=nsf_alpha,
add_noise_std=nsf_sigma,
voiced_threshod=nsf_voiced_threshold)
voiced_threshod=nsf_voiced_threshold,
sinegen_type='1' if self.sampling_rate == 22050 else '2',
causal=False)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
self.conv_pre = weight_norm(
@@ -580,3 +567,180 @@ class HiFTGenerator(nn.Module):
s[:, :, :cache_source.shape[2]] = cache_source
generated_speech = self.decode(x=speech_feat, s=s)
return generated_speech, s
class CausalHiFTGenerator(HiFTGenerator):
"""
HiFTNet Generator: Neural Source Filter + ISTFTNet
https://arxiv.org/abs/2309.09493
"""
def __init__(
self,
in_channels: int = 80,
base_channels: int = 512,
nb_harmonics: int = 8,
sampling_rate: int = 22050,
nsf_alpha: float = 0.1,
nsf_sigma: float = 0.003,
nsf_voiced_threshold: float = 10,
upsample_rates: List[int] = [8, 8],
upsample_kernel_sizes: List[int] = [16, 16],
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
resblock_kernel_sizes: List[int] = [3, 7, 11],
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
source_resblock_kernel_sizes: List[int] = [7, 11],
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
lrelu_slope: float = 0.1,
audio_limit: float = 0.99,
conv_pre_look_right: int = 4,
f0_predictor: torch.nn.Module = None,
):
torch.nn.Module.__init__(self)
self.out_channels = 1
self.nb_harmonics = nb_harmonics
self.sampling_rate = sampling_rate
self.istft_params = istft_params
self.lrelu_slope = lrelu_slope
self.audio_limit = audio_limit
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.m_source = SourceModuleHnNSF(
sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics,
sine_amp=nsf_alpha,
add_noise_std=nsf_sigma,
voiced_threshod=nsf_voiced_threshold,
sinegen_type='1' if self.sampling_rate == 22050 else '2',
causal=True)
self.upsample_rates = upsample_rates
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
self.conv_pre = weight_norm(
CausalConv1d(in_channels, base_channels, conv_pre_look_right + 1, 1, causal_type='right')
)
# Up
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
CausalConv1dUpsample(
base_channels // (2**i),
base_channels // (2**(i + 1)),
k,
u,
)
)
)
# Down
self.source_downs = nn.ModuleList()
self.source_resblocks = nn.ModuleList()
downsample_rates = [1] + upsample_rates[::-1][:-1]
downsample_cum_rates = np.cumprod(downsample_rates)
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
if u == 1:
self.source_downs.append(
CausalConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1, causal_type='left')
)
else:
self.source_downs.append(
CausalConv1dDownSample(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
)
self.source_resblocks.append(
ResBlock(base_channels // (2 ** (i + 1)), k, d, causal=True)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = base_channels // (2**(i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(ResBlock(ch, k, d, causal=True))
self.conv_post = weight_norm(CausalConv1d(ch, istft_params["n_fft"] + 2, 7, 1, causal_type='left'))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.reflection_pad = nn.ReflectionPad1d((1, 0))
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
self.conv_pre_look_right = conv_pre_look_right
self.f0_predictor = f0_predictor
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
if finalize is True:
x = self.conv_pre(x)
else:
x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
s_stft_real = s_stft_real[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
s_stft_imag = s_stft_imag[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, self.lrelu_slope)
x = self.ups[i](x)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
# fusion
si = self.source_downs[i](s_stft)
si = self.source_resblocks[i](si)
x = x + si
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
x = self._istft(magnitude, phase)
if finalize is False:
x = x[:, :-int(np.prod(self.upsample_rates) * self.istft_params['hop_len'])]
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x
@torch.inference_mode()
def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
# mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
self.f0_predictor.to(torch.float64)
f0 = self.f0_predictor(speech_feat.to(torch.float64), finalize=finalize).to(speech_feat)
# f0->source
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
s, _, _ = self.m_source(s)
s = s.transpose(1, 2)
if finalize is True:
generated_speech = self.decode(x=speech_feat, s=s, finalize=finalize)
else:
generated_speech = self.decode(x=speech_feat[:, :, :-self.f0_predictor.condnet[0].causal_padding], s=s, finalize=finalize)
return generated_speech, s
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from hyperpyyaml import load_hyperpyyaml
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
configs = load_hyperpyyaml(f, overrides={'llm': None, 'flow': None})
model = configs['hift']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
max_len, chunk_size, context_size = 300, 30, 8
mel = torch.rand(1, 80, max_len).to(device)
pred_gt, _ = model.inference(mel)
for i in range(0, max_len, chunk_size):
finalize = True if i + chunk_size + context_size >= max_len else False
pred_chunk, _ = model.inference(mel[:, :, : i + chunk_size + context_size], finalize=finalize)
pred_chunk = pred_chunk[:, i * 480:]
print((pred_gt[:, i * 480:i * 480 + pred_chunk.shape[1]] - pred_chunk).abs().max().item())

View File

@@ -1,5 +1,5 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua)
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
import os, queue
import random
import time
import threading
from typing import Dict, Optional, Callable, List, Generator
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
@@ -27,6 +28,7 @@ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.mask import make_pad_mask
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
class TransformerLM(torch.nn.Module):
@@ -56,8 +58,9 @@ class TransformerLM(torch.nn.Module):
)
# 2. build speech token language model related modules
self.sos_eos = 0
self.sos = 0
self.task_id = 1
self.eos_token = self.speech_token_size
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
@@ -85,10 +88,10 @@ class TransformerLM(torch.nn.Module):
encoder_out = self.text_encoder_affine_layer(encoder_out)
return encoder_out, encoder_out_lens
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
@@ -126,15 +129,15 @@ class TransformerLM(torch.nn.Module):
embedding = self.spk_embed_affine_layer(embedding)
embedding = embedding.unsqueeze(1)
# 3. eos and task_id
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
# 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 4. encode speech_token
speech_token = self.speech_embedding(speech_token)
# 5. unpad and pad
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
task_id_emb, speech_token, speech_token_len)
# 6. run lm forward
@@ -154,7 +157,7 @@ class TransformerLM(torch.nn.Module):
num_trials, max_trials = 0, 100
while True:
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
if (not ignore_eos) or (self.speech_token_size not in top_ids):
if (not ignore_eos) or (top_ids < self.speech_token_size):
break
num_trials += 1
if num_trials > max_trials:
@@ -193,13 +196,13 @@ class TransformerLM(torch.nn.Module):
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -215,11 +218,8 @@ class TransformerLM(torch.nn.Module):
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
# force continue decode first token
if i == 0:
logp[:, self.speech_token_size] = -float('inf')
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
if top_ids == self.eos_token:
break
# in stream mode, yield token one by one
yield top_ids
@@ -276,9 +276,10 @@ class Qwen2LM(TransformerLM):
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
# 2. build speech token language model related modules
self.sos_eos = 0
self.sos = 0
self.task_id = 1
self.fill_token = 2
self.eos_token = speech_token_size
self.fill_token = speech_token_size + 2
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm
@@ -300,19 +301,29 @@ class Qwen2LM(TransformerLM):
# 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(3)]
self.vllm_output_queue = {}
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
lm_target, lm_input = [], []
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
# NOTE add instruct_token in CosyVoice3
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
else:
instruct_token = [torch.empty(0).to(text_token[0])] * len(text_token)
instruct_token_emb = [torch.empty(0, 896).to(text_token_emb[0])] * len(text_token)
instruct_token_len = torch.zeros(len(text_token)).to(text_token_len)
for i in range(len(text_token)):
# bistream sequence
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
this_lm_target, this_lm_input = [], []
this_lm_target.append(IGNORE_ID)
this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
this_lm_target += [IGNORE_ID] * instruct_token_len[i]
this_lm_input.append(instruct_token_emb[i])
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
@@ -320,22 +331,21 @@ class Qwen2LM(TransformerLM):
assert len(this_speech_token) == self.mix_ratio[1]
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
this_lm_target += this_speech_token
this_lm_target.append(self.speech_token_size + 2)
this_lm_target.append(self.fill_token)
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
else:
this_lm_target += [-1] * len(this_text_token)
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
this_lm_target.append(self.speech_token_size)
this_lm_target.append(self.eos_token)
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
this_lm_input.append(task_id_emb.squeeze(dim=0))
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
# unistream sequence
else:
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
lm_target.append(this_lm_target)
lm_input.append(this_lm_input)
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
@@ -357,17 +367,25 @@ class Qwen2LM(TransformerLM):
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
if 'speech_token' not in batch:
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
# 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token
speech_token_emb = self.speech_embedding(speech_token)
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
speech_token, speech_token_emb, speech_token_len)
lm_target = lm_target.to(device)
# 4. run lm forward
@@ -392,6 +410,10 @@ class Qwen2LM(TransformerLM):
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
# 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
@@ -401,8 +423,8 @@ class Qwen2LM(TransformerLM):
speech_token_combined_emb = self.speech_embedding(speech_token_combined)
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
lm_target = lm_target.to(device)
# 4. run lm forward
@@ -420,8 +442,8 @@ class Qwen2LM(TransformerLM):
rejected_lm_mask = rejected_lm_target == IGNORE_ID
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1)
rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
@torch.inference_mode()
@@ -445,13 +467,13 @@ class Qwen2LM(TransformerLM):
text = self.llm.model.model.embed_tokens(text)
# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -500,11 +522,9 @@ class Qwen2LM(TransformerLM):
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
if top_ids in self.stop_token_ids:
break
if top_ids > self.speech_token_size:
continue
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
@@ -526,20 +546,20 @@ class Qwen2LM(TransformerLM):
device = prompt_text.device
# 1. prepare input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb], dim=1)
lm_input = torch.concat([sos_emb], dim=1)
# 2. iterate text
out_tokens = []
cache = None
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
text_cache = self.llm.model.model.embed_tokens(prompt_text)
next_fill_index = -1
next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
for this_text in text:
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
# prompt_speech_token_emb not empty, try append to lm_input
@@ -554,12 +574,12 @@ class Qwen2LM(TransformerLM):
break
# no prompt_speech_token_emb remain, can decode some speech token
if prompt_speech_token_emb.size(1) == 0:
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
logging.info('get fill token, need to append more text token')
if text_cache.size(1) >= self.mix_ratio[0]:
lm_input_text = text_cache[:, :self.mix_ratio[0]]
logging.info('append {} text token'.format(lm_input_text.size(1)))
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
lm_input = lm_input_text
else:
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
@@ -574,16 +594,16 @@ class Qwen2LM(TransformerLM):
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
top_ids = self.speech_token_size + 2
top_ids = self.fill_token
next_fill_index += (self.mix_ratio[1] + 1)
else:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
if top_ids == self.speech_token_size + 2:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
if top_ids == self.fill_token:
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size + 2:
if top_ids == self.fill_token:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
@@ -599,13 +619,142 @@ class Qwen2LM(TransformerLM):
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False)
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size:
if top_ids == self.eos_token:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
# in stream mode, yield token one by one
yield top_ids
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
class CosyVoice3LM(Qwen2LM):
def __init__(
self,
llm_input_size: int,
llm_output_size: int,
speech_token_size: int,
llm: torch.nn.Module,
sampling: Callable,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
mix_ratio: List[int] = [5, 15],
):
torch.nn.Module.__init__(self)
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
# 2. build speech token language model related modules
self.sos = speech_token_size + 0
self.eos_token = speech_token_size + 1
self.task_id = speech_token_size + 2
self.fill_token = speech_token_size + 3
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
self.criterion_ce = LabelSmoothingLoss(
size=speech_token_size + 200,
padding_idx=IGNORE_ID,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
# 3. [Optional] build speech token related modules
self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
# 4. sampling method
self.sampling = sampling
self.mix_ratio = mix_ratio
# 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(200)]
self.vllm_output_queue = {}
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
if 'speech_token' not in batch:
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# NOTE should append instruct_token to sequence, not implemented yet
instruct_token = batch['instruct_token'].to(device)
instruct_token_len = batch['instruct_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
# 3. sos and task_id
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token
speech_token_emb = self.speech_embedding(speech_token)
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
lm_target = lm_target.to(device)
# 4. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target.to(device))
acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
text = self.llm.model.model.embed_tokens(text)
# 3. concat llm_input
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 5. step by step decode
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
yield token

View File

@@ -238,7 +238,7 @@ def get_tokenizer(
)
class QwenTokenizer():
class CosyVoice2Tokenizer():
def __init__(self, token_path, skip_special_tokens=True):
super().__init__()
# NOTE: non-chat model, all these special tokens keep randomly initialized.
@@ -271,9 +271,57 @@ class QwenTokenizer():
return text
class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
def __init__(self, token_path, skip_special_tokens=True):
# NOTE: non-chat model, all these special tokens keep randomly initialized.
special_tokens = {
'eos_token': '<|endoftext|>',
'pad_token': '<|endoftext|>',
'additional_special_tokens': [
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
'[breath]', '<strong>', '</strong>', '[noise]',
'[laughter]', '[cough]', '[clucking]', '[accent]',
'[quick_breath]',
"<laughter>", "</laughter>",
"[hissing]", "[sigh]", "[vocalized-noise]",
"[lipsmack]", "[mn]", "<|endofsystem|>",
"[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]",
"[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]",
"[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]",
"[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]",
"[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]",
"[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]",
"[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]",
"[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]",
"[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]",
"[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]",
"[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]",
"[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]",
"[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]",
"[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]",
"[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]",
"[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]",
"[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]",
"[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]",
"[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]",
"[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]"
]
}
self.special_tokens = special_tokens
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
self.tokenizer.add_special_tokens(special_tokens)
self.skip_special_tokens = skip_special_tokens
@lru_cache(maxsize=None)
def get_qwen_tokenizer(
token_path: str,
skip_special_tokens: bool
) -> QwenTokenizer:
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
skip_special_tokens: bool,
version: str = 'cosyvoice2'
):
if version == 'cosyvoice2':
return CosyVoice2Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
elif version == 'cosyvoice3':
return CosyVoice3Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
else:
raise ValueError

View File

@@ -19,6 +19,7 @@ from typing import Tuple
import torch
from torch import nn
import torch.nn.functional as F
class ConvolutionModule(nn.Module):
@@ -143,3 +144,115 @@ class ConvolutionModule(nn.Module):
x.masked_fill_(~mask_pad, 0.0)
return x.transpose(1, 2), new_cache
# NOTE(Xiang Lyu) causal conv module used in convolution-based vocoder
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
causal_type: str = 'left',
device=None,
dtype=None
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels,
kernel_size, stride=1,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride == 1
self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2 + (kernel_size + 1) % 2
assert causal_type in ['left', 'right']
self.causal_type = causal_type
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor]:
input_timestep = x.shape[2]
if cache.size(2) == 0:
cache = torch.zeros(x.shape[0], x.shape[1], self.causal_padding).to(x)
assert cache.size(2) == self.causal_padding
if self.causal_type == 'left':
x = torch.concat([cache, x], dim=2)
else:
x = torch.concat([x, cache], dim=2)
x = super(CausalConv1d, self).forward(x)
assert x.shape[2] == input_timestep
return x
class CausalConv1dDownSample(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1dDownSample, self).__init__(in_channels, out_channels,
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride != 1 and dilation == 1
assert kernel_size % stride == 0
self.causal_padding = stride - 1
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
if cache.size(2) == 0:
x = F.pad(x, (self.causal_padding, 0), value=0.0)
else:
assert cache.size(2) == self.causal_padding
x = torch.concat([cache, x], dim=2)
x = super(CausalConv1dDownSample, self).forward(x)
return x
class CausalConv1dUpsample(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1dUpsample, self).__init__(in_channels, out_channels,
kernel_size, 1,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert dilation == 1
self.causal_padding = kernel_size - 1
self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.upsample(x)
input_timestep = x.shape[2]
if cache.size(2) == 0:
x = F.pad(x, (self.causal_padding, 0), value=0.0)
else:
assert cache.size(2) == self.causal_padding
x = torch.concat([cache, x], dim=2)
x = super(CausalConv1dUpsample, self).forward(x)
assert input_timestep == x.shape[2]
return x

View File

@@ -64,17 +64,18 @@ class Upsample1D(nn.Module):
class PreLookaheadLayer(nn.Module):
def __init__(self, channels: int, pre_lookahead_len: int = 1):
def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.pre_lookahead_len = pre_lookahead_len
self.conv1 = nn.Conv1d(
channels, channels,
in_channels, channels,
kernel_size=pre_lookahead_len + 1,
stride=1, padding=0,
)
self.conv2 = nn.Conv1d(
channels, channels,
channels, in_channels,
kernel_size=3, stride=1, padding=0,
)
@@ -199,7 +200,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal)
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
self.pre_lookahead_layer = PreLookaheadLayer(in_channels=512, channels=512, pre_lookahead_len=3)
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,

View File

@@ -32,10 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
from cosyvoice.hifigan.generator import HiFTGenerator
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
COSYVOICE_ACTIVATION_CLASSES = {
@@ -80,4 +80,6 @@ def get_model_type(configs):
return CosyVoiceModel
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
return CosyVoice2Model
if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
return CosyVoice3Model
raise TypeError('No valid model type found!')

View File

@@ -25,6 +25,33 @@ import torch
IGNORE_ID = -1
instruct_list = ["You are a helpful assistant. 请用广东话表达。<|endofprompt|>",
"You are a helpful assistant. 请用东北话表达。<|endofprompt|>",
"You are a helpful assistant. 请用甘肃话表达。<|endofprompt|>",
"You are a helpful assistant. 请用贵州话表达。<|endofprompt|>",
"You are a helpful assistant. 请用河南话表达。<|endofprompt|>",
"You are a helpful assistant. 请用湖北话表达。<|endofprompt|>",
"You are a helpful assistant. 请用湖南话表达。<|endofprompt|>",
"You are a helpful assistant. 请用江西话表达。<|endofprompt|>",
"You are a helpful assistant. 请用闽南话表达。<|endofprompt|>",
"You are a helpful assistant. 请用宁夏话表达。<|endofprompt|>",
"You are a helpful assistant. 请用山西话表达。<|endofprompt|>",
"You are a helpful assistant. 请用陕西话表达。<|endofprompt|>",
"You are a helpful assistant. 请用山东话表达。<|endofprompt|>",
"You are a helpful assistant. 请用上海话表达。<|endofprompt|>",
"You are a helpful assistant. 请用四川话表达。<|endofprompt|>",
"You are a helpful assistant. 请用天津话表达。<|endofprompt|>",
"You are a helpful assistant. 请用云南话表达。<|endofprompt|>",
"You are a helpful assistant. Please say a sentence as loudly as possible.<|endofprompt|>",
"You are a helpful assistant. Please say a sentence in a very soft voice.<|endofprompt|>",
"You are a helpful assistant. 请用尽可能慢地语速说一句话。<|endofprompt|>",
"You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>",
"You are a helpful assistant. 请非常开心地说一句话。<|endofprompt|>",
"You are a helpful assistant. 请非常伤心地说一句话。<|endofprompt|>",
"You are a helpful assistant. 请非常生气地说一句话。<|endofprompt|>",
"You are a helpful assistant. 我想体验一下小猪佩奇风格,可以吗?<|endofprompt|>",
"You are a helpful assistant. 你可以尝试用机器人的方式解答吗?<|endofprompt|>"]
def pad_list(xs: List[torch.Tensor], pad_value: int):
"""Perform padding for the list of tensors.
@@ -112,6 +139,7 @@ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25,
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
if rep_num >= win_size * tau_r:
weighted_scores[top_ids] = -float('inf')
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
return top_ids
@@ -130,12 +158,12 @@ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
break
prob = torch.tensor(prob).to(weighted_scores)
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
top_ids = indices[prob.multinomial(1, replacement=True)]
top_ids = indices[prob.multinomial(1, replacement=True)].item()
return top_ids
def random_sampling(weighted_scores, decoded_tokens, sampling):
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True).item()
return top_ids

View File

@@ -166,7 +166,7 @@ class Executor:
for k, v in info_dict['loss_dict'].items():
if k not in total_loss_dict:
total_loss_dict[k] = []
total_loss_dict[k].append(v.item() * num_utts)
total_loss_dict[k].append(v.mean().item() * num_utts)
log_per_step(None, info_dict)
for k, v in total_loss_dict.items():
total_loss_dict[k] = sum(v) / total_num_utts

View File

@@ -41,11 +41,11 @@ def read_json_lists(list_file):
return results
def load_wav(wav, target_sr):
def load_wav(wav, target_sr, min_sr=16000):
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
speech = speech.mean(dim=0, keepdim=True)
if sample_rate != target_sr:
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
return speech
@@ -88,30 +88,18 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
logging.info("Succesfully convert onnx to trt...")
# NOTE do not support bistream inference as only speech token embedding/head is kept
def export_cosyvoice2_vllm(model, model_path, device):
if os.path.exists(model_path):
return
pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
vocab_size = model.speech_embedding.num_embeddings
feature_size = model.speech_embedding.embedding_dim
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
dtype = torch.bfloat16
# lm_head
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
with torch.no_grad():
new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
new_lm_head.weight[vocab_size:] = 0
new_lm_head.bias[vocab_size:] = 0
model.llm.model.lm_head = new_lm_head
new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
use_bias = True if model.llm_decoder.bias is not None else False
model.llm.model.lm_head = model.llm_decoder
# embed_tokens
embed_tokens = model.llm.model.model.embed_tokens
with torch.no_grad():
new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight
new_codec_embed.weight[vocab_size:] = 0
model.llm.model.set_input_embeddings(new_codec_embed)
model.llm.model.set_input_embeddings(model.speech_embedding)
model.llm.model.to(device)
model.llm.model.to(dtype)
tmp_vocab_size = model.llm.model.config.vocab_size
@@ -119,10 +107,11 @@ def export_cosyvoice2_vllm(model, model_path, device):
del model.llm.model.generation_config.eos_token_id
del model.llm.model.config.bos_token_id
del model.llm.model.config.eos_token_id
model.llm.model.config.vocab_size = pad_vocab_size
model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings
model.llm.model.config.tie_word_embeddings = False
model.llm.model.config.use_bias = True
model.llm.model.config.use_bias = use_bias
model.llm.model.save_pretrained(model_path)
if use_bias is True:
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
model.llm.model.config.vocab_size = tmp_vocab_size
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding

54
cosyvoice/utils/onnx.py Normal file
View File

@@ -0,0 +1,54 @@
import onnxruntime
import torch, random
import os
import torchaudio.compliance.kaldi as kaldi
class SpeechTokenExtractor():
def __init__(self, model_path):
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.speech_tokenizer_session = onnxruntime.InferenceSession(model_path,
sess_options=option,
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
def inference(self, feat, feat_lengths, device):
speech_token = self.speech_tokenizer_session.run(None,
{self.speech_tokenizer_session.get_inputs()[0].name:
feat.transpose(1, 2).detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name:
feat_lengths.detach().cpu().numpy()})[0]
return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
class EmbeddingExtractor():
def __init__(self, model_path):
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.max_len = 10 * 16000
self.campplus_session = onnxruntime.InferenceSession(model_path,
sess_options=option,
providers=["CPUExecutionProvider"])
def inference(self, speech):
if speech.shape[1] > self.max_len:
start_index = random.randint(0, speech.shape[1] - self.max_len)
speech = speech[:, start_index: start_index + self.max_len]
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.campplus_session.run(None,
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
return torch.tensor(embedding).to(speech.device)
# singleton mode, only initialized once
onnx_path = os.environ.get('onnx_path')
if onnx_path is not None:
embedding_extractor, online_feature = EmbeddingExtractor(model_path=os.path.join(onnx_path, 'campplus.onnx')), True
else:
embedding_extractor, online_feature = None, False

View File

@@ -53,7 +53,7 @@ def init_distributed(args):
def init_dataset_and_dataloader(args, configs, gan, dpo):
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=False, partition=False)
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='dev', gan=gan, dpo=dpo, shuffle=False, partition=False)
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
train_data_loader = DataLoader(train_dataset,
@@ -71,7 +71,7 @@ def init_dataset_and_dataloader(args, configs, gan, dpo):
def check_modify_and_save_config(args, configs):
if args.train_engine == "torch_ddp":
configs['train_conf']["dtype"] = 'fp32'
configs['train_conf']["dtype"] = 'bf16' if args.use_amp is True else 'fp32'
else:
with open(args.deepspeed_config, 'r') as fin:
ds_configs = json.load(fin)
@@ -164,18 +164,18 @@ def init_optimizer_and_scheduler(args, configs, model, gan):
raise ValueError("unknown scheduler: " + configs['train_conf'])
if configs['train_conf']['optim_d'] == 'adam':
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
elif configs['train_conf']['optim_d'] == 'adamw':
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
else:
raise ValueError("unknown optimizer: " + configs['train_conf'])
if configs['train_conf']['scheduler_d'] == 'warmuplr':
scheduler_type = WarmupLR
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_d'])
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_d'])
elif configs['train_conf']['scheduler'] == 'constantlr':
scheduler_type = ConstantLR
scheduler_d = ConstantLR(optimizer_d)
@@ -247,7 +247,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None
dtype = torch.float32
if info_dict['train_engine'] == 'torch_ddp':
autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
autocast = torch.cuda.amp.autocast(enabled=scaler is not None, dtype=dtype)
else:
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)

View File

@@ -23,6 +23,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Optional
from packaging.version import parse as vparse
import vllm
# vLLM-0.11.0+ only support V1 engine
VLLM_V1_ENGINE_ONLY: bool = vparse(vllm.__version__) >= vparse("0.11.0")
if VLLM_V1_ENGINE_ONLY:
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.model_executor.models.qwen2 import *
@@ -87,8 +96,12 @@ class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_metadata: Optional[SamplingMetadata] = None,
) -> Optional[torch.Tensor]:
if VLLM_V1_ENGINE_ONLY:
logits = self.logits_processor(self.lm_head, hidden_states,
self.lm_head.bias)
else:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits

View File

@@ -4,7 +4,7 @@ ARG VENV_NAME="cosyvoice"
ENV VENV=$VENV_NAME
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
ENV DEBIAN_FRONTEN=noninteractive
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
SHELL ["/bin/bash", "--login", "-c"]
@@ -46,6 +46,6 @@ RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5
RUN conda activate ${VENV} && cd CosyVoice && \
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
WORKDIR /workspace/CosyVoice

112
example.py Normal file
View File

@@ -0,0 +1,112 @@
import sys
sys.path.append('third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import AutoModel
import torchaudio
def cosyvoice_example():
""" CosyVoice Usage, check https://fun-audio-llm.github.io/ for more details
"""
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-SFT')
# sft usage
print(cosyvoice.list_available_spks())
# change stream=True for chunk stream inference
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M')
# zero_shot usage
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# cross_lingual usage, <|zh|><|en|><|ja|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.',
'./asset/cross_lingual_prompt.wav')):
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# vc usage
for i, j in enumerate(cosyvoice.inference_vc('./asset/cross_lingual_prompt.wav', './asset/zero_shot_prompt.wav')):
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-Instruct')
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男',
'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.<|endofprompt|>')):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
def cosyvoice2_example():
""" CosyVoice2 Usage, check https://funaudiollm.github.io/cosyvoice2/ for more details
"""
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice2-0.5B')
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
# zero_shot usage
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# save zero_shot spk for future usage
assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', 'my_zero_shot_spk') is True
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk')):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice.save_spkinfo()
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', './asset/zero_shot_prompt.wav')):
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# instruct usage
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话<|endofprompt|>', './asset/zero_shot_prompt.wav')):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# bistream usage, you can use generator as input, this is useful when using text llm model as input
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
def text_generator():
yield '收到好友从远方寄来的生日礼物,'
yield '那份意外的惊喜与深深的祝福'
yield '让我心中充满了甜蜜的快乐,'
yield '笑容如花儿般绽放。'
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('zero_shot_bistream_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
def cosyvoice3_example():
""" CosyVoice3 Usage, check https://funaudiollm.github.io/cosyvoice3/ for more details
"""
cosyvoice = AutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B')
# zero_shot usage
for i, j in enumerate(cosyvoice.inference_zero_shot('八百标兵奔北坡,北坡炮兵并排跑,炮兵怕把标兵碰,标兵怕碰炮兵炮。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L280
for i, j in enumerate(cosyvoice.inference_cross_lingual('You are a helpful assistant.<|endofprompt|>[breath]因为他们那一辈人[breath]在乡里面住的要习惯一点,[breath]邻居都很活络,[breath]嗯,都很熟悉。[breath]',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# instruct usage, for supported control, check cosyvoice/utils/common.py#L28
for i, j in enumerate(cosyvoice.inference_instruct2('好少咯,一般系放嗰啲国庆啊,中秋嗰啲可能会咯。', 'You are a helpful assistant. 请用广东话表达。<|endofprompt|>',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', 'You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# hotfix usage
for i, j in enumerate(cosyvoice.inference_zero_shot('高管也通过电话、短信、微信等方式对报道[j][ǐ]予好评。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('hotfix_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# NOTE for Japanese usage, you must translate it to katakana.
# 歴史的世界においては、過去は単に過ぎ去ったものではない、プラトンのいう如く非有が有である。 -> レキシ テキ セカイ ニ オイ テ ワ、カコ ワ タンニ スギサッ タ モノ デ ワ ナイ、プラトン イウ ゴトク ヒ ユー ガ ユー デ アル。
for i, j in enumerate(cosyvoice.inference_cross_lingual('You are a helpful assistant.<|endofprompt|>レキシ テキ セカイ ニ オイ テ ワ、カコ ワ タンニ スギサッ タ モノ デ ワ ナイ、プラトン イウ ゴトク ヒ ユー ガ ユー デ アル。',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('japanese_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
def main():
# cosyvoice_example()
# cosyvoice2_example()
cosyvoice3_example()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,6 @@
FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
COPY requirements.txt /myworkspace/requirements.txt
RUN pip install -r /myworkspace/requirements.txt
RUN pip install -U nvidia-pytriton
RUN git clone https://github.com/yuekaizhang/verl.git /myworkspace/verl -b thread && cd /myworkspace/verl && pip install --no-deps -e .
RUN git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /myworkspace/PytritonSenseVoice && cd /myworkspace/PytritonSenseVoice && pip install -e .

View File

@@ -0,0 +1,125 @@
# CosyVoice2 LLM Reinforcement Learning Recipe
This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%.
## Table of Contents
- [Environment Setup](#environment-setup)
- [Data Preparation](#data-preparation)
- [Reward Function & ASR Server](#reward-function--asr-server)
- [Training](#training)
- [Evaluation](#evaluation)
- [Export Model](#export-model)
- [Results](#results)
- [Acknowledgement](#acknowledgement)
## Environment Setup
We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile.
```bash
docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
```
If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally.
## Data Preparation
`prepare_data.py` expects a JSON/JSONL file with at least the following schema:
```jsonc
{
"text": "An example sentence to be synthesized."
}
```
You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face.
Stage `0` converts raw JSONL files into the parquet format expected by veRL:
```bash
bash run.sh 0 0
```
Create two JSONL files—`train.jsonl` and `test.jsonl`.
The script will then generate two Parquet files:
```
data/parquet_tiny/train.parquet
data/parquet_tiny/test.parquet
```
Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens.
## Reward Function & ASR Server
To compute rewards, we run a lightweight server that:
1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model.
2. Transcribes the waveform with **SenseVoice** ASR.
3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1.
Start the server (stage `1`) in a dedicated terminal or on a separate GPU:
```bash
bash run.sh 1 1
# Triton server listens on ports 8000/8001/8002
```
The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
## Training
Run stage `2` to start GRPO training:
```bash
bash run.sh 2 2
```
Key CLI arguments passed to `verl.trainer.main_ppo`:
* `algorithm.adv_estimator=grpo` use GRPO instead of PPO.
* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet`
* `custom_reward_function.path=reward_tts.py` custom reward function described above.
Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware.
> [!TIP]
> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model.
## Evaluation
After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`):
```bash
bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model
```
You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`):
```bash
bash run.sh 4 4
```
This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`.
> [!TIP]
> The script also supports the Seed-TTS test set by setting `dataset=test_zh`.
## Export Model
To use the RL-trained model with the official CosyVoice repository:
```bash
bash run.sh 5 5
```
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
> [!TIP]
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
## Results
| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment |
|-------|------------------------|------------------------------|---------|
| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) |
| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model |
## Acknowledgement
This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo).

View File

@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python3 hf2pretrained.py --hf-cosyvoice2-llm-path /workspace/rl-exp/checkpoint-400 --output-path /workspace/CosyVoice2-0.5B/llm-new.pt
"""
from argparse import ArgumentParser
import torch
from safetensors import safe_open
from transformers import AutoTokenizer
def get_args():
parser = ArgumentParser()
parser.add_argument(
"--hf-cosyvoice2-llm-path",
type=str,
default=None,
help="The RL trained CosyVoice2 model path in HuggingFace format",
)
parser.add_argument(
"--output-path",
type=str,
default="./llm.pt",
help="The path to save the llm.pt",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
tokenizer = AutoTokenizer.from_pretrained(args.hf_cosyvoice2_llm_path)
speech_start_idx = tokenizer.convert_tokens_to_ids("<|s_0|>")
cosyvoice2_token_size = 6561 + 3
llm_embedding_vocab_size = 2
hf_tensors = {}
with safe_open(f"{args.hf_cosyvoice2_llm_path}/model.safetensors", framework="pt", device="cpu") as f:
for k in f.keys():
if k.startswith("lm_head.bias"):
# RL trained model disable bias for lm_head
continue
new_k = "llm.model." + k
hf_tensors[new_k] = f.get_tensor(k)
if k.startswith("lm_head"):
hf_tensors["llm_decoder.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
hf_tensors["llm_decoder.bias"] = torch.zeros_like(hf_tensors["llm_decoder.weight"][:, 0])
if k.startswith("model.embed_tokens"):
hf_tensors["speech_embedding.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
hf_tensors["llm_embedding.weight"] = f.get_tensor(k)[speech_start_idx + cosyvoice2_token_size:speech_start_idx + cosyvoice2_token_size + llm_embedding_vocab_size]
# use tie_word_embeddings=True
hf_tensors["llm.model.model.embed_tokens.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"][:151936]
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
torch.save(hf_tensors, args.output_path)

View File

@@ -0,0 +1,397 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
dataset=zero_shot_zh
output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts
token2wav_path=/workspace/CosyVoice2-0.5B
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --nproc_per_node=8 \
infer_dataset.py \
--output-dir $output_dir \
--llm-model-name-or-path $llm_path/merged_hf_model \
--token2wav-path $token2wav_path \
--split-name ${dataset} || exit 1
"""
import argparse
import json
import os
import sys
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from cosyvoice.cli.cosyvoice import CosyVoice2
from cosyvoice.utils.file_utils import load_wav
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
import soundfile as sf
import s3tokenizer
from functools import partial
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501
def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
):
"""
Generate audio from tokens with optional tone and prompt embedding.
"""
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
"empty", prompt_text, prompt_speech_16k, 24000
)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
codec_decoder.model.device
),
prompt_token_len=torch.tensor(
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
codec_decoder.model.device
),
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
codec_decoder.model.device
),
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
finalize=True,
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def extract_speech_ids(speech_tokens_str):
"""Extract speech IDs from token strings like <|s_23456|>"""
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith('<|s_') and token_str.endswith('|>'):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
speech_id_str = ""
for token in cosy2_tokens:
speech_id_str += f"<|s_{token}|>"
return speech_id_str
def get_args():
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch-size",
default=1,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num-workers", type=int, default=1, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=5, help="prefetch for dataloader"
)
parser.add_argument(
"--llm-model-name-or-path",
required=True,
type=str,
help="LLM model path (includes both model and tokenizer)",
)
parser.add_argument(
"--token2wav-path",
required=True,
type=str,
help="CosyVoice2 token2wav model path",
)
parser.add_argument(
"--prompt-text",
type=str,
default=None,
help="The prompt text for CosyVoice2",
)
parser.add_argument(
"--prompt-speech-path",
type=str,
default=None,
help="The path to the prompt speech for CosyVoice2",
)
parser.add_argument(
"--top-p",
type=float,
default=0.95,
help="top p for sampling",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="temperature for sampling",
)
parser.add_argument(
"--top-k",
type=int,
default=50,
help="top k for sampling",
)
args = parser.parse_args()
return args
def data_collator(batch, tokenizer, s3_tokenizer):
"""Simplified data collator for batch_size=1 processing"""
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
mels, prompt_audio_cosy2tokens_list = [], []
for item in batch:
prompt_text, target_text = (
item["prompt_text"],
item["target_text"],
)
prompt_text_list.append(prompt_text)
# Combine prompt and target text
full_text = prompt_text + target_text
# get prompt audio for CosyVoice2 (convert to 16kHz)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
# ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
print(ref_audio_org.shape)
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
prompt_audio_list.append(ref_audio)
if "prompt_audio_cosy2_tokens" in item:
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
else:
# convert to float first
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
if len(mels) > 0:
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
for i in range(len(codes)):
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
# Create chat template for LLM generation
chat = [
{"role": "user", "content": full_text},
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
]
if 'system' in tokenizer.chat_template:
tokenizer.chat_template = TEMPLATE
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
)
input_ids_list.append(input_ids.squeeze(0))
# For batch_size=1, no need to pad
if len(input_ids_list) == 1:
input_ids = input_ids_list[0].unsqueeze(0)
else:
# Handle batch > 1 if needed
max_len = max([len(input_ids) for input_ids in input_ids_list])
input_ids_list = [
torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
for input_ids in input_ids_list
]
input_ids = torch.stack(input_ids_list)
ids = [item["id"] for item in batch]
return {
"input_ids": input_ids,
"ids": ids,
"prompt_text": prompt_text_list,
"prompt_audio_list": prompt_audio_list,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
# Load LLM model and tokenizer directly
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
model.eval()
model.to(device)
cosyvoice_codec = CosyVoice2(
args.token2wav_path, load_jit=True, load_trt=True, fp16=True
)
if args.prompt_speech_path:
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
else:
prompt_speech_16k = None
s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
dataset = load_dataset(
dataset_name,
split=args.split_name,
trust_remote_code=True,
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
)
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
for batch in dataloader:
with torch.no_grad():
input_ids = batch["input_ids"].to(device)
# Generate speech tokens using LLM
outputs = model.generate(
input_ids,
max_new_tokens=2048, # Max length for generation
do_sample=True,
top_p=args.top_p,
temperature=args.temperature,
top_k=args.top_k,
)
# Process each sample in the batch
for i in range(len(batch["ids"])):
# Extract generated tokens (excluding input)
input_length = input_ids[i].shape[0]
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Extract speech IDs from token strings like <|s_23456|>
speech_ids = extract_speech_ids(speech_tokens_str)
if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
continue
# Convert to tensor for CosyVoice2
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
if args.prompt_text is not None:
current_prompt_text = args.prompt_text
current_prompt_audio = prompt_speech_16k
else:
current_prompt_text = batch["prompt_text"][i]
current_prompt_audio = batch["prompt_audio_list"][i]
if current_prompt_audio is not None:
# Generate audio using CosyVoice2
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
current_prompt_text,
current_prompt_audio,
cosyvoice_codec,
)
# Convert to numpy and save
generated_wave = audio_hat.squeeze(0).cpu().numpy()
target_sample_rate = 24000
utt = batch["ids"][i]
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
else:
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
if rank == 0:
progress_bar.close()
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,86 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the Text to Speech dataset to parquet format
"""
import argparse
import os
import re
import datasets
from verl.utils.hdfs_io import copy, makedirs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
parser.add_argument("--test_file", required=True, help="Path to test JSON/JSONL file")
parser.add_argument("--local_dir", default=None, required=True)
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
# Load datasets from local JSON files
train_dataset = datasets.load_dataset("json", data_files=args.train_file)['train']
test_dataset = datasets.load_dataset("json", data_files=args.test_file)['train']
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
text = example.pop("text")
# use cosyvoice2 official huggingface compatible checkpoint template
question = text
answer = ""
data = {
"data_source": f"{args.train_file}_{args.test_file}", # Use file names as data source
"prompt": [
{
"role": "user",
"content": question,
},
{
"role": "assistant",
"content": answer,
},
],
"ability": "text-to-speech",
"reward_model": {"style": "rule", "ground_truth": text},
"extra_info": {
"split": split,
"index": idx,
"text": text,
},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
print(train_dataset)
print(test_dataset)
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)

View File

@@ -0,0 +1,133 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage: Instruct TTS
python3 infer.py \
--token2wav-path /workspace/CosyVoice2-0.5B \
--prompt-text "吃燕窝就选燕之屋本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝营养更均衡本节目由豆本豆豆奶特约播出。" \
--prompt-speech-path ./assets/prompt_audio.wav \
--model-path ./transformers_cosyvoice2_llm \
--input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。"
"""
from cosyvoice.cli.cosyvoice import CosyVoice2
import sys
from argparse import ArgumentParser
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def get_args():
parser = ArgumentParser()
parser.add_argument(
"--pretrained-cosyvoice2-path",
type=str,
default="/workspace/CosyVoice2-0.5B",
help="Token2Wav path, default to %(default)r",
)
parser.add_argument(
"--save-path",
type=str,
default='./transformers_cosyvoice2_llm',
help="The path to save the model",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
cosy2_model = CosyVoice2(
args.pretrained_cosyvoice2_path, load_jit=False, load_trt=False, fp16=False
)
llm = cosy2_model.model.llm.llm.model
speech_embedding = cosy2_model.model.llm.speech_embedding
llm_decoder = cosy2_model.model.llm.llm_decoder
llm_embedding = cosy2_model.model.llm.llm_embedding
tokenizer = AutoTokenizer.from_pretrained(f"{args.pretrained_cosyvoice2_path}/CosyVoice-BlankEN")
special_tokens = {
'eos_token': '<|endoftext|>',
'pad_token': '<|endoftext|>',
'additional_special_tokens': [
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
'[breath]', '<strong>', '</strong>', '[noise]',
'[laughter]', '[cough]', '[clucking]', '[accent]',
'[quick_breath]',
"<laughter>", "</laughter>",
"[hissing]", "[sigh]", "[vocalized-noise]",
"[lipsmack]", "[mn]"
]
}
tokenizer.add_special_tokens(special_tokens)
original_tokenizer_vocab_size = len(tokenizer)
cosyvoice2_token_size = 6561
new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
"<|eos1|>", "<|eos2|>", "<|eos3|>", "<|sos|>", "<|task_id|>"
]
num_added_tokens = tokenizer.add_tokens(new_tokens)
llm.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128)
vocab_size = llm.get_input_embeddings().weight.shape[0]
feature_size = speech_embedding.embedding_dim
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=vocab_size, bias=True)
with torch.no_grad():
# set the weight and bias of the new lm_head to 0
new_lm_head.weight.data.zero_()
# make bias value -inf
new_lm_head.bias.data.fill_(-float('inf'))
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
llm.lm_head = new_lm_head
input_embeddings = llm.get_input_embeddings()
with torch.no_grad():
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
llm.generation_config.eos_token_id = eos_token_ids
llm.generation_config.temperature = 1.0
llm.generation_config.top_p = 0.8
llm.generation_config.top_k = 25
llm.config.eos_token_id = original_tokenizer_vocab_size + cosyvoice2_token_size
llm.config.vocab_size = vocab_size
llm.config.tie_word_embeddings = False
llm.config.use_bias = True
llm.to(torch.bfloat16)
llm.save_pretrained(args.save_path)
TEMPLATE = (
"{%- for message in messages %}"
"{%- if message['role'] == 'user' %}"
"{{- '<|sos|>' + message['content'] + '<|task_id|>' }}"
"{%- elif message['role'] == 'assistant' %}"
"{{- message['content']}}"
"{%- endif %}"
"{%- endfor %}"
)
tokenizer.chat_template = TEMPLATE
tokenizer.save_pretrained(args.save_path)

View File

@@ -0,0 +1,31 @@
conformer==0.3.2
diffusers==0.29.0
gdown==5.1.0
gradio
hydra-core==1.3.2
HyperPyYAML==1.2.2
inflect==7.3.1
librosa==0.10.2
lightning==2.2.4
matplotlib==3.7.5
modelscope==1.15.0
networkx==3.1
omegaconf==2.3.0
onnx==1.16.0
onnxruntime-gpu==1.18.0
protobuf==4.25
pydantic==2.7.0
pyworld==0.3.4
rich==13.7.1
soundfile==0.12.1
tensorboard==2.14.0
wget==3.2
WeTextProcessing==1.0.3
s3tokenizer
tensorrt
sherpa_onnx
jiwer
zhon
numpy==1.25.2
pypinyin
openai-whisper

View File

@@ -0,0 +1,233 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reward calculation for CosyVoice2-0.5B.
"""
from __future__ import annotations
import re
import json
import time
import argparse
from typing import List
import numpy as np
import requests
REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
def _parse_ids(token_str: str) -> List[int]:
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
"""Send token IDs and ground-truth text to the Triton server and get reward."""
tokens_arr = np.array(tokens, dtype=np.int32).reshape(1, -1)
lens_arr = np.array([[tokens_arr.shape[1]]], dtype=np.int32)
gt_arr = np.array([ground_truth.encode("utf-8")], dtype=object)
payload = {
"inputs": [
{
"name": "TOKENS",
"shape": list(tokens_arr.shape),
"datatype": "INT32",
"data": tokens_arr.tolist(),
},
{
"name": "TOKEN_LENS",
"shape": list(lens_arr.shape),
"datatype": "INT32",
"data": lens_arr.tolist(),
},
{
"name": "GT_TEXT",
"shape": [1, 1],
"datatype": "BYTES",
"data": [ground_truth],
},
]
}
rsp = requests.post(
REWARD_SERVER_URL,
headers={"Content-Type": "application/json"},
json=payload,
timeout=timeout,
verify=False,
params={"request_id": "0"},
)
rsp.raise_for_status()
result = rsp.json()
try:
# Reward is returned as the first output
return float(result["outputs"][0]["data"][0])
except (KeyError, IndexError, TypeError):
return 0.0
def compute_score(
data_source: str,
solution_str: str,
ground_truth: str,
extra_info: dict | None = None,
*,
debug_dump: bool = False,
) -> float:
"""Return reward in [0, 1] using the Triton ASR service.
The reward is based on the pinyin-level WER between the ASR transcript
produced from *solution_str* and the provided *ground_truth* text.
"""
# Decode token IDs
ids = _parse_ids(solution_str)
# Query remote server for reward
try:
reward = _remote_reward(ids, ground_truth)
except Exception as e:
reward = 0.0
if debug_dump:
print(
f"\033[92m[{data_source}] Remote reward: {reward:.4f}\033[0m"
)
return reward
# CLI quick test
if __name__ == "__main__":
import sys
def get_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Test TTS CER scoring with data from JSONL file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input", "-i",
type=str,
default="data/emilia_zh-cosy-tiny-test.jsonl",
help="Path to input JSONL file"
)
parser.add_argument(
"--max-samples", "-n",
type=int,
default=None,
help="Maximum number of samples to process (default: all)"
)
parser.add_argument(
"--no-interactive",
action="store_true",
help="Run in non-interactive mode (process all samples without prompts)"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode"
)
return parser.parse_args()
def load_jsonl(file_path: str):
"""Load data from jsonl file."""
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line.strip()))
return data
def code_to_solution_str(code_list: List[int]) -> str:
"""Convert code list to solution string format."""
return ''.join([f"<|s_{code}|>" for code in code_list])
# Parse command line arguments
args = get_args()
try:
# Load data from jsonl file
print(f"Loading data from: {args.input}")
data_list = load_jsonl(args.input)
print(f"Loaded {len(data_list)} samples")
# Limit samples if specified
if args.max_samples is not None:
data_list = data_list[:args.max_samples]
print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
# Process each sample
begin_time = time.time()
for i, sample in enumerate(data_list):
print(f"\n--- Sample {i+1}/{len(data_list)} ---")
print(f"Index: {sample.get('index', 'unknown')}")
print(f"Text: {sample['text']}")
# Extract required fields
code_list = sample['code']
ground_truth = sample['text']
data_source = sample.get('index', f'sample_{i}') # Use index as data_source
# Convert code list to solution string
solution_str = code_to_solution_str(code_list)
print(f"Solution tokens: {len(code_list)} tokens")
if args.debug:
print(f"Solution string: {solution_str}")
else:
print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
# Call compute_score function
try:
score = compute_score(
data_source=data_source,
solution_str=solution_str,
ground_truth=ground_truth,
extra_info=None,
debug_dump=args.debug
)
print(f"Final Score: {score:.4f}")
except Exception as e:
print(f"Error computing score: {e}")
# Ask user if they want to continue (for interactive mode)
if not args.no_interactive and i < len(data_list) - 1:
try:
response = input("\nPress Enter to continue or 'q' to quit: ").strip().lower()
if response == 'q':
break
except KeyboardInterrupt:
print("\nStopped by user")
break
print(f"\nProcessed {min(i+1, len(data_list))} samples")
end_time = time.time()
print(f"Time taken: {end_time - begin_time} seconds")
except FileNotFoundError:
print(f"Error: File not found - {args.input}")
print("Please check the file path or use --input to specify correct path")
print("Run with --help for usage information")
except Exception as e:
print(f"Error: {e}")

View File

@@ -0,0 +1,159 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=4
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export PYTHONPATH=/workspace/CosyVoice
model_scope_model_path=./CosyVoice2-0.5B
sft_model_path=./transformers_cosyvoice2_llm
if [ $stage -le -2 ] && [ $stop_stage -ge -2 ]; then
log "stage -2: install dependencies locally if pre-built docker image is not available"
conda create -n cosyvoice2 python=3.10 -y
conda activate cosyvoice2
# install verl
git clone https://github.com/yuekaizhang/verl.git -b thread
cd verl
USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh
pip install --no-deps -e .
cd -
# install requirements
pip install -r requirements.txt
pip install -U nvidia-pytriton
git clone https://github.com/yuekaizhang/PytritonSenseVoice.git && cd PytritonSenseVoice && pip install -e .
fi
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
python3 pretrained_to_huggingface.py \
--pretrained-cosyvoice2-path $model_scope_model_path \
--save-path $sft_model_path
# Or, you could use the following command to download the huggingface compatible checkpoint
# huggingface-cli download --local-dir $sft_model_path yuekai/cosyvoice2_llm
# Note: we remove the lm_head's bias to make it compatible with the Qwen2.5-0.5B model in Transformers.
fi
data_dir=data/parquet_aishell3
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: prepare data into verl format"
mkdir -p $data_dir
wget -O data/aishell-3.jsonl https://huggingface.co/datasets/SparkAudio/voxbox/resolve/main/metadata/aishell-3.jsonl
# total 88035 samples
head -n 80000 data/aishell-3.jsonl > data/train.jsonl
tail -n 100 data/aishell-3.jsonl > data/test.jsonl
python prepare_data.py \
--train_file data/train.jsonl \
--test_file data/test.jsonl \
--local_dir $data_dir
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: start token2wav asr server for reward function"
python3 token2wav_asr_server.py --number-of-devices 8
fi
exp_name=official_llm_aishell3_grpo
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "stage 2: grpo train"
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
export MKL_SERVICE_FORCE_INTEL=TRUE
n_gpus_per_node=8
micro_batch_size=4
train_batch_size=32
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$data_dir/train.parquet \
data.val_files=$data_dir/test.parquet \
data.train_batch_size=$train_batch_size \
data.max_prompt_length=1024 \
data.max_response_length=512 \
data.truncation='error' \
actor_rollout_ref.model.use_remove_padding=False \
actor_rollout_ref.model.path=$sft_model_path \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.do_sample=true \
actor_rollout_ref.rollout.temperature=0.8 \
actor_rollout_ref.rollout.top_p=0.95 \
actor_rollout_ref.rollout.top_k=25 \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.rollout.val_kwargs.do_sample=true \
actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
actor_rollout_ref.rollout.val_kwargs.top_k=25 \
reward_model.reward_manager=prime \
custom_reward_function.path=reward_tts.py \
custom_reward_function.name=compute_score \
trainer.project_name='cosyvoice2_grpo' \
trainer.experiment_name=$exp_name \
trainer.logger=['console','wandb'] \
trainer.n_gpus_per_node=$n_gpus_per_node \
trainer.nnodes=1 \
trainer.save_freq=100 \
trainer.test_freq=100 \
trainer.resume_mode='auto' \
trainer.total_epochs=1 \
trainer.val_before_train=False
fi
steps=(100 200 300 400 500)
for step in ${steps[@]}; do
llm_path=./checkpoints/cosyvoice2_grpo/$exp_name/global_step_${step}
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: merge the model"
python -m verl.model_merger merge \
--backend fsdp \
--local_dir $llm_path/actor \
--target_dir $llm_path/merged_hf_model || exit 1
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: Test the model"
dataset=zero_shot_zh # from CosyVoice3 test set
# dataset=test_zh # from seed_tts test set
output_dir=./outputs_${exp_name}_${step}_${dataset}
token2wav_path=/workspace/CosyVoice2-0.5B
model_path=$llm_path/merged_hf_model
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --nproc_per_node=8 \
infer_dataset.py \
--output-dir $output_dir \
--llm-model-name-or-path $model_path \
--token2wav-path $token2wav_path \
--split-name ${dataset} || exit 1
bash scripts/compute_wer.sh $output_dir ${dataset}
fi
done
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "stage 5: Convert the RL trained model to CosyVoice repo format"
python3 huggingface_to_pretrained.py \
--hf-cosyvoice2-llm-path $llm_path/merged_hf_model \
--output-path /workspace/CosyVoice2-0.5B/llm-new.pt
# You need to manually move the llm-new.pt to overwrite /workspace/CosyVoice2-0.5B/llm.pt
# However, we found that the RL trained model accuracy would slightly drop after this conversion.
# Please be careful or use the huggingface format inference code.
fi

View File

@@ -0,0 +1,33 @@
wav_dir=$1
wav_files=$(ls $wav_dir/*.wav)
# if wav_files is empty, then exit
if [ -z "$wav_files" ]; then
exit 1
fi
split_name=$2
model_path=models/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
pip install sherpa-onnx
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
mkdir models
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C models
fi
python3 scripts/offline-decode-files.py \
--tokens=$model_path/tokens.txt \
--paraformer=$model_path/model.int8.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=24000 \
--log-dir $wav_dir \
--feature-dim=80 \
--split-name $split_name \
--name sherpa_onnx \
$wav_files
# python3 scripts/paraformer-pytriton-client.py \
# --log-dir $wav_dir \
# --split-name $split_name \
# $wav_files

View File

@@ -0,0 +1,754 @@
# Copyright (c) 2023 by manyeyes
# Copyright (c) 2023 Xiaomi Corporation
"""
This file demonstrates how to use sherpa-onnx Python API to transcribe
file(s) with a non-streaming model.
(1) For paraformer
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/paraformer.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=16000 \
--feature-dim=80 \
/path/to/0.wav \
/path/to/1.wav
(2) For transducer models from icefall
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=16000 \
--feature-dim=80 \
/path/to/0.wav \
/path/to/1.wav
(3) For CTC models from NeMo
python3 ./python-api-examples/offline-decode-files.py \
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
(4) For Whisper models
python3 ./python-api-examples/offline-decode-files.py \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--whisper-task=transcribe \
--num-threads=1 \
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
(5) For CTC models from WeNet
python3 ./python-api-examples/offline-decode-files.py \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
(6) For tdnn models of the yesno recipe from icefall
python3 ./python-api-examples/offline-decode-files.py \
--sample-rate=8000 \
--feature-dim=23 \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download non-streaming pre-trained models
used in this file.
"""
import argparse
import time
import wave
from pathlib import Path
from typing import List, Tuple, Dict, Iterable, TextIO, Union
import numpy as np
import sherpa_onnx
import soundfile as sf
from datasets import load_dataset
import logging
from collections import defaultdict
import kaldialign
from zhon.hanzi import punctuation
import string
punctuation_all = punctuation + string.punctuation
Pathlike = Union[str, Path]
def remove_punctuation(text: str) -> str:
for x in punctuation_all:
if x == '\'':
continue
text = text.replace(x, '')
return text
def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None:
"""Save predicted results and reference transcripts to a file.
Args:
filename:
File to save the results to.
texts:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
If it is a multi-talker ASR system, the ref and hyp may also be lists of
strings.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf8") as f:
for cut_id, ref, hyp in texts:
if char_level:
ref = list("".join(ref))
hyp = list("".join(hyp))
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
def write_error_stats(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
compute_CER: bool = False,
sclite_mode: bool = False,
) -> float:
"""Write statistics based on predicted results and reference transcripts.
It will write the following to the given file:
- WER
- number of insertions, deletions, substitutions, corrects and total
reference words. For example::
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct)
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
FOR THE FIRST DAY (SIR->*) I THINK
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the cut_id, the second is
the reference transcript and the third element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return None.
"""
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
if compute_CER:
for i, res in enumerate(results):
cut_id, ref, hyp = res
ref = list("".join(ref))
hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp)
for _cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
f"{cut_id}:\t"
+ " ".join(
(
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--hotwords-file",
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, like
HELLO WORLD
你好世界
""",
)
parser.add_argument(
"--hotwords-score",
type=float,
default=1.5,
help="""
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
parser.add_argument(
"--modeling-unit",
type=str,
default="",
help="""
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
Used only when hotwords-file is given.
""",
)
parser.add_argument(
"--bpe-vocab",
type=str,
default="",
help="""
The path to the bpe vocabulary, the bpe vocabulary is generated by
sentencepiece, you can also export the bpe vocabulary through a bpe model
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
and modeling-unit is bpe or cjkchar+bpe.
""",
)
parser.add_argument(
"--encoder",
default="",
type=str,
help="Path to the encoder model",
)
parser.add_argument(
"--decoder",
default="",
type=str,
help="Path to the decoder model",
)
parser.add_argument(
"--joiner",
default="",
type=str,
help="Path to the joiner model",
)
parser.add_argument(
"--paraformer",
default="",
type=str,
help="Path to the model.onnx from Paraformer",
)
parser.add_argument(
"--nemo-ctc",
default="",
type=str,
help="Path to the model.onnx from NeMo CTC",
)
parser.add_argument(
"--wenet-ctc",
default="",
type=str,
help="Path to the model.onnx from WeNet CTC",
)
parser.add_argument(
"--tdnn-model",
default="",
type=str,
help="Path to the model.onnx for the tdnn model of the yesno recipe",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--whisper-encoder",
default="",
type=str,
help="Path to whisper encoder model",
)
parser.add_argument(
"--whisper-decoder",
default="",
type=str,
help="Path to whisper decoder model",
)
parser.add_argument(
"--whisper-language",
default="",
type=str,
help="""It specifies the spoken language in the input audio file.
Example values: en, fr, de, zh, jp.
Available languages for multilingual models can be found at
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
If not specified, we infer the language from the input audio file.
""",
)
parser.add_argument(
"--whisper-task",
default="transcribe",
choices=["transcribe", "translate"],
type=str,
help="""For multilingual models, if you specify translate, the output
will be in English.
""",
)
parser.add_argument(
"--whisper-tail-paddings",
default=-1,
type=int,
help="""Number of tail padding frames.
We have removed the 30-second constraint from whisper, so you need to
choose the amount of tail padding frames by yourself.
Use -1 to use a default value for tail padding.
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
parser.add_argument(
"--debug",
type=bool,
default=False,
help="True to show debug messages",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="""Sample rate of the feature extractor. Must match the one
expected by the model. Note: The input sound files can have a
different sample rate from this argument.""",
)
parser.add_argument(
"--feature-dim",
type=int,
default=80,
help="Feature dimension. Must match the one expected by the model",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to decode. Each file must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)
parser.add_argument(
"--name",
type=str,
default="",
help="The directory containing the input sound files to decode",
)
parser.add_argument(
"--log-dir",
type=str,
default="",
help="The directory containing the input sound files to decode",
)
parser.add_argument(
"--label",
type=str,
default=None,
help="wav_base_name label",
)
# Dataset related arguments for loading labels when label file is not provided
parser.add_argument(
"--dataset-name",
type=str,
default="yuekai/seed_tts_cosy2",
help="Huggingface dataset name for loading labels",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
help="Dataset split name for loading labels",
)
return parser.parse_args()
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and can be of type
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples,
which are normalized to the range [-1, 1].
- Sample rate of the wave file.
"""
samples, sample_rate = sf.read(wave_filename, dtype="float32")
assert (
samples.ndim == 1
), f"Expected single channel, but got {samples.ndim} channels."
samples_float32 = samples.astype(np.float32)
return samples_float32, sample_rate
def normalize_text_alimeeting(text: str) -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
import re
text = text.replace('\u00A0', '') # test_hard
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = remove_punctuation(text)
return text
def main():
args = get_args()
assert_file_exists(args.tokens)
assert args.num_threads > 0, args.num_threads
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.paraformer)
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=args.paraformer,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
print("Started!")
start_time = time.time()
streams, results = [], []
total_duration = 0
for i, wave_filename in enumerate(args.sound_files):
assert_file_exists(wave_filename)
samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate
total_duration += duration
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
streams.append(s)
if i % 10 == 0:
recognizer.decode_streams(streams)
results += [s.result.text for s in streams]
streams = []
print(f"Processed {i} files")
# process the last batch
if streams:
recognizer.decode_streams(streams)
results += [s.result.text for s in streams]
end_time = time.time()
print("Done!")
results_dict = {}
for wave_filename, result in zip(args.sound_files, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
wave_basename = Path(wave_filename).stem
results_dict[wave_basename] = result
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {total_duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
# Load labels either from file or from dataset
labels_dict = {}
if args.label:
# Load labels from file (original functionality)
print(f"Loading labels from file: {args.label}")
with open(args.label, "r") as f:
for line in f:
# fields = line.strip().split(" ")
# fields = [item for item in fields if item]
# assert len(fields) == 4
# prompt_text, prompt_audio, text, audio_path = fields
fields = line.strip().split("|")
fields = [item for item in fields if item]
assert len(fields) == 4
audio_path, prompt_text, prompt_audio, text = fields
labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
else:
# Load labels from dataset (new functionality)
print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
if 'zero' in args.split_name:
dataset_name = "yuekai/CV3-Eval"
else:
dataset_name = "yuekai/seed_tts_cosy2"
dataset = load_dataset(
dataset_name,
split=args.split_name,
trust_remote_code=True,
)
for item in dataset:
audio_id = item["id"]
labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
print(f"Loaded {len(labels_dict)} labels from dataset")
# Perform evaluation if labels are available
if labels_dict:
final_results = []
for key, value in results_dict.items():
if key in labels_dict:
final_results.append((key, labels_dict[key], value))
else:
print(f"Warning: No label found for {key}, skipping...")
if final_results:
store_transcripts(
filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
)
with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
write_error_stats(f, "test-set", final_results, enable_log=True)
with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
print(f.readline()) # WER
print(f.readline()) # Detailed errors
else:
print("No matching labels found for evaluation")
else:
print("No labels available for evaluation")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,346 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pytriton server for token2wav conversion and ASR"""
from datasets import load_dataset
from cosyvoice.cli.cosyvoice import CosyVoice2
from omnisense.models import OmniSenseVoiceSmall
from pytriton.proxy.types import Request
from pytriton.triton import Triton, TritonConfig
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
from pytriton.decorators import batch
import argparse
import io
import logging
from typing import Any, List
import numpy as np
import torch
from scipy.signal import resample
import sys
import random
import re
from jiwer import wer
from pypinyin import lazy_pinyin, Style
from tn.chinese.normalizer import Normalizer as ZhNormalizer
# Chinese text normalizer (cached globally)
zh_tn_model = ZhNormalizer(
cache_dir="./cache",
remove_erhua=False,
remove_interjections=False,
remove_puncts=True,
overwrite_cache=True,
)
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
logger = logging.getLogger("token2wav_asr_server")
class _ASR_Server:
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
def __init__(self, device_id: int):
self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
@batch
def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray):
"""
WAV: np.ndarray, WAV_LENS: np.ndarray
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
"""
logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape)
wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))]
results = self._model.transcribe_single_batch(
wavs,
language="zh",
textnorm="woitn",
)
texts = [result.text for result in results]
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
return {"TRANSCRIPTS": transcripts}
def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
):
"""
Generate audio from tokens with optional tone and prompt embedding.
"""
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
"empty", prompt_text, prompt_speech_16k, 24000
)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
codec_decoder.model.device
),
prompt_token_len=torch.tensor(
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
codec_decoder.model.device
),
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
codec_decoder.model.device
),
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
finalize=True,
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def get_random_prompt_from_dataset(dataset):
"""
Get random prompt text and speech from the pre-loaded dataset.
Returns (prompt_text, prompt_speech_16k)
"""
random_idx = random.randint(0, len(dataset) - 1)
sample = dataset[random_idx]
# Extract audio data
audio_data = sample["audio"]
audio_array = audio_data["array"]
sample_rate = audio_data["sampling_rate"]
# Convert audio to 16kHz if needed
if sample_rate != 16000:
num_samples = int(len(audio_array) * (16000 / sample_rate))
audio_array = resample(audio_array, num_samples)
# Convert to torch tensor
prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0)
prompt_text = sample["text"]
# remove space in prompt_text
prompt_text = prompt_text.replace(" ", "")
return prompt_text, prompt_speech_16k
class _Token2Wav_ASR:
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
def __init__(self, device_id: int):
self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"]
# Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model
# CosyVoice2 internally uses generic "cuda" device, so we first switch the
# current CUDA context to the desired card before the object is created.
# Afterwards, all parameters loaded with the generic "cuda" device will
# reside on this GPU. We keep the selected id in `self.device_id` and
# will set the context again for every forward call to avoid race
# conditions when several instances are used in the same process.
self.device_id = device_id
# Construct the TTS codec decoder under the correct CUDA device context
with torch.cuda.device(self.device_id):
self.codec_decoder = CosyVoice2(
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
)
@batch
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
"""
WAV: np.ndarray, WAV_LENS: np.ndarray
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
"""
# Ensure the default CUDA device is set correctly for this invocation
torch.cuda.set_device(self.device_id)
if self.device_id == 0:
print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")
tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]
# Decode ground-truth text strings (BYTES → str)
if GT_TEXT.ndim == 2:
gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
else:
gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]
wavs = []
for tokens in tokens_list:
prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
prompt_text,
prompt_speech_16k,
self.codec_decoder,
)
# resample to 16000 using soundfile
audio_hat = audio_hat.squeeze(0).float().cpu()
audio_hat = audio_hat.numpy()
num_samples = int(len(audio_hat) * (16000 / 24000))
audio_hat = resample(audio_hat, num_samples)
wavs.append(audio_hat)
results = self.asr_model.transcribe_single_batch(
wavs,
language="zh",
textnorm="woitn",
)
texts = [result.text for result in results]
# ---------------- Reward computation ----------------
rewards = []
for gt_text, hyp_text in zip(gt_texts, texts):
gt_norm = zh_tn_model.normalize(gt_text).lower()
hyp_norm = zh_tn_model.normalize(hyp_text).lower()
gt_pinyin = lazy_pinyin(
gt_norm,
style=Style.TONE3,
tone_sandhi=True,
neutral_tone_with_five=True,
)
hyp_pinyin = lazy_pinyin(
hyp_norm,
style=Style.TONE3,
tone_sandhi=True,
neutral_tone_with_five=True,
)
c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin)))
reward_val = 1.0 - np.tanh(3.0 * c)
reward_val = max(0.0, min(1.0, reward_val))
rewards.append(reward_val)
print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}")
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
def _infer_function_factory(device_ids: List[int], model_name: str):
"""Creates a list of inference functions, one for each requested device ID."""
infer_funcs = []
for device_id in device_ids:
if model_name == "sensevoice":
infer_funcs.append(_ASR_Server(device_id=device_id))
else:
infer_funcs.append(_Token2Wav_ASR(device_id=device_id))
return infer_funcs
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--max-batch-size",
type=int,
default=32,
help="Batch size of request.",
required=False,
)
parser.add_argument(
"--verbose",
action="store_true",
default=False,
)
parser.add_argument(
"--number-of-instances-per-device",
type=int,
default=1,
help="Number of model instances to load.",
required=False,
)
parser.add_argument(
"--number-of-devices",
type=int,
default=8,
help="Number of devices to use.",
)
parser.add_argument(
"--model-name",
type=str,
default="token2wav_asr",
choices=["token2wav_asr", "sensevoice"],
help="Model name.",
)
args = parser.parse_args()
log_level = logging.DEBUG if args.verbose else logging.INFO
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
triton_config = TritonConfig(
http_port=8000,
grpc_port=8001,
metrics_port=8002,
)
device_ids = list(range(args.number_of_devices))
device_ids = device_ids * args.number_of_instances_per_device
with Triton(config=triton_config) as triton:
logger.info("Loading SenseVoice model on device ids: %s", device_ids)
if args.model_name == "sensevoice":
triton.bind(
model_name="sensevoice",
infer_func=_infer_function_factory(device_ids, args.model_name),
inputs=[
Tensor(name="WAV", dtype=np.float32, shape=(-1,)),
Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)),
Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)),
Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)),
],
outputs=[
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
],
config=ModelConfig(
max_batch_size=args.max_batch_size,
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
),
strict=True,
)
else:
triton.bind(
model_name="token2wav_asr",
infer_func=_infer_function_factory(device_ids, args.model_name),
inputs=[
Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)),
Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)),
Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)),
],
outputs=[
Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)),
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
],
config=ModelConfig(
max_batch_size=args.max_batch_size,
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
),
strict=True,
)
logger.info("Serving inference")
triton.serve()
if __name__ == "__main__":
main()

View File

@@ -1,257 +0,0 @@
# set random seed, so that you may reproduce your result.
__set_seed1: !apply:random.seed [1986]
__set_seed2: !apply:numpy.random.seed [1986]
__set_seed3: !apply:torch.manual_seed [1986]
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
# fixed params
sample_rate: 22050
text_encoder_input_size: 512
llm_input_size: 1024
llm_output_size: 1024
spk_embed_dim: 192
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:cosyvoice.llm.llm.TransformerLM
text_encoder_input_size: !ref <text_encoder_input_size>
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
text_token_size: 51866 # change to 60515 if you want to train with CosyVoice-300M-25Hz recipe
speech_token_size: 4096
length_normalized_loss: True
lsm_weight: 0
spk_embed_dim: !ref <spk_embed_dim>
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
input_size: !ref <text_encoder_input_size>
output_size: 1024
attention_heads: 8
linear_units: 2048
num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
use_cnn_module: False
macaron_style: False
use_dynamic_chunk: False
use_dynamic_left_chunk: False
static_chunk_size: 1
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
input_size: !ref <llm_input_size>
output_size: !ref <llm_output_size>
attention_heads: 8
linear_units: 2048
num_blocks: 7
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: 'linear_legacy'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
static_chunk_size: 1
sampling: !name:cosyvoice.utils.common.ras_sampling
top_p: 0.8
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
input_size: 512
output_size: 80
spk_embed_dim: !ref <spk_embed_dim>
output_type: 'mel'
vocab_size: 4096
input_frame_rate: 50 # change to 25 if you want to train with CosyVoice-300M-25Hz recipe
only_mask_loss: True
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
output_size: 512
attention_heads: 4
linear_units: 1024
num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 512
use_cnn_module: False
macaron_style: False
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
channels: 80
sampling_ratios: [1, 1, 1, 1]
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
in_channels: 240
n_spks: 1
spk_emb_dim: 80
cfm_params: !new:omegaconf.DictConfig
content:
sigma_min: 1e-06
solver: 'euler'
t_scheduler: 'cosine'
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
in_channels: 320
out_channels: 80
channels: [256, 256]
dropout: 0.0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 8
num_heads: 8
act_fn: 'gelu'
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
sampling_rate: !ref <sample_rate>
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 8]
upsample_kernel_sizes: [16, 16]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
# gan related module
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
fmin: 0
fmax: null
center: False
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
generator: !ref <hift>
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
mel_spec_transform: [
!ref <mel_spec_transform1>
]
# processor functions
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
multilingual: True
num_languages: 100
language: 'en'
task: 'transcribe'
allowed_special: 'all'
tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample: !name:cosyvoice.dataset.processor.resample
resample_rate: !ref <sample_rate>
truncate: !name:cosyvoice.dataset.processor.truncate
truncate_length: 24576 # must be a multiplier of hop_size
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
fmin: 0
fmax: 8000
center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate>
hop_size: 256
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
normalize: True
shuffle: !name:cosyvoice.dataset.processor.shuffle
shuffle_size: 1000
sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 12000
padding: !name:cosyvoice.dataset.processor.padding
use_spk_embedding: False # change to True during sft
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <compute_fbank>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
data_pipeline_gan: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <truncate>,
!ref <compute_fbank>,
!ref <compute_f0>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
# llm flow train conf
train_conf:
optim: adam
optim_conf:
lr: 0.002 # change to 0.001 if you want to train flow from scratch
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
max_epoch: 200
grad_clip: 5
accum_grad: 2
log_interval: 100
save_per_step: -1
# gan train conf
train_conf_gan:
optim: adam
optim_conf:
lr: 0.0002 # use small lr for gan training
scheduler: constantlr
optim_d: adam
optim_conf_d:
lr: 0.0002 # use small lr for gan training
scheduler_d: constantlr
max_epoch: 200
grad_clip: 5
accum_grad: 1 # in gan training, accum_grad must be 1
log_interval: 100
save_per_step: -1

View File

@@ -1 +0,0 @@
../../../cosyvoice

View File

@@ -40,6 +40,10 @@ def main():
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
for k, v in spk2utt.items():
f.write('{} {}\n'.format(k, ' '.join(v)))
if args.instruct != '':
with open('{}/instruct'.format(args.des_dir), 'w') as f:
for k, v in utt2text.items():
f.write('{} {}\n'.format(k, args.instruct))
return
@@ -49,7 +53,8 @@ if __name__ == "__main__":
type=str)
parser.add_argument('--des_dir',
type=str)
parser.add_argument('--ref_model',
type=str)
parser.add_argument('--instruct',
type=str,
default='')
args = parser.parse_args()
main()

View File

@@ -27,7 +27,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \
../../../tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx
done
fi
@@ -35,7 +35,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \
../../../tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
done
fi
@@ -44,7 +44,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \
--src_dir data/$x \
--des_dir data/$x/parquet
@@ -60,7 +60,7 @@ num_workers=2
prefetch=100
train_engine=torch_ddp
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
echo "Run train. We only support llm traning for now"
if [ $train_engine == 'deepspeed' ]; then
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
fi
@@ -69,7 +69,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
for model in llm flow hifigan; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \
../../../cosyvoice/bin/train.py \
--train_engine $train_engine \
--config conf/cosyvoice.yaml \
--train_data data/train.data.list \

View File

@@ -1 +0,0 @@
../../../tools

View File

@@ -139,7 +139,7 @@ tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960
max_length: 6000
min_length: 100
token_max_length: 200
token_min_length: 1
@@ -158,7 +158,9 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
token_mel_ratio: 2
num_frames: 960
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
num_frames: 960
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate>
hop_size: 480
@@ -183,6 +185,7 @@ data_pipeline: [
!ref <resample>,
!ref <compute_fbank>,
!ref <parse_embedding>,
!ref <compute_whisper_fbank>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,

View File

@@ -1 +0,0 @@
../../../cosyvoice

View File

@@ -24,27 +24,12 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
done
fi
# NOTE embedding/token extraction is not necessary now as we support online feature extraction
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \
--src_dir data/$x \
--des_dir data/$x/parquet
@@ -60,22 +45,22 @@ num_workers=2
prefetch=100
train_engine=torch_ddp
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
echo "Run train. We only support llm traning for now"
if [ $train_engine == 'deepspeed' ]; then
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
fi
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
# NOTE will update llm/hift training later
for model in llm flow hifigan; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \
../../../cosyvoice/bin/train.py \
--train_engine $train_engine \
--config conf/cosyvoice2.yaml \
--train_data data/train.data.list \
--cv_data data/dev.data.list \
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
--onnx_path $pretrained_model_dir \
--model $model \
--checkpoint $pretrained_model_dir/$model.pt \
--model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \

View File

@@ -36,7 +36,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \
../../../tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx
done
fi
@@ -44,7 +44,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 train-clean-100_reject train-clean-360_reject dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \
../../../tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
done
fi
@@ -53,7 +53,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \
--dpo \
--src_dir data/$x \
@@ -70,7 +70,7 @@ num_workers=2
prefetch=100
train_engine=torch_ddp
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
echo "Run train. We only support llm traning for now"
if [ $train_engine == 'deepspeed' ]; then
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
fi
@@ -80,11 +80,12 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
for model in llm; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \
../../../cosyvoice/bin/train.py \
--train_engine $train_engine \
--config conf/cosyvoice2.yaml \
--train_data data/train.data.list \
--cv_data data/dev.data.list \
--onnx_path $pretrained_model_dir \
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
--model $model \
--checkpoint $pretrained_model_dir/$model.pt \

View File

@@ -1 +0,0 @@
../../../tools

View File

@@ -5,22 +5,28 @@ __set_seed3: !apply:torch.manual_seed [1986]
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
# fixed params
sample_rate: 24000 # 16000 for llm, 24000 for cfm
sample_rate: 24000
llm_input_size: 896
llm_output_size: 896
spk_embed_dim: 192
qwen_pretrain_path: 'CosyVoice2-0.5B/CosyVoice-BlankEN'
qwen_pretrain_path: ''
token_frame_rate: 25
token_mel_ratio: 2
# stream related params
chunk_size: 25 # streaming inference chunk size, in token
num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:cosyvoice.llm.llm_dpo.Qwen2LM
llm: !new:cosyvoice.llm.llm.CosyVoice3LM
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
speech_token_size: 6561
length_normalized_loss: True
lsm_weight: 0
dpo: True
mix_ratio: [5, 15]
llm: !new:cosyvoice.llm.llm.Qwen2Encoder
pretrain_path: !ref <qwen_pretrain_path>
sampling: !name:cosyvoice.utils.common.ras_sampling
@@ -28,31 +34,21 @@ llm: !new:cosyvoice.llm.llm_dpo.Qwen2LM
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
input_size: 512
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithDiT
input_size: 80
output_size: 80
spk_embed_dim: !ref <spk_embed_dim>
output_type: 'mel'
vocab_size: 6561
input_frame_rate: 25
input_frame_rate: !ref <token_frame_rate>
only_mask_loss: True
token_mel_ratio: 2
token_mel_ratio: !ref <token_mel_ratio>
pre_lookahead_len: 3
pre_lookahead_layer: !new:cosyvoice.transformer.upsample_encoder.PreLookaheadLayer
in_channels: 80
channels: 1024
pre_lookahead_len: 3
encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
output_size: 512
attention_heads: 8
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 512
use_cnn_module: False
macaron_style: False
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
in_channels: 240
n_spks: 1
@@ -65,19 +61,20 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
in_channels: 320
estimator: !new:cosyvoice.flow.DiT.dit.DiT
dim: 1024
depth: 22
heads: 16
dim_head: 64
ff_mult: 2
mel_dim: 80
mu_dim: 80
spk_dim: 80
out_channels: 80
causal: True
channels: [256]
dropout: 0.0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 12
num_heads: 8
act_fn: 'gelu'
static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
num_decoding_left_chunks: !ref <num_decoding_left_chunks>
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
hift: !new:cosyvoice.hifigan.generator.CausalHiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
@@ -96,18 +93,19 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
conv_pre_look_right: 4
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.CausalConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
# gan related module
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
n_fft: 1920
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
hop_size: 480
win_size: 1920
fmin: 0
fmax: null
center: False
@@ -115,45 +113,47 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
generator: !ref <hift>
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
mel_spec_transform: [
!ref <mel_spec_transform1>
]
# processor functions
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
multilingual: True
num_languages: 100
language: 'en'
task: 'transcribe'
get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer
token_path: !ref <qwen_pretrain_path>
skip_special_tokens: True
version: cosyvoice3
allowed_special: 'all'
tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960
min_length: 0
max_length: 6000
min_length: 100
token_max_length: 200
token_min_length: 1
resample: !name:cosyvoice.dataset.processor.resample
resample_rate: !ref <sample_rate>
truncate: !name:cosyvoice.dataset.processor.truncate
truncate_length: 24576 # must be a multiplier of hop_size
truncate_length: 24960 # must be a multiplier of hop_size and token_mel_ratio
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
n_fft: 1920
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
hop_size: 480
win_size: 1920
fmin: 0
fmax: 8000
fmax: null
center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
num_frames: 960
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
num_frames: 960
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate>
hop_size: 256
hop_size: 480
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
normalize: True
shuffle: !name:cosyvoice.dataset.processor.shuffle
@@ -162,10 +162,10 @@ sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g
max_frames_in_batch: 2000
padding: !name:cosyvoice.dataset.processor.padding
use_spk_embedding: True # change to True during sft
dpo: True
use_spk_embedding: False # change to True during sft
# dataset processor pipeline
data_pipeline: [
@@ -175,6 +175,7 @@ data_pipeline: [
!ref <resample>,
!ref <compute_fbank>,
!ref <parse_embedding>,
!ref <compute_whisper_fbank>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
@@ -199,10 +200,10 @@ data_pipeline_gan: [
train_conf:
optim: adam
optim_conf:
lr: 0.00001 # change to 1e-5 during sft
scheduler: warmuplr # change to constantlr during sft
lr: 1e-5 # change to 1e-5 during sft
scheduler: constantlr # change to constantlr during sft
scheduler_conf:
warmup_steps: 25000
warmup_steps: 2500
max_epoch: 200
grad_clip: 5
accum_grad: 2

View File

@@ -0,0 +1,42 @@
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,
"gradient_clipping": 5,
"fp16": {
"enabled": false,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 256,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},
"bf16": {
"enabled": false
},
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": false,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients" : true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 0.001,
"weight_decay": 0.0001,
"torch_adam": true,
"adam_w_mode": true
}
}
}

View File

@@ -0,0 +1 @@
../cosyvoice/local

View File

@@ -0,0 +1 @@
../cosyvoice/path.sh

View File

@@ -0,0 +1,97 @@
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
. ./path.sh || exit 1;
stage=-1
stop_stage=3
data_url=www.openslr.org/resources/60
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
pretrained_model_dir=../../../pretrained_models/Fun-CosyVoice3-0.5B
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Data Download"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x
# NOTE in CosyVoice3, we add instruct in sequence
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x --instruct "You are a helpful assistant.<|endofprompt|>"
done
fi
# NOTE embedding/token extraction is not necessary now as we support online feature extraction
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \
--src_dir data/$x \
--des_dir data/$x/parquet
done
fi
# train llm
export CUDA_VISIBLE_DEVICES="0"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
job_id=1986
dist_backend="nccl"
num_workers=2
prefetch=100
train_engine=torch_ddp
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Run train. We only support llm traning for now"
if [ $train_engine == 'deepspeed' ]; then
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
fi
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
for model in llm flow hifigan; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
../../../cosyvoice/bin/train.py \
--train_engine $train_engine \
--config conf/cosyvoice3.yaml \
--train_data data/train.data.list \
--cv_data data/dev.data.list \
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
--onnx_path $pretrained_model_dir \
--model $model \
--checkpoint $pretrained_model_dir/$model.pt \
--model_dir `pwd`/exp/cosyvoice3/$model/$train_engine \
--tensorboard_dir `pwd`/tensorboard/cosyvoice3/$model/$train_engine \
--ddp.dist_backend $dist_backend \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
--pin_memory \
--use_amp \
--deepspeed_config ./conf/ds_stage2.json \
--deepspeed.save_states model+optimizer
done
fi
# average model
average_num=5
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
for model in llm flow hifigan; do
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python cosyvoice/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
--num ${average_num} \
--val_best
done
fi
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
fi

View File

@@ -1 +0,0 @@
../../../cosyvoice

View File

@@ -27,7 +27,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in dev test train; do
tools/extract_embedding.py --dir data/$x \
../../../tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx
done
fi
@@ -35,7 +35,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in dev test train; do
tools/extract_speech_token.py --dir data/$x \
../../../tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
done
fi
@@ -44,7 +44,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in dev test train; do
mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \
--src_dir data/$x \
--des_dir data/$x/parquet
@@ -69,7 +69,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
for model in llm flow hifigan; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
cosyvoice/bin/train.py \
../../../cosyvoice/bin/train.py \
--train_engine $train_engine \
--config conf/cosyvoice.yaml \
--train_data data/train.data.list \

View File

@@ -1 +0,0 @@
../../../tools

View File

@@ -17,6 +17,7 @@ lightning==2.2.4
matplotlib==3.7.5
modelscope==1.20.0
networkx==3.1
numpy==1.26.4
omegaconf==2.3.0
onnx==1.16.0
onnxruntime-gpu==1.18.0; sys_platform == 'linux'
@@ -29,12 +30,13 @@ pyworld==0.3.4
rich==13.7.1
soundfile==0.12.1
tensorboard==2.14.0
tensorrt-cu12==10.0.1; sys_platform == 'linux'
tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux'
tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
tensorrt-cu12==10.13.3.9; sys_platform == 'linux'
tensorrt-cu12-bindings==10.13.3.9; sys_platform == 'linux'
tensorrt-cu12-libs==10.13.3.9; sys_platform == 'linux'
torch==2.3.1
torchaudio==2.3.1
transformers==4.40.1
transformers==4.51.3
x-transformers==2.11.24
uvicorn==0.30.0
wetext==0.0.4
wget==3.2

View File

@@ -9,5 +9,5 @@ RUN apt-get -y install git unzip git-lfs g++
RUN git lfs install
RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
# here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto

View File

@@ -24,7 +24,7 @@ import numpy as np
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import load_wav
app = FastAPI()
@@ -88,14 +88,8 @@ if __name__ == '__main__':
default=50000)
parser.add_argument('--model_dir',
type=str,
default='iic/CosyVoice-300M',
default='iic/CosyVoice2-0.5B',
help='local path or modelscope repo id')
args = parser.parse_args()
try:
cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
cosyvoice = AutoModel(model_dir=args.model_dir)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@@ -25,7 +25,7 @@ import numpy as np
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.cli.cosyvoice import AutoModel
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
@@ -33,13 +33,7 @@ logging.basicConfig(level=logging.DEBUG,
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
def __init__(self, args):
try:
self.cosyvoice = CosyVoice(args.model_dir, trt_concurrent=args.max_conc)
except Exception:
try:
self.cosyvoice = CosyVoice2(args.model_dir, trt_concurrent=args.max_conc)
except Exception:
raise TypeError('no valid model_type!')
self.cosyvoice = AutoModel(model_dir=args.model_dir)
logging.info('grpc service initialized')
def Inference(self, request, context):
@@ -90,7 +84,7 @@ if __name__ == '__main__':
default=4)
parser.add_argument('--model_dir',
type=str,
default='iic/CosyVoice-300M',
default='iic/CosyVoice2-0.5B',
help='local path or modelscope repo id')
args = parser.parse_args()
main()

View File

@@ -0,0 +1,8 @@
FROM nvcr.io/nvidia/tritonserver:25.06-trtllm-python-py3
LABEL maintainer="zhangyuekai@foxmail.com"
RUN apt-get update && apt-get install -y cmake
RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
COPY ./requirements.txt /workspace/requirements.txt
RUN pip install -r /workspace/requirements.txt
WORKDIR /workspace

View File

@@ -0,0 +1,141 @@
## Accelerating CosyVoice with DiT-based Token2Wav, NVIDIA Triton Inference Server and TensorRT-LLM
Contributed by Yuekai Zhang (NVIDIA).
This document describes how to accelerate CosyVoice with a DiT-based Token2Wav module from Step-Audio2, using NVIDIA Triton Inference Server and TensorRT-LLM.
### Quick Start
Launch the service directly with Docker Compose:
```sh
docker compose -f docker-compose.dit.yml up
```
### Build the Docker Image
To build the image from scratch:
```sh
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
```
### Run a Docker Container
```sh
your_mount_dir=/mnt:/mnt
docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
```
### Understanding `run_stepaudio2_dit_token2wav.sh`
The `run_stepaudio2_dit_token2wav.sh` script orchestrates the entire workflow through numbered stages.
You can run a subset of stages with:
```sh
bash run_stepaudio2_dit_token2wav.sh <start_stage> <stop_stage>
```
- `<start_stage>`: The stage to start from.
- `<stop_stage>`: The stage to stop after.
**Stages:**
- **Stage -1**: Clones the `Step-Audio2` and `CosyVoice` repositories.
- **Stage 0**: Downloads the `cosyvoice2_llm`, `CosyVoice2-0.5B`, and `Step-Audio-2-mini` models.
- **Stage 1**: Converts the HuggingFace checkpoint for the LLM to the TensorRT-LLM format and builds the TensorRT engines.
- **Stage 2**: Creates the Triton model repository, including configurations for `cosyvoice2_dit` and `token2wav_dit`.
- **Stage 3**: Launches the Triton Inference Server for Token2Wav module and uses `trtllm-serve` to deploy Cosyvoice2 LLM.
- **Stage 4**: Runs the gRPC benchmark client for performance testing.
- **Stage 5**: Runs the offline TTS inference benchmark test.
- **Stage 6**: Runs a standalone inference script for the Step-Audio2-mini DiT Token2Wav model.
- **Stage 7**: Launches servers in a disaggregated setup, with the LLM on GPU 0 and Token2Wav servers on GPUs 1-3.
- **Stage 8**: Runs the benchmark client for the disaggregated server configuration.
### Export Models and Launch Server
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
```sh
# This command runs stages 0, 1, 2, and 3
bash run_stepaudio2_dit_token2wav.sh 0 3
```
### Benchmark with client-server mode
To benchmark the running Triton server, run stage 4:
```sh
bash run_stepaudio2_dit_token2wav.sh 4 4
# You can customize parameters such as the number of tasks inside the script.
```
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
#### Total Request Latency
| Concurrent Tasks | RTF | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
| ---------------- | ------ | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
| 1 | 0.1228 | 833.66 | 779.98 | 1297.05 | 1555.97 | 1653.02 |
| 2 | 0.0901 | 1166.23 | 1124.69 | 1762.76 | 1900.64 | 2204.14 |
| 4 | 0.0741 | 1849.30 | 1759.42 | 2624.50 | 2822.20 | 3128.42 |
| 6 | 0.0774 | 2936.13 | 3054.64 | 3849.60 | 3900.49 | 4245.79 |
| 8 | 0.0691 | 3408.56 | 3434.98 | 4547.13 | 5047.76 | 5346.53 |
| 10 | 0.0707 | 4306.56 | 4343.44 | 5769.64 | 5876.09 | 5939.79 |
#### First Chunk Latency
| Concurrent Tasks | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
| ---------------- | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
| 1 | 197.50 | 196.13 | 214.65 | 215.96 | 229.21 |
| 2 | 281.15 | 278.20 | 345.18 | 361.79 | 395.97 |
| 4 | 510.65 | 530.50 | 630.13 | 642.44 | 666.65 |
| 6 | 921.54 | 918.86 | 1079.97 | 1265.22 | 1524.41 |
| 8 | 1019.95 | 1085.26 | 1371.05 | 1402.24 | 1410.66 |
| 10 | 1214.98 | 1293.54 | 1575.36 | 1654.51 | 2161.76 |
### Benchmark with offline inference mode
For offline inference mode benchmark, please run stage 5:
```sh
bash run_stepaudio2_dit_token2wav.sh 5 5
```
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
#### Offline TTS (Cosyvoice2 0.5B LLM + StepAudio2 DiT Token2Wav)
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|---------|------------|------------------|-----------------------|--|
| TRTLLM | 16 | 2.01 | 5.03 | 0.0292 |
### Disaggregated Server
When the LLM and token2wav components are deployed on the same GPU, they compete for resources. To optimize performance, we use a disaggregated setup where the LLM is deployed on one dedicated L20 GPU, taking advantage of in-flight batching for inference. The token2wav module is deployed on separate, dedicated GPUs.
The table below shows the first chunk latency results for this configuration. In our tests, we deploy two token2wav instances on each dedicated token2wav GPU.
| token2wav_num_gpu | concurrent_task_per_instance | concurrent_tasks_per_gpu | avg (ms) | p50 (ms) | p90 (ms) | p99 (ms) |
|---|---|---|---|---|---|---|
| 1 | 1 | 1.00 | 218.53 | 217.86 | 254.07 | 296.49 |
| 2 | 1 | 1.33 | 218.82 | 219.21 | 256.62 | 303.13 |
| 3 | 1 | 1.50 | 229.08 | 223.27 | 302.13 | 324.41 |
| 4 | 1 | 1.60 | 203.87 | 198.23 | 254.92 | 279.31 |
| 1 | 2 | 2.00 | 293.46 | 280.53 | 370.81 | 407.40 |
| 2 | 2 | 2.67 | 263.38 | 236.84 | 350.82 | 397.39 |
| 3 | 2 | 3.00 | 308.09 | 275.48 | 385.22 | 521.45 |
| 4 | 2 | 3.20 | 271.85 | 253.25 | 359.03 | 387.91 |
| 1 | 3 | 3.00 | 389.15 | 373.01 | 469.22 | 542.89 |
| 2 | 3 | 4.00 | 403.48 | 394.80 | 481.24 | 507.75 |
| 3 | 3 | 4.50 | 406.33 | 391.28 | 495.43 | 571.29 |
| 4 | 3 | 4.80 | 436.72 | 383.81 | 638.44 | 879.23 |
| 1 | 4 | 4.00 | 520.12 | 493.98 | 610.38 | 739.85 |
| 2 | 4 | 5.33 | 494.60 | 490.50 | 605.93 | 708.09 |
| 3 | 4 | 6.00 | 538.23 | 508.33 | 687.62 | 736.96 |
| 4 | 4 | 6.40 | 579.68 | 546.20 | 721.53 | 958.04 |
| 1 | 5 | 5.00 | 635.02 | 623.30 | 786.85 | 819.84 |
| 2 | 5 | 6.67 | 598.23 | 617.09 | 741.00 | 788.96 |
| 3 | 5 | 7.50 | 644.78 | 684.40 | 786.45 | 1009.45 |
| 4 | 5 | 8.00 | 733.92 | 642.26 | 1024.79 | 1281.55 |
| 1 | 6 | 6.00 | 715.38 | 745.68 | 887.04 | 906.68 |
| 2 | 6 | 8.00 | 748.31 | 753.94 | 873.59 | 1007.14 |
| 3 | 6 | 9.00 | 900.27 | 822.28 | 1431.14 | 1800.23 |
| 4 | 6 | 9.60 | 857.54 | 820.33 | 1150.30 | 1298.53 |
The `concurrent_task_per_gpu` is calculated as:
`concurrent_task_per_gpu = concurrent_task_per_instance * num_token2wav_instance_per_gpu (2) * token2wav_gpus / (token2wav_gpus + llm_gpus (1))`
### Acknowledgements
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).

View File

@@ -0,0 +1,146 @@
## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM
Contributed by Yuekai Zhang (NVIDIA).
### Quick Start
Launch the service directly with Docker Compose:
```sh
docker compose up
```
### Build the Docker Image
To build the image from scratch:
```sh
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
```
### Run a Docker Container
```sh
your_mount_dir=/mnt:/mnt
docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
```
### Understanding `run.sh`
The `run.sh` script orchestrates the entire workflow through numbered stages.
You can run a subset of stages with:
```sh
bash run.sh <start_stage> <stop_stage> [service_type]
```
- `<start_stage>`: The stage to start from (0-5).
- `<stop_stage>`: The stage to stop after (0-5).
**Stages:**
- **Stage 0**: Downloads the `cosyvoice-2 0.5B` model from HuggingFace.
- **Stage 1**: Converts the HuggingFace checkpoint to the TensorRT-LLM format and builds the TensorRT engines.
- **Stage 2**: Creates the Triton model repository and configures the model files. The configuration is adjusted based on whether `Decoupled=True` (streaming) or `Decoupled=False` (offline) will be used.
- **Stage 3**: Launches the Triton Inference Server.
- **Stage 4**: Runs the single-utterance HTTP client for testing.
- **Stage 5**: Runs the gRPC benchmark client.
- **Stage 6**: Runs the offline inference benchmark test.
### Export Models and Launch Server
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
```sh
# This command runs stages 0, 1, 2, and 3
bash run.sh 0 3
```
> [!TIP]
> Both streaming and offline (non-streaming) TTS modes are supported. For streaming TTS, set `Decoupled=True`. For offline TTS, set `Decoupled=False`. You need to rerun stage 2 if you switch between modes.
### Single-Utterance HTTP Client
Sends a single HTTP inference request. This is intended for testing the offline TTS mode (`Decoupled=False`):
```sh
bash run.sh 4 4
```
### Benchmark with client-server mode
To benchmark the running Triton server, pass `streaming` or `offline` as the third argument:
```sh
bash run.sh 5 5 # [streaming|offline]
# You can also customize parameters such as the number of tasks and the dataset split:
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline]
```
> [!TIP]
> It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up.
### Benchmark with offline inference mode
For offline inference mode benchmark, please check the below command:
```sh
# install FlashCosyVoice for token2wav batching
# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt
# cd /workspace/FlashCosyVoice
# pip install -e .
# cd -
# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx
bash run.sh 6 6
# You can also switch to huggingface backend by setting backend=hf
```
### Benchmark Results
The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio):
**Client-Server Mode: Streaming TTS (First Chunk Latency)**
| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|---|---|---|---|---|
| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 |
| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 |
| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 |
| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 |
| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 |
> If your service only needs a fixed speaker, you can set `use_spk2info_cache=True` in `run.sh`. To add more speakers, refer to the instructions [here](https://github.com/qi-hua/async_cosyvoice?tab=readme-ov-file#9-spk2info-%E8%AF%B4%E6%98%8E).
**Client-Server Mode: Offline TTS (Full Sentence Latency)**
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|---|---|---|---|---|---|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
**Offline Inference Mode: Hugginface LLM V.S. TensorRT-LLM**
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|---------|------------|------------------|-----------------------|--|
| HF | 1 | 39.26 | 44.31 | 0.2494 |
| HF | 2 | 30.54 | 35.62 | 0.2064 |
| HF | 4 | 18.63 | 23.90 | 0.1421 |
| HF | 8 | 11.22 | 16.45 | 0.0947 |
| HF | 16 | 8.42 | 13.78 | 0.0821 |
| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 |
| TRTLLM | 2 | 7.64 |12.65 | 0.0739 |
| TRTLLM | 4 | 4.89 | 9.38 | 0.0539 |
| TRTLLM | 8 | 2.92 | 7.23 | 0.0418 |
| TRTLLM | 16 | 2.01 | 6.63 | 0.0386 |
### OpenAI-Compatible Server
To launch an OpenAI-compatible API service, run the following commands:
```sh
git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
cd Triton-OpenAI-Speech
pip install -r requirements.txt
# After the Triton service is running, start the FastAPI bridge:
python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
# Test the service with curl:
bash test/test_cosyvoice.sh
```
> [!NOTE]
> Currently, only the offline TTS mode is compatible with the OpenAI-compatible server.
### Acknowledgements
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).

View File

@@ -0,0 +1,922 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Nvidia (authors: Yuekai Zhang)
# 2023 Recurrent.ai (authors: Songtao Shi)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script supports to load dataset from huggingface and sends it to the server
for decoding, in parallel.
Usage:
num_task=2
# For offline F5-TTS
python3 client_grpc.py \
--server-addr localhost \
--model-name f5_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task}
# For offline Spark-TTS-0.5B
python3 client_grpc.py \
--server-addr localhost \
--model-name spark_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name wenetspeech4tts \
--log-dir ./log_concurrent_tasks_${num_task}
"""
import argparse
import asyncio
import json
import queue
import uuid
import functools
import os
import time
import types
from pathlib import Path
import numpy as np
import soundfile as sf
import tritonclient
import tritonclient.grpc.aio as grpcclient_aio
import tritonclient.grpc as grpcclient_sync
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
self._first_chunk_time = None
self._second_chunk_time = None
self._start_time = None
def record_start_time(self):
self._start_time = time.time()
def get_first_chunk_latency(self):
if self._first_chunk_time and self._start_time:
return self._first_chunk_time - self._start_time
return None
def get_second_chunk_latency(self):
if self._first_chunk_time and self._second_chunk_time:
return self._second_chunk_time - self._first_chunk_time
return None
def callback(user_data, result, error):
if not error:
if user_data._first_chunk_time is None:
user_data._first_chunk_time = time.time()
elif user_data._second_chunk_time is None:
user_data._second_chunk_time = time.time()
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
def stream_callback(user_data_map, result, error):
request_id = None
if error:
print(f"An error occurred in the stream callback: {error}")
else:
request_id = result.get_response().id
if request_id:
user_data = user_data_map.get(request_id)
if user_data:
callback(user_data, result, error)
else:
print(f"Warning: Could not find user_data for request_id {request_id}")
def write_triton_stats(stats, summary_file):
with open(summary_file, "w") as summary_f:
model_stats = stats["model_stats"]
for model_state in model_stats:
if "last_inference" not in model_state:
continue
summary_f.write(f"model name is {model_state['name']} \n")
model_inference_stats = model_state["inference_stats"]
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
summary_f.write(
f"queue time {total_queue_time_s:<5.2f} s, "
f"compute infer time {total_infer_time_s:<5.2f} s, "
f"compute input time {total_input_time_s:<5.2f} s, "
f"compute output time {total_output_time_s:<5.2f} s \n"
)
model_batch_stats = model_state["batch_stats"]
for batch in model_batch_stats:
batch_size = int(batch["batch_size"])
compute_input = batch["compute_input"]
compute_output = batch["compute_output"]
compute_infer = batch["compute_infer"]
batch_count = int(compute_infer["count"])
if batch_count == 0:
continue
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
compute_input_time_ms = int(compute_input["ns"]) / 1e6
compute_output_time_ms = int(compute_output["ns"]) / 1e6
summary_f.write(
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
f"{compute_infer_time_ms / batch_count:.2f} ms, "
f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
)
summary_f.write(
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
)
summary_f.write(
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
)
def subtract_stats(stats_after, stats_before):
"""Subtracts two Triton inference statistics objects."""
stats_diff = json.loads(json.dumps(stats_after))
model_stats_before_map = {
s["name"]: {
"version": s["version"],
"last_inference": s.get("last_inference", 0),
"inference_count": s.get("inference_count", 0),
"execution_count": s.get("execution_count", 0),
"inference_stats": s.get("inference_stats", {}),
"batch_stats": s.get("batch_stats", []),
}
for s in stats_before["model_stats"]
}
for model_stat_after in stats_diff["model_stats"]:
model_name = model_stat_after["name"]
if model_name in model_stats_before_map:
model_stat_before = model_stats_before_map[model_name]
model_stat_after["inference_count"] = str(
int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
)
model_stat_after["execution_count"] = str(
int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
)
if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
if "ns" in model_stat_after["inference_stats"][key]:
ns_after = int(model_stat_after["inference_stats"][key]["ns"])
ns_before = int(model_stat_before["inference_stats"][key]["ns"])
model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
if "count" in model_stat_after["inference_stats"][key]:
count_after = int(model_stat_after["inference_stats"][key]["count"])
count_before = int(model_stat_before["inference_stats"][key]["count"])
model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
for batch_stat_after in model_stat_after["batch_stats"]:
bs = batch_stat_after["batch_size"]
if bs in batch_stats_before_map:
batch_stat_before = batch_stats_before_map[bs]
for key in ["compute_input", "compute_infer", "compute_output"]:
if key in batch_stat_after and key in batch_stat_before:
count_after = int(batch_stat_after[key]["count"])
count_before = int(batch_stat_before[key]["count"])
batch_stat_after[key]["count"] = str(count_after - count_before)
ns_after = int(batch_stat_after[key]["ns"])
ns_before = int(batch_stat_before[key]["ns"])
batch_stat_after[key]["ns"] = str(ns_after - ns_before)
return stats_diff
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-addr",
type=str,
default="localhost",
help="Address of the server",
)
parser.add_argument(
"--server-port",
type=int,
default=8001,
help="Grpc port of the triton server, default is 8001",
)
parser.add_argument(
"--reference-audio",
type=str,
default=None,
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--huggingface-dataset",
type=str,
default="yuekai/seed_tts",
help="dataset name in huggingface dataset hub",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="dataset split name, default is 'test'",
)
parser.add_argument(
"--manifest-path",
type=str,
default=None,
help="Path to the manifest dir which includes wav.scp trans.txt files.",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
choices=[
"f5_tts",
"spark_tts",
"cosyvoice2",
"cosyvoice2_dit"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--num-tasks",
type=int,
default=1,
help="Number of concurrent tasks for sending",
)
parser.add_argument(
"--log-interval",
type=int,
default=5,
help="Controls how frequently we print the log.",
)
parser.add_argument(
"--compute-wer",
action="store_true",
default=False,
help="""True to compute WER.
""",
)
parser.add_argument(
"--log-dir",
type=str,
required=False,
default="./tmp",
help="log directory",
)
parser.add_argument(
"--mode",
type=str,
default="offline",
choices=["offline", "streaming"],
help="Select offline or streaming benchmark mode."
)
parser.add_argument(
"--chunk-overlap-duration",
type=float,
default=0.1,
help="Chunk overlap duration for streaming reconstruction (in seconds)."
)
parser.add_argument(
"--use-spk2info-cache",
type=str,
default="False",
help="Use spk2info cache for reference audio.",
)
return parser.parse_args()
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
waveform = resample(waveform, num_samples)
return waveform, target_sample_rate
def prepare_request_input_output(
protocol_client,
waveform,
reference_text,
target_text,
sample_rate=16000,
padding_duration: int = None,
use_spk2info_cache: bool = False
):
"""Prepares inputs for Triton inference (offline or streaming)."""
assert len(waveform.shape) == 1, "waveform should be 1D"
lengths = np.array([[len(waveform)]], dtype=np.int32)
if padding_duration:
duration = len(waveform) / sample_rate
if reference_text:
estimated_target_duration = duration / len(reference_text) * len(target_text)
else:
estimated_target_duration = duration
required_total_samples = padding_duration * sample_rate * (
(int(estimated_target_duration + duration) // padding_duration) + 1
)
samples = np.zeros((1, required_total_samples), dtype=np.float32)
samples[0, : len(waveform)] = waveform
else:
samples = waveform.reshape(1, -1).astype(np.float32)
inputs = [
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
protocol_client.InferInput(
"reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
),
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
]
inputs[0].set_data_from_numpy(samples)
inputs[1].set_data_from_numpy(lengths)
input_data_numpy = np.array([reference_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[2].set_data_from_numpy(input_data_numpy)
input_data_numpy = np.array([target_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[3].set_data_from_numpy(input_data_numpy)
outputs = [protocol_client.InferRequestedOutput("waveform")]
if use_spk2info_cache:
inputs = inputs[-1:]
return inputs, outputs
def run_sync_streaming_inference(
sync_triton_client: tritonclient.grpc.InferenceServerClient,
model_name: str,
inputs: list,
outputs: list,
request_id: str,
user_data: UserData,
chunk_overlap_duration: float,
save_sample_rate: int,
audio_save_path: str,
):
"""Helper function to run the blocking sync streaming call."""
start_time_total = time.time()
user_data.record_start_time()
sync_triton_client.async_stream_infer(
model_name,
inputs,
request_id=request_id,
outputs=outputs,
enable_empty_final_response=True,
)
audios = []
while True:
try:
result = user_data._completed_requests.get(timeout=200)
if isinstance(result, InferenceServerException):
print(f"Received InferenceServerException: {result}")
return None, None, None, None
response = result.get_response()
final = response.parameters["triton_final_response"].bool_param
if final is True:
break
audio_chunk = result.as_numpy("waveform").reshape(-1)
if audio_chunk.size > 0:
audios.append(audio_chunk)
else:
print("Warning: received empty audio chunk.")
except queue.Empty:
print(f"Timeout waiting for response for request id {request_id}")
return None, None, None, None
end_time_total = time.time()
total_request_latency = end_time_total - start_time_total
first_chunk_latency = user_data.get_first_chunk_latency()
second_chunk_latency = user_data.get_second_chunk_latency()
if audios:
if model_name == "spark_tts":
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
reconstructed_audio = None
if not audios:
print("Warning: No audio chunks received.")
reconstructed_audio = np.array([], dtype=np.float32)
elif len(audios) == 1:
reconstructed_audio = audios[0]
else:
reconstructed_audio = audios[0][:-cross_fade_samples]
for i in range(1, len(audios)):
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
audios[i - 1][-cross_fade_samples:] * fade_out)
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
if reconstructed_audio is not None and reconstructed_audio.size > 0:
actual_duration = len(reconstructed_audio) / save_sample_rate
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
else:
print("Warning: No audio chunks received or reconstructed.")
actual_duration = 0
else:
reconstructed_audio = np.concatenate(audios)
actual_duration = len(reconstructed_audio) / save_sample_rate
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
else:
print("Warning: No audio chunks received.")
actual_duration = 0
return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
async def send_streaming(
manifest_item_list: list,
name: str,
server_url: str,
protocol_client: types.ModuleType,
log_interval: int,
model_name: str,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
chunk_overlap_duration: float = 0.1,
padding_duration: int = None,
use_spk2info_cache: bool = False,
):
total_duration = 0.0
latency_data = []
task_id = int(name[5:])
sync_triton_client = None
user_data_map = {}
try:
print(f"{name}: Initializing sync client for streaming...")
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
try:
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
reference_text, target_text = item["reference_text"], item["target_text"]
inputs, outputs = prepare_request_input_output(
protocol_client,
waveform,
reference_text,
target_text,
sample_rate,
padding_duration=padding_duration,
use_spk2info_cache=use_spk2info_cache
)
request_id = str(uuid.uuid4())
user_data = UserData()
user_data_map[request_id] = user_data
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
run_sync_streaming_inference,
sync_triton_client,
model_name,
inputs,
outputs,
request_id,
user_data,
chunk_overlap_duration,
save_sample_rate,
audio_save_path
)
if total_request_latency is not None:
print(
f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
)
latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
total_duration += actual_duration
else:
print(f"{name}: Item {i} failed.")
del user_data_map[request_id]
except FileNotFoundError:
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
except Exception as e:
print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
import traceback
traceback.print_exc()
finally:
if sync_triton_client:
try:
print(f"{name}: Closing stream and sync client...")
sync_triton_client.stop_stream()
sync_triton_client.close()
except Exception as e:
print(f"{name}: Error closing sync client: {e}")
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
return total_duration, latency_data
async def send(
manifest_item_list: list,
name: str,
triton_client: tritonclient.grpc.aio.InferenceServerClient,
protocol_client: types.ModuleType,
log_interval: int,
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
use_spk2info_cache: bool = False,
):
total_duration = 0.0
latency_data = []
task_id = int(name[5:])
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
reference_text, target_text = item["reference_text"], item["target_text"]
inputs, outputs = prepare_request_input_output(
protocol_client,
waveform,
reference_text,
target_text,
sample_rate,
padding_duration=padding_duration,
use_spk2info_cache=use_spk2info_cache
)
sequence_id = 100000000 + i + task_id * 10
start = time.time()
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
audio = response.as_numpy("waveform").reshape(-1)
actual_duration = len(audio) / save_sample_rate
end = time.time() - start
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
latency_data.append((end, actual_duration))
total_duration += actual_duration
return total_duration, latency_data
def load_manifests(manifest_path):
with open(manifest_path, "r") as f:
manifest_list = []
for line in f:
assert len(line.strip().split("|")) == 4
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
utt = Path(utt).stem
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
manifest_list.append(
{
"audio_filepath": prompt_wav,
"reference_text": prompt_text,
"target_text": gt_text,
"target_audio_path": utt,
}
)
return manifest_list
def split_data(data, k):
n = len(data)
if n < k:
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
k = n
quotient = n // k
remainder = n % k
result = []
start = 0
for i in range(k):
if i < remainder:
end = start + quotient + 1
else:
end = start + quotient
result.append(data[start:end])
start = end
return result
async def main():
args = get_args()
url = f"{args.server_addr}:{args.server_port}"
triton_client = None
protocol_client = None
if args.mode == "offline":
print("Initializing gRPC client for offline mode...")
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
protocol_client = grpcclient_aio
elif args.mode == "streaming":
print("Initializing gRPC client for streaming mode...")
protocol_client = grpcclient_sync
else:
raise ValueError(f"Invalid mode: {args.mode}")
if args.reference_audio:
args.num_tasks = 1
args.log_interval = 1
manifest_item_list = [
{
"reference_text": args.reference_text,
"target_text": args.target_text,
"audio_filepath": args.reference_audio,
"target_audio_path": "test",
}
]
elif args.huggingface_dataset:
import datasets
dataset = datasets.load_dataset(
args.huggingface_dataset,
split=args.split_name,
trust_remote_code=True,
)
manifest_item_list = []
for i in range(len(dataset)):
manifest_item_list.append(
{
"audio_filepath": dataset[i]["prompt_audio"],
"reference_text": dataset[i]["prompt_text"],
"target_audio_path": dataset[i]["id"],
"target_text": dataset[i]["target_text"],
}
)
else:
manifest_item_list = load_manifests(args.manifest_path)
stats_client = None
stats_before = None
try:
print("Initializing temporary async client for fetching stats...")
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
print("Fetching inference statistics before running tasks...")
stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
except Exception as e:
print(f"Could not retrieve statistics before running tasks: {e}")
num_tasks = min(args.num_tasks, len(manifest_item_list))
manifest_item_list = split_data(manifest_item_list, num_tasks)
os.makedirs(args.log_dir, exist_ok=True)
args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
tasks = []
start_time = time.time()
for i in range(num_tasks):
if args.mode == "offline":
task = asyncio.create_task(
send(
manifest_item_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=args.log_interval,
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
use_spk2info_cache=args.use_spk2info_cache,
)
)
elif args.mode == "streaming":
task = asyncio.create_task(
send_streaming(
manifest_item_list[i],
name=f"task-{i}",
server_url=url,
protocol_client=protocol_client,
log_interval=args.log_interval,
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=10,
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
chunk_overlap_duration=args.chunk_overlap_duration,
use_spk2info_cache=args.use_spk2info_cache,
)
)
tasks.append(task)
ans_list = await asyncio.gather(*tasks)
end_time = time.time()
elapsed = end_time - start_time
total_duration = 0.0
latency_data = []
for ans in ans_list:
if ans:
total_duration += ans[0]
latency_data.extend(ans[1])
else:
print("Warning: A task returned None, possibly due to an error.")
if total_duration == 0:
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
rtf = float('inf')
else:
rtf = elapsed / total_duration
s = f"Mode: {args.mode}\n"
s += f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
if latency_data:
if args.mode == "offline":
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
if latency_list:
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
s += f"latency_variance: {latency_variance:.2f}\n"
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
s += f"average_latency_ms: {latency_ms:.2f}\n"
else:
s += "No latency data collected for offline mode.\n"
elif args.mode == "streaming":
total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
s += "\n--- Total Request Latency ---\n"
if total_latency_list:
avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
else:
s += "No total request latency data collected.\n"
s += "\n--- First Chunk Latency ---\n"
if first_chunk_latency_list:
avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
else:
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
s += "\n--- Second Chunk Latency ---\n"
if second_chunk_latency_list:
avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
else:
s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
else:
s += "No latency data collected.\n"
print(s)
if args.manifest_path:
name = Path(args.manifest_path).stem
elif args.split_name:
name = args.split_name
elif args.reference_audio:
name = Path(args.reference_audio).stem
else:
name = "results"
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
f.write(s)
try:
if stats_client and stats_before:
print("Fetching inference statistics after running tasks...")
stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
print("Calculating statistics difference...")
stats = subtract_stats(stats_after, stats_before)
print("Fetching model config...")
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
json.dump(metadata, f, indent=4)
else:
print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
except Exception as e:
print(f"Could not retrieve statistics or config: {e}")
finally:
if stats_client:
try:
print("Closing temporary async stats client...")
await stats_client.close()
except Exception as e:
print(f"Error closing async stats client: {e}")
if __name__ == "__main__":
async def run_main():
try:
await main()
except Exception as e:
print(f"An error occurred in main: {e}")
import traceback
traceback.print_exc()
asyncio.run(run_main())

View File

@@ -0,0 +1,172 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import requests
import soundfile as sf
import numpy as np
import argparse
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--server-url",
type=str,
default="localhost:8000",
help="Address of the server",
)
parser.add_argument(
"--reference-audio",
type=str,
default="../../example/prompt_audio.wav",
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="吃燕窝就选燕之屋本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝营养更均衡本节目由豆本豆豆奶特约播出。",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
help="",
)
parser.add_argument(
"--model-name",
type=str,
default="spark_tts",
choices=[
"f5_tts",
"spark_tts",
"cosyvoice2"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--output-audio",
type=str,
default="output.wav",
help="Path to save the output audio",
)
return parser.parse_args()
def prepare_request(
waveform,
reference_text,
target_text,
sample_rate=16000,
padding_duration: int = None,
audio_save_dir: str = "./",
):
assert len(waveform.shape) == 1, "waveform should be 1D"
lengths = np.array([[len(waveform)]], dtype=np.int32)
if padding_duration:
# padding to nearset 10 seconds
samples = np.zeros(
(
1,
padding_duration
* sample_rate
* ((int(len(waveform) / sample_rate) // padding_duration) + 1),
),
dtype=np.float32,
)
samples[0, : len(waveform)] = waveform
else:
samples = waveform
samples = samples.reshape(1, -1).astype(np.float32)
data = {
"inputs": [
{
"name": "reference_wav",
"shape": samples.shape,
"datatype": "FP32",
"data": samples.tolist()
},
{
"name": "reference_wav_len",
"shape": lengths.shape,
"datatype": "INT32",
"data": lengths.tolist(),
},
{
"name": "reference_text",
"shape": [1, 1],
"datatype": "BYTES",
"data": [reference_text]
},
{
"name": "target_text",
"shape": [1, 1],
"datatype": "BYTES",
"data": [target_text]
}
]
}
return data
if __name__ == "__main__":
args = get_args()
server_url = args.server_url
if not server_url.startswith(("http://", "https://")):
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
waveform, sr = sf.read(args.reference_audio)
assert sr == 16000, "sample rate hardcoded in server"
samples = np.array(waveform, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)
rsp = requests.post(
url,
headers={"Content-Type": "application/json"},
json=data,
verify=False,
params={"request_id": '0'}
)
result = rsp.json()
audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32)
if args.model_name == "spark_tts":
sample_rate = 16000
else:
sample_rate = 24000
sf.write(args.output_audio, audio, sample_rate, "PCM_16")

View File

@@ -0,0 +1,20 @@
services:
tts:
image: soar97/triton-cosyvoice:25.06
shm_size: '1gb'
ports:
- "8000:8000"
- "8001:8001"
- "8002:8002"
environment:
- PYTHONIOENCODING=utf-8
- MODEL_ID=${MODEL_ID}
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: >
/bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run_stepaudio2_dit_token2wav.sh 0 3"

View File

@@ -0,0 +1,20 @@
services:
tts:
image: soar97/triton-cosyvoice:25.06
shm_size: '1gb'
ports:
- "8000:8000"
- "8001:8001"
- "8002:8002"
environment:
- PYTHONIOENCODING=utf-8
- MODEL_ID=${MODEL_ID}
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: >
/bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/FunAudioLLM/CosyVoice.git && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3"

View File

@@ -0,0 +1,97 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import torch
from torch.utils.dlpack import to_dlpack
import triton_python_backend_utils as pb_utils
import os
import numpy as np
import s3tokenizer
torch.set_num_threads(1)
ORIGINAL_VOCAB_SIZE = 151663
class TritonPythonModel:
"""Triton Python model for audio tokenization.
This model takes reference audio input and extracts semantic tokens
using s3tokenizer.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.device = torch.device("cuda")
model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing tokenized outputs
"""
mels = []
# Process each request in batch
for request in requests:
# Extract input tensors
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(
request, "reference_wav_len").as_numpy().item()
wav_array = torch.from_numpy(wav_array).to(self.device)
# Prepare inputs
wav = wav_array[:, :wav_len].squeeze(0)
mels.append(s3tokenizer.log_mel_spectrogram(wav))
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
codes = codes.clone() + ORIGINAL_VOCAB_SIZE
responses = []
for i in range(len(requests)):
prompt_speech_tokens = codes[i, :codes_lens[i].item()]
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
"prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
inference_response = pb_utils.InferenceResponse(
output_tensors=[prompt_speech_tokens_tensor])
responses.append(inference_response)
return responses

View File

@@ -0,0 +1,53 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "audio_tokenizer"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
}
]
output [
{
name: "prompt_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

View File

@@ -0,0 +1,454 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import threading
import time
from uuid import uuid4
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
import torchaudio
from matcha.utils.audio import mel_spectrogram
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
class TritonPythonModel:
"""Triton Python model for Spark TTS.
This model orchestrates the end-to-end TTS pipeline by coordinating
between audio tokenizer, LLM, and vocoder components.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
self.logger = pb_utils.Logger
# Parse model parameters
self.model_config = json.loads(args['model_config'])
parameters = self.model_config['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.logger.log_info(f"model_params:{model_params}")
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
self.device = torch.device("cuda")
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
self.token_frame_rate = 25
self.flow_pre_lookahead_len = 3
self.token_hop_len = 15
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
def forward_llm(self, input_ids):
"""
Prepares the response from the language model based on the provided
inputs. Creates a `pb_utils.InferenceRequest` object with passed
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
For each response from the language model:
- Checks for errors and raise an exception if any are found.
- Extracts the "output_ids" tensor from the response.
- Determines the finish reason based on the presence of the
end-of-sequence token or reaching the maximum length.
- Appends the generated token IDs to `output_ids`.
- If the finish reason is determined, decodes the output IDs to text
and prepares the final response.
The final response includes the generated text, finish reason,
completion tokens, prompt tokens, and total tokens.
Parameters
----------
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
Returns
-------
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
"""
# convert input_ids to numpy, with shape [1, sequence_length]
input_ids = input_ids.cpu().numpy()
max_tokens = 750
input_dict = {
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
"pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
"streaming": np.array([[self.decoupled]], dtype=np.bool_),
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
"runtime_top_k": np.array([[50]], dtype=np.int32),
"temperature": np.array([[0.8]], dtype=np.float32),
"repetition_penalty": np.array([[1.1]], dtype=np.float32),
"random_seed": np.array([[42]], dtype=np.uint64),
"input_ids": input_ids,
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
}
# Convert inputs to Triton tensors
input_tensor_list = [
pb_utils.Tensor(k, v) for k, v in input_dict.items()
]
# Create and execute inference request
llm_request = pb_utils.InferenceRequest(
model_name="tensorrt_llm",
requested_output_names=["output_ids", "sequence_length"],
inputs=input_tensor_list,
)
llm_responses = llm_request.exec(decoupled=self.decoupled)
if self.decoupled:
for llm_response in llm_responses:
if llm_response.has_error():
raise pb_utils.TritonModelException(llm_response.error().message())
# Extract and process output
output_ids = pb_utils.get_output_tensor_by_name(
llm_response, "output_ids").as_numpy()
seq_lens = pb_utils.get_output_tensor_by_name(
llm_response, "sequence_length").as_numpy()
# Get actual output IDs up to the sequence length
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
yield actual_output_ids
else:
llm_response = llm_responses
if llm_response.has_error():
raise pb_utils.TritonModelException(llm_response.error().message())
# Extract and process output
output_ids = pb_utils.get_output_tensor_by_name(
llm_response, "output_ids").as_numpy()
seq_lens = pb_utils.get_output_tensor_by_name(
llm_response, "sequence_length").as_numpy()
# Get actual output IDs up to the sequence length
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
yield actual_output_ids
def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component.
Args:
wav: Input waveform tensor
wav_len: Waveform length tensor
Returns:
Tuple of global and semantic tokens
"""
inference_request = pb_utils.InferenceRequest(
model_name='audio_tokenizer',
requested_output_names=['prompt_speech_tokens'],
inputs=[wav, wav_len]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
return prompt_speech_tokens
def forward_speaker_embedding(self, wav):
"""Forward pass through the speaker embedding component.
Args:
wav: Input waveform tensor
Returns:
Prompt speaker embedding tensor
"""
inference_request = pb_utils.InferenceRequest(
model_name='speaker_embedding',
requested_output_names=['prompt_spk_embedding'],
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
return prompt_spk_embedding
def forward_token2wav(
self,
target_speech_tokens: torch.Tensor,
request_id: str,
prompt_speech_tokens: torch.Tensor = None,
prompt_speech_feat: torch.Tensor = None,
prompt_spk_embedding: torch.Tensor = None,
token_offset: int = None,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
Args:
prompt_speech_tokens: Prompt speech tokens tensor
prompt_speech_feat: Prompt speech feat tensor
prompt_spk_embedding: Prompt spk embedding tensor
target_speech_tokens: Target speech tokens tensor
Returns:
Generated waveform tensor
"""
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
inputs_tensor = [target_speech_tokens_tensor]
if token_offset is not None:
assert finalize is not None
token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor.append(token_offset_tensor)
inputs_tensor.append(finalize_tensor)
if prompt_spk_embedding is not None:
assert prompt_speech_feat is not None
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav',
requested_output_names=['waveform'],
inputs=inputs_tensor,
request_id=request_id,
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def parse_input(self, text, prompt_text, prompt_speech_tokens):
total_text = f"{prompt_text}{text}"
prompt = self.prompt_template.format(input_text=total_text)
input_ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([input_ids], dtype=torch.int32)
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
return input_ids
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(
speech,
n_fft=1920,
num_mels=80,
sampling_rate=24000,
hop_size=480,
win_size=1920,
fmin=0,
fmax=8000).squeeze(
dim=0).transpose(
0,
1).to(
self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
for generated_ids in generated_ids_iter:
generated_ids = generated_ids.tolist()
if len(generated_ids) == 0:
break
semantic_token_ids_arr.extend(generated_ids)
llm_is_done_flag[0] = True
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated audio
"""
responses = []
for request in requests:
request_id = request.request_id()
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
# Process reference audio through audio tokenizer
if wav is not None:
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
else:
# using pre-cached reference text
reference_text = self.default_spk_info["prompt_text"]
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
prompt_speech_feat = None
prompt_spk_embedding = None
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
# Prepare prompt for LLM
input_ids = self.parse_input(
text=target_text,
prompt_text=reference_text,
prompt_speech_tokens=prompt_speech_tokens,
)
# Generate semantic tokens with LLM
generated_ids_iter = self.forward_llm(input_ids)
token2wav_request_id = request_id or str(uuid4())
if self.decoupled:
response_sender = request.get_response_sender()
semantic_token_ids_arr = []
llm_is_done_flag = [False]
llm_thread = threading.Thread(
target=self._llm_gen_thread,
args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
)
llm_thread.start()
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
while True:
pending_num = len(semantic_token_ids_arr) - token_offset
if llm_is_done_flag[0]:
break
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(
this_tts_speech_token, token2wav_request_id, prompt_speech_tokens,
prompt_speech_feat, prompt_spk_embedding, token_offset, False
)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
token_offset += this_token_hop_len
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time
duration = token_offset / self.token_frame_rate
if chunk_index > 0 and cost_time > 0:
avg_chunk_processing_time = cost_time / (chunk_index + 1)
if avg_chunk_processing_time > 0:
multiples = (duration - cost_time) / avg_chunk_processing_time
self.logger.log_info(f"multiples: {multiples}")
next_pending_num = len(semantic_token_ids_arr) - token_offset
if multiples > 4:
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
elif multiples > 2:
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
else:
this_token_hop_len = self.token_hop_len
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
chunk_index += 1
else:
time.sleep(0.02)
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
llm_thread.join()
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:
generated_ids = next(generated_ids_iter)
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
if generated_ids is None or len(generated_ids) == 0:
raise pb_utils.TritonModelException("Generated IDs is None or empty")
audio = self.forward_token2wav(generated_ids, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
# Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
responses.append(inference_response)
if not self.decoupled:
return responses

View File

@@ -0,0 +1,73 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "cosyvoice2"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
model_transaction_policy {
decoupled: ${decoupled_mode}
}
parameters [
{
key: "llm_tokenizer_dir",
value: {string_value:"${llm_tokenizer_dir}"}
},
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: true
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: true
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
optional: true
},
{
name: "target_text"
data_type: TYPE_STRING
dims: [1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: ${bls_instance_num}
kind: KIND_CPU
}
]

View File

@@ -0,0 +1,394 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import math
import os
import re
import time
from typing import Dict, List, Tuple, Optional, Union
import asyncio
import httpx
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
import torchaudio
from matcha.utils.audio import mel_spectrogram
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def parse_speech_token_string(response_text: str) -> List[int]:
"""
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
"""
speech_tokens = response_text.strip().split('><')
if len(speech_tokens) > 1:
# Add back the missing '<' and '>' for proper parsing
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
speech_ids = []
for token_str in speech_tokens:
match = re.match(r'<\|s_(\d+)\|>', token_str)
if match:
speech_ids.append(int(match.group(1)))
return speech_ids
class TritonPythonModel:
"""Triton Python model for Spark TTS.
This model orchestrates the end-to-end TTS pipeline by coordinating
between audio tokenizer, LLM, and vocoder components.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
self.logger = pb_utils.Logger
# Parse model parameters
self.model_config = json.loads(args['model_config'])
parameters = self.model_config['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
self.device = torch.device("cuda")
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
self.token_frame_rate = 25
self.flow_pre_lookahead_len = 3
self.token_hop_len = 15
self.http_client = httpx.AsyncClient()
self.api_base = "http://localhost:8000/v1/chat/completions"
self.speaker_cache = {}
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
"""Converts a tensor or list of speech token IDs to a string representation."""
if isinstance(speech_tokens, torch.Tensor):
# Ensure tensor is on CPU and flattened
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
speech_id_str = ""
for token_id in speech_tokens:
# Convert token ID back to the speech number N
token_num = token_id - ORIGINAL_VOCAB_SIZE
speech_id_str += f"<|s_{token_num}|>"
return speech_id_str
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
"""
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
"""
full_text = f"{reference_text}{target_text}"
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
chat = [
{"role": "user", "content": full_text},
{"role": "assistant", "content": prompt_speech_tokens_str}
]
payload = {
"model": "trt_engines_bfloat16",
"messages": chat,
"max_tokens": 750,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": True,
}
buffer = ""
async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
line_data = line[len("data: "):].strip()
if line_data == "[DONE]":
break
try:
json_data = json.loads(line_data)
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
if content:
buffer += content
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
except json.JSONDecodeError:
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
continue
# Process any remaining complete tokens in the buffer after the stream ends
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component.
Args:
wav: Input waveform tensor
wav_len: Waveform length tensor
Returns:
Tuple of global and semantic tokens
"""
inference_request = pb_utils.InferenceRequest(
model_name='audio_tokenizer',
requested_output_names=['prompt_speech_tokens'],
inputs=[wav, wav_len]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
return prompt_speech_tokens
def forward_speaker_embedding(self, wav):
"""Forward pass through the speaker embedding component.
Args:
wav: Input waveform tensor
Returns:
Prompt speaker embedding tensor
"""
inference_request = pb_utils.InferenceRequest(
model_name='speaker_embedding',
requested_output_names=['prompt_spk_embedding'],
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
return prompt_spk_embedding
async def forward_token2wav(
self,
index: int,
target_speech_tokens: torch.Tensor,
request_id: str,
reference_wav: object,
reference_wav_len: object,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
Args:
index: Index of the request
target_speech_tokens: Target speech tokens tensor
request_id: Request ID
reference_wav: Reference waveform tensor
reference_wav_len: Reference waveform length tensor
finalize: Whether to finalize the request
Returns:
Generated waveform tensor
"""
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav_dit',
requested_output_names=[
"waveform",
],
inputs=inputs_tensor,
request_id=request_id,
parameters={"priority": index + 1},
)
inference_response = await inference_request.async_exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(
speech,
n_fft=1920,
num_mels=80,
sampling_rate=24000,
hop_size=480,
win_size=1920,
fmin=0,
fmax=8000).squeeze(
dim=0).transpose(
0,
1).to(
self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
async def _process_request(self, request):
request_id = request.request_id()
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
if reference_text not in self.speaker_cache:
self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0)
prompt_speech_tokens = self.speaker_cache[reference_text]
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
if self.decoupled:
response_sender = request.get_response_sender()
semantic_token_ids_arr = []
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
async for generated_ids in self.forward_llm_async(
target_text=target_text,
reference_text=reference_text,
prompt_speech_tokens=prompt_speech_tokens,
):
if not generated_ids:
break
semantic_token_ids_arr.append(generated_ids)
while True:
pending_num = len(semantic_token_ids_arr) - token_offset
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = await self.forward_token2wav(
chunk_index,
this_tts_speech_token, request_id, wav, wav_len, False
)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
token_offset += this_token_hop_len
if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "equal":
this_token_hop_len = self.token_hop_len
elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time
duration = token_offset / self.token_frame_rate
if chunk_index > 0 and cost_time > 0:
avg_chunk_processing_time = cost_time / (chunk_index + 1)
if avg_chunk_processing_time > 0:
multiples = (duration - cost_time) / avg_chunk_processing_time
next_pending_num = len(semantic_token_ids_arr) - token_offset
if multiples > 4:
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
elif multiples > 2:
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
else:
this_token_hop_len = self.token_hop_len
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
chunk_index += 1
else:
break
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
else:
raise NotImplementedError("Offline TTS mode is not supported")
async def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated audio
"""
tasks = [
asyncio.create_task(self._process_request(request))
for request in requests
]
await asyncio.gather(*tasks)
return None
def finalize(self):
self.logger.log_info("Finalizing CosyVoice DIT model")
if hasattr(self, "http_client"):
asyncio.run(self.http_client.aclose())

View File

@@ -0,0 +1,73 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "cosyvoice2_dit"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
model_transaction_policy {
decoupled: ${decoupled_mode}
}
parameters [
{
key: "llm_tokenizer_dir",
value: {string_value:"${llm_tokenizer_dir}"}
},
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: true
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: true
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
optional: true
},
{
name: "target_text"
data_type: TYPE_STRING
dims: [1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: ${bls_instance_num}
kind: KIND_CPU
}
]

View File

@@ -0,0 +1,153 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import torch
from torch.utils.dlpack import to_dlpack
import triton_python_backend_utils as pb_utils
import os
import numpy as np
import torchaudio.compliance.kaldi as kaldi
from cosyvoice.utils.file_utils import convert_onnx_to_trt
from cosyvoice.utils.common import TrtContextWrapper
import onnxruntime
class TritonPythonModel:
"""Triton Python model for audio tokenization.
This model takes reference audio input and extracts semantic tokens
using s3tokenizer.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.device = torch.device("cuda")
model_dir = model_params["model_dir"]
gpu = "l20"
enable_trt = True
if enable_trt:
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False)
else:
campplus_model = f'{model_dir}/campplus.onnx'
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
trt_kwargs = self.get_spk_trt_kwargs()
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
import tensorrt as trt
with open(spk_model, 'rb') as f:
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_spk_trt_kwargs(self):
min_shape = [(1, 4, 80)]
opt_shape = [(1, 500, 80)]
max_shape = [(1, 3000, 80)]
input_names = ["input"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def _extract_spk_embedding(self, speech):
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
spk_feat = feat - feat.mean(dim=0, keepdim=True)
if isinstance(self.spk_model, onnxruntime.InferenceSession):
embedding = self.spk_model.run(
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device)
else:
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
# NOTE need to synchronize when switching stream
with torch.cuda.device(self.device):
torch.cuda.current_stream().synchronize()
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
batch_size = spk_feat.size(0)
with stream:
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
embedding = torch.empty((batch_size, 192), device=spk_feat.device)
data_ptrs = [spk_feat.contiguous().data_ptr(),
embedding.contiguous().data_ptr()]
for i, j in enumerate(data_ptrs):
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.spk_model.release_estimator(spk_model, stream)
return embedding.half()
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing tokenized outputs
"""
responses = []
# Process each request in batch
for request in requests:
# Extract input tensors
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_array = torch.from_numpy(wav_array).to(self.device)
embedding = self._extract_spk_embedding(wav_array)
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
"prompt_spk_embedding", to_dlpack(embedding))
inference_response = pb_utils.InferenceResponse(
output_tensors=[prompt_spk_embedding_tensor])
responses.append(inference_response)
return responses

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "speaker_embedding"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
}
]
output [
{
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

View File

@@ -0,0 +1,857 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
name: "tensorrt_llm"
backend: "${triton_backend}"
max_batch_size: ${triton_max_batch_size}
model_transaction_policy {
decoupled: ${decoupled_mode}
}
dynamic_batching {
preferred_batch_size: [ ${triton_max_batch_size} ]
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
default_queue_policy: { max_queue_size: ${max_queue_size} }
}
input [
{
name: "input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
allow_ragged_batch: true
optional: true
},
{
name: "encoder_input_features"
data_type: ${encoder_input_features_data_type}
dims: [ -1, -1 ]
allow_ragged_batch: true
optional: true
},
{
name: "encoder_output_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "input_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
},
{
name: "request_output_len"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
},
{
name: "num_return_sequences"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "draft_input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "decoder_input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "decoder_input_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
reshape: { shape: [ ] }
},
{
name: "draft_logits"
data_type: ${logits_datatype}
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "draft_acceptance_threshold"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "end_id"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "pad_id"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "stop_words_list"
data_type: TYPE_INT32
dims: [ 2, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "bad_words_list"
data_type: TYPE_INT32
dims: [ 2, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "embedding_bias"
data_type: TYPE_FP32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "beam_width"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "temperature"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_k"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_min"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_decay"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_reset_ids"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "len_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "early_stopping"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "repetition_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "min_length"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "beam_search_diversity_rate"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "presence_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "frequency_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "random_seed"
data_type: TYPE_UINT64
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_log_probs"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_context_logits"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_generation_logits"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_perf_metrics"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "exclude_input_in_output"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "stop"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "streaming"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "prompt_embedding_table"
data_type: TYPE_FP16
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "prompt_table_extra_ids"
data_type: TYPE_UINT64
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "prompt_vocab_size"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
# cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
{
name: "cross_attention_mask"
data_type: TYPE_BOOL
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
# Mrope param when mrope is used
{
name: "mrope_rotary_cos_sin"
data_type: TYPE_FP32
dims: [ -1 ]
optional: true
},
{
name: "mrope_position_deltas"
data_type: TYPE_INT64
dims: [ 1 ]
optional: true
},
# the unique task ID for the given LoRA.
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
{
name: "lora_task_id"
data_type: TYPE_UINT64
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
# each of the in / out tensors are first flattened and then concatenated together in the format above.
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
{
name: "lora_weights"
data_type: TYPE_FP16
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
# module identifier (same size a first dimension of lora_weights)
# See LoraModule::ModuleType for model id mapping
#
# "attn_qkv": 0 # compbined qkv adapter
# "attn_q": 1 # q adapter
# "attn_k": 2 # k adapter
# "attn_v": 3 # v adapter
# "attn_dense": 4 # adapter for the dense layer in attention
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
#
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
{
name: "lora_config"
data_type: TYPE_INT32
dims: [ -1, 3 ]
optional: true
allow_ragged_batch: true
},
{
name: "context_phase_params"
data_type: TYPE_UINT8
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
# skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
{
name: "skip_cross_attn_blocks"
data_type: TYPE_BOOL
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_starts"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_ends"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_priorities"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_durations_ms"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_decode_priority"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_decode_duration_ms"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "guided_decoding_guide_type"
data_type: TYPE_STRING
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "guided_decoding_guide"
data_type: TYPE_STRING
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "lookahead_window_size"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "lookahead_ngram_size"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "lookahead_verification_set_size"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
}
]
output [
{
name: "output_ids"
data_type: TYPE_INT32
dims: [ -1, -1 ]
},
{
name: "sequence_length"
data_type: TYPE_INT32
dims: [ -1 ]
},
{
name: "cum_log_probs"
data_type: TYPE_FP32
dims: [ -1 ]
},
{
name: "output_log_probs"
data_type: TYPE_FP32
dims: [ -1, -1 ]
},
{
name: "context_logits"
data_type: ${logits_datatype}
dims: [ -1, -1 ]
},
{
name: "generation_logits"
data_type: ${logits_datatype}
dims: [ -1, -1, -1 ]
},
{
name: "batch_index"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "sequence_index"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "context_phase_params"
data_type: TYPE_UINT8
dims: [ -1 ]
},
{
name: "kv_cache_alloc_new_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "kv_cache_reused_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "kv_cache_alloc_total_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "arrival_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "first_scheduled_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "first_token_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "last_token_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "acceptance_rate"
data_type: TYPE_FP32
dims: [ 1 ]
},
{
name: "total_accepted_draft_tokens"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "total_draft_tokens"
data_type: TYPE_INT32
dims: [ 1 ]
}
]
instance_group [
{
count: 1
kind : KIND_CPU
}
]
parameters: {
key: "max_beam_width"
value: {
string_value: "${max_beam_width}"
}
}
parameters: {
key: "FORCE_CPU_ONLY_INPUT_TENSORS"
value: {
string_value: "no"
}
}
parameters: {
key: "gpt_model_type"
value: {
string_value: "${batching_strategy}"
}
}
parameters: {
key: "gpt_model_path"
value: {
string_value: "${engine_dir}"
}
}
parameters: {
key: "encoder_model_path"
value: {
string_value: "${encoder_engine_dir}"
}
}
parameters: {
key: "max_tokens_in_paged_kv_cache"
value: {
string_value: "${max_tokens_in_paged_kv_cache}"
}
}
parameters: {
key: "max_attention_window_size"
value: {
string_value: "${max_attention_window_size}"
}
}
parameters: {
key: "sink_token_length"
value: {
string_value: "${sink_token_length}"
}
}
parameters: {
key: "batch_scheduler_policy"
value: {
string_value: "${batch_scheduler_policy}"
}
}
parameters: {
key: "kv_cache_free_gpu_mem_fraction"
value: {
string_value: "${kv_cache_free_gpu_mem_fraction}"
}
}
parameters: {
key: "cross_kv_cache_fraction"
value: {
string_value: "${cross_kv_cache_fraction}"
}
}
parameters: {
key: "kv_cache_host_memory_bytes"
value: {
string_value: "${kv_cache_host_memory_bytes}"
}
}
# kv_cache_onboard_blocks is for internal implementation.
parameters: {
key: "kv_cache_onboard_blocks"
value: {
string_value: "${kv_cache_onboard_blocks}"
}
}
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
# parameters: {
# key: "enable_trt_overlap"
# value: {
# string_value: "${enable_trt_overlap}"
# }
# }
parameters: {
key: "exclude_input_in_output"
value: {
string_value: "${exclude_input_in_output}"
}
}
parameters: {
key: "cancellation_check_period_ms"
value: {
string_value: "${cancellation_check_period_ms}"
}
}
parameters: {
key: "stats_check_period_ms"
value: {
string_value: "${stats_check_period_ms}"
}
}
parameters: {
key: "iter_stats_max_iterations"
value: {
string_value: "${iter_stats_max_iterations}"
}
}
parameters: {
key: "request_stats_max_iterations"
value: {
string_value: "${request_stats_max_iterations}"
}
}
parameters: {
key: "enable_kv_cache_reuse"
value: {
string_value: "${enable_kv_cache_reuse}"
}
}
parameters: {
key: "normalize_log_probs"
value: {
string_value: "${normalize_log_probs}"
}
}
parameters: {
key: "enable_chunked_context"
value: {
string_value: "${enable_chunked_context}"
}
}
parameters: {
key: "gpu_device_ids"
value: {
string_value: "${gpu_device_ids}"
}
}
parameters: {
key: "participant_ids"
value: {
string_value: "${participant_ids}"
}
}
parameters: {
key: "lora_cache_optimal_adapter_size"
value: {
string_value: "${lora_cache_optimal_adapter_size}"
}
}
parameters: {
key: "lora_cache_max_adapter_size"
value: {
string_value: "${lora_cache_max_adapter_size}"
}
}
parameters: {
key: "lora_cache_gpu_memory_fraction"
value: {
string_value: "${lora_cache_gpu_memory_fraction}"
}
}
parameters: {
key: "lora_cache_host_memory_bytes"
value: {
string_value: "${lora_cache_host_memory_bytes}"
}
}
parameters: {
key: "lora_prefetch_dir"
value: {
string_value: "${lora_prefetch_dir}"
}
}
parameters: {
key: "decoding_mode"
value: {
string_value: "${decoding_mode}"
}
}
parameters: {
key: "executor_worker_path"
value: {
string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
}
}
parameters: {
key: "lookahead_window_size"
value: {
string_value: "${lookahead_window_size}"
}
}
parameters: {
key: "lookahead_ngram_size"
value: {
string_value: "${lookahead_ngram_size}"
}
}
parameters: {
key: "lookahead_verification_set_size"
value: {
string_value: "${lookahead_verification_set_size}"
}
}
parameters: {
key: "medusa_choices"
value: {
string_value: "${medusa_choices}"
}
}
parameters: {
key: "eagle_choices"
value: {
string_value: "${eagle_choices}"
}
}
parameters: {
key: "gpu_weights_percent"
value: {
string_value: "${gpu_weights_percent}"
}
}
parameters: {
key: "enable_context_fmha_fp32_acc"
value: {
string_value: "${enable_context_fmha_fp32_acc}"
}
}
parameters: {
key: "multi_block_mode"
value: {
string_value: "${multi_block_mode}"
}
}
parameters: {
key: "cuda_graph_mode"
value: {
string_value: "${cuda_graph_mode}"
}
}
parameters: {
key: "cuda_graph_cache_size"
value: {
string_value: "${cuda_graph_cache_size}"
}
}
parameters: {
key: "speculative_decoding_fast_logits"
value: {
string_value: "${speculative_decoding_fast_logits}"
}
}
parameters: {
key: "tokenizer_dir"
value: {
string_value: "${tokenizer_dir}"
}
}
parameters: {
key: "guided_decoding_backend"
value: {
string_value: "${guided_decoding_backend}"
}
}
parameters: {
key: "xgrammar_tokenizer_info_path"
value: {
string_value: "${xgrammar_tokenizer_info_path}"
}
}

View File

@@ -0,0 +1,277 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import logging
import torch
from torch.utils.dlpack import to_dlpack
from torch.nn import functional as F
import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
from collections import defaultdict
import numpy as np
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
class CosyVoice2:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
self.model_dir = model_dir
self.fp16 = fp16
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
class CosyVoice2Model:
def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False,
device: str = 'cuda'):
self.device = device
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.flow.half()
# streaming tts config
self.token_hop_len = 25
self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480)
self.speech_window = np.hamming(2 * self.source_cache_len)
self.hift_cache_dict = defaultdict(lambda: None)
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load(self, flow_model, hift_model):
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if finalize is False:
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
'source': tts_source[:, :, -self.source_cache_len:],
'speech': tts_speech[:, -self.source_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech
class TritonPythonModel:
"""Triton Python model for vocoder.
This model takes global and semantic tokens as input and generates audio waveforms
using the BiCodec vocoder.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {key: value["string_value"] for key, value in parameters.items()}
model_dir = model_params["model_dir"]
# Initialize device and vocoder
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
self.token2wav_model = CosyVoice2(
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
)
spk_info_path = os.path.join(model_dir, "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_dir}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
logger.info("Token2Wav initialized successfully")
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated waveforms
"""
responses = []
# Process each request in batch
for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
if prompt_speech_tokens_tensor is not None:
prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
else:
prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
# shift the speech tokens according to the original vocab size
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
if token_offset is not None:
token_offset = token_offset.as_numpy().item()
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
if not finalize:
stream = True
else:
stream = False
request_id = request.request_id()
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
prompt_token=prompt_speech_tokens,
prompt_feat=prompt_speech_feat,
embedding=prompt_spk_embedding,
token_offset=token_offset,
uuid=request_id,
stream=stream,
finalize=finalize)
if finalize:
self.token2wav_model.model.hift_cache_dict.pop(request_id)
else:
tts_mel, _ = self.token2wav_model.model.flow.inference(
token=target_speech_tokens,
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
self.device
),
prompt_token=prompt_speech_tokens,
prompt_token_len=torch.tensor(
[prompt_speech_tokens.shape[1]], dtype=torch.int32
).to(self.device),
prompt_feat=prompt_speech_feat,
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=prompt_spk_embedding,
streaming=False,
finalize=True,
)
audio_hat, _ = self.token2wav_model.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
generated_wave = audio_hat.squeeze(0).cpu().numpy()
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
responses.append(inference_response)
return responses

View File

@@ -0,0 +1,80 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "token2wav"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "target_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
},
{
name: "prompt_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
optional: true
},
{
name: "prompt_speech_feat"
data_type: TYPE_FP16
dims: [-1, 80]
optional: true
},
{
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
optional: true
},
{
name: "token_offset"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "finalize"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

View File

@@ -0,0 +1,142 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import logging
from typing import List, Dict
import torch
from torch.utils.dlpack import to_dlpack
from torch.nn import functional as F
import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
from collections import defaultdict
import numpy as np
from .token2wav_dit import CosyVoice2_Token2Wav
import hashlib
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
"""
Generates a unique ID for a torch.Tensor.
Tensors with the same elements and properties will have the same ID.
"""
# Convert tensor to a byte string
tensor_bytes = tensor.numpy().tobytes()
# Create a SHA-256 hash of the byte string
hasher = hashlib.sha256()
hasher.update(tensor_bytes)
return hasher.hexdigest()
class TritonPythonModel:
"""Triton Python model for vocoder.
This model takes global and semantic tokens as input and generates audio waveforms
using the BiCodec vocoder.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {key: value["string_value"] for key, value in parameters.items()}
model_dir = model_params["model_dir"]
# Initialize device and vocoder
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
# FIXME: device id settings
self.token2wav_model = CosyVoice2_Token2Wav(
model_dir, enable_trt=True, streaming=True
)
logger.info("Token2Wav initialized successfully")
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated waveforms
"""
responses = []
# Process each request in batch
for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens.squeeze().tolist()
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
request_id = request.request_id()
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(
request, "reference_wav_len").as_numpy().item()
wav_array = torch.from_numpy(wav_array)
wav = wav_array[:, :wav_len].squeeze(0)
spk_id = get_spk_id_from_prompt_audio(wav)
audio_hat = self.token2wav_model.forward_streaming(
target_speech_tokens, finalize, request_id=request_id,
speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
)
outputs = []
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
outputs.append(wav_tensor)
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
responses.append(inference_response)
return responses

View File

@@ -0,0 +1,510 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
CUDA_VISIBLE_DEVICES=0 \
python3 token2wav.py --enable-trt || exit 1
"""
import torch
# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
from flashcosyvoice.modules.hifigan import HiFTGenerator
from flashcosyvoice.utils.audio import mel_spectrogram
import torchaudio.compliance.kaldi as kaldi
import onnxruntime
import s3tokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset
import torchaudio
import os
import logging
import argparse
import queue
import time
import numpy as np
from hyperpyyaml import load_hyperpyyaml
def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor):
"""perform fade_in_out in tensor style
"""
mel_overlap_len = int(window.shape[0] / 2)
fade_in_mel = fade_in_mel.clone()
fade_in_mel[..., :mel_overlap_len] = \
fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
import tensorrt as trt
logging.info("Converting onnx to trt...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
if dtype == torch.float16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
# load onnx model
with open(onnx_model, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise ValueError('failed to parse {}'.format(onnx_model))
# set input shapes
for i in range(len(trt_kwargs['input_names'])):
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
if dtype == torch.float16:
tensor_dtype = trt.DataType.HALF
elif dtype == torch.bfloat16:
tensor_dtype = trt.DataType.BF16
elif dtype == torch.float32:
tensor_dtype = trt.DataType.FLOAT
else:
raise ValueError('invalid dtype {}'.format(dtype))
# set input and output data type
for i in range(network.num_inputs):
input_tensor = network.get_input(i)
input_tensor.dtype = tensor_dtype
for i in range(network.num_outputs):
output_tensor = network.get_output(i)
output_tensor.dtype = tensor_dtype
config.add_optimization_profile(profile)
engine_bytes = builder.build_serialized_network(network, config)
# save trt engine
with open(trt_model, "wb") as f:
f.write(engine_bytes)
logging.info("Succesfully convert onnx to trt...")
class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
self.trt_engine = trt_engine
self.device = device
for _ in range(trt_concurrent):
trt_context = trt_engine.create_execution_context()
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
self.trt_context_pool.put([trt_context, trt_stream])
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
def acquire_estimator(self):
return self.trt_context_pool.get(), self.trt_engine
def release_estimator(self, context, stream):
self.trt_context_pool.put([context, stream])
class CosyVoice2_Token2Wav(torch.nn.Module):
def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
super().__init__()
self.device_id = device_id
self.device = f"cuda:{device_id}"
with open(f"{model_dir}/flow.yaml", "r") as f:
configs = load_hyperpyyaml(f)
self.flow = configs['flow']
self.dtype = dtype
self.flow.to(self.dtype)
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
self.flow.to(self.device).eval()
self.hift = HiFTGenerator()
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(
f"{model_dir}/campplus.onnx", sess_options=option,
providers=["CPUExecutionProvider"])
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
gpu = "l20"
if enable_trt:
if streaming:
self.load_trt(
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
1,
self.dtype, streaming
)
else:
self.load_trt(
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
1,
self.dtype
)
self.load_spk_trt(
f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False
)
self.streaming_flow_cache = {}
self.speaker_cache = {}
self.mel_cache_len = 8 # hard-coded, 160ms
self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave
self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
# hifigan cache for streaming tts
self.hift_cache_dict = {}
def forward_spk_embedding(self, spk_feat):
if isinstance(self.spk_model, onnxruntime.InferenceSession):
return self.spk_model.run(
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
else:
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
# NOTE need to synchronize when switching stream
with torch.cuda.device(self.device_id):
torch.cuda.current_stream().synchronize()
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
batch_size = spk_feat.size(0)
with stream:
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
data_ptrs = [spk_feat.contiguous().data_ptr(),
output_tensor.contiguous().data_ptr()]
for i, j in enumerate(data_ptrs):
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.spk_model.release_estimator(spk_model, stream)
return output_tensor.cpu().numpy().flatten().tolist()
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
trt_kwargs = self.get_spk_trt_kwargs()
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32)
import tensorrt as trt
with open(spk_model, 'rb') as f:
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_spk_trt_kwargs(self):
min_shape = [(1, 4, 80)]
opt_shape = [(1, 500, 80)]
max_shape = [(1, 3000, 80)]
input_names = ["input"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
opt_batch_size = 2
max_batch_size = 16
if streaming:
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
if streaming:
min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
opt_shape = [
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500),
(opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2),
(16, opt_batch_size * 2, 8, 100, 128)
]
max_shape = [
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000),
(max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2),
(16, max_batch_size * 2, 8, 1000, 128)
]
input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
else:
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
opt_shape = [
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500),
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80)
]
max_shape = [
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000),
(max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80)
]
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
prompt_speech_tokens_list, prompt_speech_mels_list = [], []
for audio in prompt_audios_list:
assert len(audio.shape) == 1
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
prompt_speech_mels_list.append(log_mel)
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
)
for i in range(len(prompt_speech_tokens)):
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
prompt_speech_tokens_list.append(speech_tokens_i)
return prompt_speech_tokens_list
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
spk_emb_for_flow = []
for audio in prompt_audios_list:
assert len(audio.shape) == 1
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
spk_emb = self.forward_spk_embedding(spk_feat)
spk_emb_for_flow.append(spk_emb)
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
if self.dtype != torch.float32:
spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
return spk_emb_for_flow
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
prompt_mels_for_flow = []
prompt_mels_lens_for_flow = []
for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
assert len(audio.shape) == 1
audio = audio.unsqueeze(0)
if sample_rate != 24000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
mel_len = mel.shape[0]
prompt_mels_for_flow.append(mel)
prompt_mels_lens_for_flow.append(mel_len)
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
prompt_mels_for_flow, batch_first=True, padding_value=0
) # [B, T', num_mels=80]
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
return prompt_mels_for_flow, prompt_mels_lens_for_flow
def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
generated_speech_tokens_list: list[list[int]],
prompt_mels_for_flow: torch.Tensor,
prompt_mels_lens_for_flow: torch.Tensor,
spk_emb_for_flow: torch.Tensor):
batch_size = prompt_mels_for_flow.shape[0]
flow_inputs = []
flow_inputs_lens = []
for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
flow_inputs_lens = torch.tensor(flow_inputs_lens)
with torch.amp.autocast(self.device, dtype=torch.float16):
generated_mels, generated_mels_lens = self.flow.inference(
flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10
)
return generated_mels, generated_mels_lens
def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
batch_size = generated_mels.shape[0]
generated_wavs = []
for i in range(batch_size):
mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
wav, _ = self.hift(speech_feat=mel)
generated_wavs.append(wav)
return generated_wavs
@torch.inference_mode()
def forward(
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
):
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
generated_mels, generated_mels_lens = self.forward_flow(
prompt_speech_tokens_list, generated_speech_tokens_list,
prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
)
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
return generated_wavs
def prepare_prompt_audio(
self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
):
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
def get_prompt_audio_cache_for_streaming_tts(
self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
):
assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts"
for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list):
prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3])
prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0)
cache = self.flow.setup_cache(
prompt_speech_tokens_tensor.to(self.device),
prompt_mels_for_flow.to(self.device),
spk_emb_for_flow.to(self.device),
n_timesteps=10
)
new_cache = {k: v.clone() for k, v in cache.items()}
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
return new_cache
@torch.inference_mode()
def forward_streaming(
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
):
if speaker_id not in self.speaker_cache:
assert prompt_audio is not None, "prompt_audio is required for new speaker"
assert prompt_audio_sample_rate == 16000
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate])
token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0]))
prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
if request_id not in self.streaming_flow_cache:
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
self.hift_cache_dict[request_id] = dict(
mel=torch.zeros(1, 80, 0, device='cuda'),
source=torch.zeros(1, 1, 0, device='cuda'),
speech=torch.zeros(1, 0, device='cuda'),
)
current_request_cache = self.streaming_flow_cache[request_id]
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
token=generated_speech_tokens,
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
cache=current_request_cache,
last_chunk=last_chunk,
n_timesteps=10,
)
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
], dim=4)
hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone()
speech, source = self.hift(mel, hift_cache_source)
# overlap speech smooth
if hift_cache_speech.shape[-1] > 0:
speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
# update vocoder cache
self.hift_cache_dict[request_id] = dict(
mel=mel[..., -self.mel_cache_len:].clone().detach(),
source=source[:, :, -self.source_cache_len:].clone().detach(),
speech=speech[:, -self.source_cache_len:].clone().detach(),
)
if not last_chunk:
speech = speech[:, :-self.source_cache_len]
if last_chunk:
assert request_id in self.streaming_flow_cache
self.streaming_flow_cache.pop(request_id)
self.hift_cache_dict.pop(request_id)
return speech
def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
for item in batch:
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio)
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
ids.append(item['id'])
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--enable-trt", action="store_true")
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--output-dir", type=str, default="generated_wavs")
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
dataset_name = "yuekai/seed_tts_cosy2"
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
for _ in range(args.warmup):
start_time = time.time()
for batch in data_loader:
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
for id, wav in zip(ids, generated_wavs):
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
end_time = time.time()
epoch_time = end_time - start_time
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")

View File

@@ -0,0 +1,69 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "token2wav_dit"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
priority_levels: 10
default_priority_level: 10
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "target_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
},
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
},
{
name: "finalize"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

View File

@@ -0,0 +1,652 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
CUDA_VISIBLE_DEVICES=0 \
python3 offline_inference.py \
--output-dir $output_dir \
--llm-model-name-or-path $huggingface_model_local_dir \
--token2wav-path $model_scope_model_local_dir \
--backend $backend \
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
--engine-dir $trt_engines_dir \
--split-name ${dataset} || exit 1
"""
import argparse
import json
import os
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from cosyvoice.utils.file_utils import load_wav
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import soundfile as sf
import s3tokenizer
from functools import partial
import time
import requests
import asyncio
import httpx
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
async def send_request_async(client, url, payload):
response = await client.post(url, json=payload, timeout=None)
response.raise_for_status()
response_json = response.json()
return response_json['choices'][0]['message']['content']
async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
async with httpx.AsyncClient() as client:
tasks = []
for chat in chats:
payload = {
"model": model_name,
"messages": chat,
"max_tokens": 2048,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": False,
}
tasks.append(send_request_async(client, api_base, payload))
return await asyncio.gather(*tasks)
def extract_speech_ids(speech_tokens_str):
"""Extract speech IDs from token strings like <|s_23456|>"""
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith('<|s_') and token_str.endswith('|>'):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
speech_id_str = ""
for token in cosy2_tokens:
speech_id_str += f"<|s_{token}|>"
return speech_id_str
def get_args():
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch-size",
default=1,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--token2wav-batch-size",
default=1,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num-workers", type=int, default=0, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=None, help="prefetch for dataloader"
)
parser.add_argument(
"--llm-model-name-or-path",
required=True,
type=str,
help="LLM model path (includes both model and tokenizer)",
)
parser.add_argument(
"--token2wav-path",
required=True,
type=str,
help="CosyVoice2 token2wav model path",
)
parser.add_argument(
"--prompt-text",
type=str,
default=None,
help="The prompt text for CosyVoice2",
)
parser.add_argument(
"--prompt-speech-path",
type=str,
default=None,
help="The path to the prompt speech for CosyVoice2",
)
parser.add_argument(
"--top-p",
type=float,
default=0.95,
help="top p for sampling",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="temperature for sampling",
)
parser.add_argument(
"--top-k",
type=int,
default=50,
help="top k for sampling",
)
parser.add_argument(
"--backend",
type=str,
default="hf",
choices=["hf", "trtllm", "vllm", "trtllm-serve"],
help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
)
parser.add_argument(
"--engine-dir",
type=str,
default=None,
help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
)
parser.add_argument(
"--kv-cache-free-gpu-memory-fraction",
type=float,
default=0.6,
help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
)
parser.add_argument(
"--openai-api-base",
type=str,
default="http://localhost:8000/v1/chat/completions",
help="OpenAI API base URL (for trtllm-serve backend)",
)
parser.add_argument(
"--openai-model-name",
type=str,
default="trt_engines_bfloat16",
help="Model name to use with OpenAI API (for trtllm-serve backend)",
)
args = parser.parse_args()
return args
def data_collator(batch, tokenizer, s3_tokenizer):
"""Simplified data collator for batch_size=1 processing"""
collator_start_time = time.time()
total_audio_processing_time = 0
total_speech_tokenization_time = 0
total_text_tokenization_time = 0
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
prompt_text_after_apply_template_list = []
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
chat_list = []
for _, item in enumerate(batch):
audio_processing_start_time = time.time()
prompt_text, target_text = (
item["prompt_text"],
item["target_text"],
)
prompt_text_list.append(prompt_text)
full_text = prompt_text + target_text
full_text_list.append(full_text)
# remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
puncts = ['"', '(', ')', '', '', '', '', '', '\'']
for p in puncts:
if p in full_text:
full_text = full_text.replace(p, '')
print(f"removed {p} from {full_text}")
# get prompt audio for CosyVoice2 (convert to 16kHz)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
print(ref_audio_org.shape)
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
prompt_audio_list.append(ref_audio)
audio_processing_end_time = time.time()
total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
speech_tokenization_start_time = time.time()
if "prompt_audio_cosy2_tokens" in item:
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
else:
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
if len(mels) > 0:
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
for i in range(len(codes)):
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
speech_tokenization_end_time = time.time()
total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
text_tokenization_start_time = time.time()
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
# Create chat template for LLM generation
chat = [
{"role": "user", "content": full_text_list[i]},
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
]
chat_list.append(chat)
assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
)
input_ids_list.append(input_ids.squeeze(0))
prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
text_tokenization_end_time = time.time()
total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
ids = [item["id"] for item in batch]
return {
"input_ids": input_ids_list,
"ids": ids,
"prompt_text": prompt_text_list,
"prompt_audio_list": prompt_audio_list,
"prompt_text_after_apply_template": prompt_text_after_apply_template_list,
"audio_processing_time": total_audio_processing_time,
"speech_tokenization_time": total_speech_tokenization_time,
"text_tokenization_time": total_text_tokenization_time,
"chat_list": chat_list
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
local_rank, world_size, rank = 0, 1, 0
device = torch.device(f"cuda:{local_rank}")
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
if args.backend == "hf":
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
model.eval()
model.to(device)
runner = None
elif args.backend == "trtllm":
if args.engine_dir is None:
raise ValueError("--engine-dir is required when backend is 'trtllm'")
runtime_rank = tensorrt_llm.mpi_rank()
model = None
runner_kwargs = dict(
engine_dir=args.engine_dir,
rank=runtime_rank,
max_output_len=2048,
enable_context_fmha_fp32_acc=False,
max_batch_size=args.batch_size,
max_input_len=512,
kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
cuda_graph_mode=False,
gather_generation_logits=False,
)
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
elif args.backend == "vllm":
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
runner = None
elif args.backend == "trtllm-serve":
model = None
runner = None
else:
raise ValueError(f"Unsupported backend: {args.backend}")
if 'Step-Audio-2-mini' in args.token2wav_path:
from token2wav_dit import CosyVoice2_Token2Wav
else:
assert 'CosyVoice2-0.5B' in args.token2wav_path
from token2wav import CosyVoice2_Token2Wav
token2wav_model = CosyVoice2_Token2Wav(
model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
)
if args.prompt_speech_path:
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
else:
prompt_speech_16k = None
s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
dataset = load_dataset(
dataset_name,
split=args.split_name,
trust_remote_code=True,
)
sampler = None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
)
for _ in range(3):
print(f"Running {_} times")
total_llm_time = 0
total_token2wav_time = 0
total_data_load_time = 0
total_llm_post_processing_time = 0
total_audio_save_time = 0
total_audio_processing_time_in_collator = 0
total_speech_tokenization_time_in_collator = 0
total_text_tokenization_time_in_collator = 0
total_audio_samples = 0
start_time = time.time()
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
last_batch_end_time = time.time()
for batch in dataloader:
data_loaded_time = time.time()
total_data_load_time += data_loaded_time - last_batch_end_time
total_audio_processing_time_in_collator += batch["audio_processing_time"]
total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
with torch.no_grad():
llm_start_time = time.time()
if args.backend == "hf":
input_ids_list = batch["input_ids"]
if len(input_ids_list) == 1:
input_ids = input_ids_list[0].unsqueeze(0)
attention_mask = torch.ones_like(input_ids)
else:
max_len = max([len(input_ids) for input_ids in input_ids_list])
input_ids_list_new = [
torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
for input_ids in input_ids_list
]
input_ids = torch.stack(input_ids_list_new)
attention_mask = torch.zeros_like(input_ids)
for i in range(len(input_ids_list)):
attention_mask[i, :len(input_ids_list[i])] = 1
input_ids = input_ids.to(device)
outputs = model.generate(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
max_new_tokens=2048,
do_sample=True,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=1.1,
top_k=args.top_k,
)
torch.cuda.synchronize()
elif args.backend == "trtllm":
batch_input_ids = list(batch["input_ids"])
input_lengths = [x.size(0) for x in batch_input_ids]
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
outputs = runner.generate(
batch_input_ids=batch_input_ids,
max_new_tokens=2048,
end_id=end_id,
pad_id=end_id,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=1.1,
num_return_sequences=1,
streaming=False,
output_sequence_lengths=True,
output_generation_logits=False,
return_dict=True,
return_all_generated_tokens=False
)
torch.cuda.synchronize()
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
num_output_sents, num_beams, _ = output_ids.size()
assert num_beams == 1
beam = 0
batch_size = len(batch["input_ids"])
num_return_sequences = num_output_sents // batch_size
assert num_return_sequences == 1
outputs = []
for i in range(batch_size * num_return_sequences):
batch_idx = i // num_return_sequences
seq_idx = i % num_return_sequences
output_begin = input_lengths[batch_idx]
output_end = sequence_lengths[i][beam]
outputs_i = output_ids[i][beam][:output_end].tolist()
outputs.append(outputs_i)
elif args.backend == "vllm":
input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=1.1,
max_tokens=2048,
)
outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
print(outputs)
for j, output in enumerate(outputs):
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
elif args.backend == "trtllm-serve":
if args.batch_size > 1:
outputs = asyncio.run(send_batch_requests_async(
args.openai_api_base,
args.openai_model_name,
batch["chat_list"],
args.temperature,
args.top_p,
args.top_k,
))
else:
outputs = []
for chat in batch["chat_list"]:
payload = {
"model": args.openai_model_name,
"messages": chat,
"max_tokens": 2048,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": False,
}
response = requests.post(args.openai_api_base, json=payload)
response.raise_for_status()
response_json = response.json()
generated_content = response_json['choices'][0]['message']['content']
outputs.append(generated_content)
llm_end_time = time.time()
total_llm_time += (llm_end_time - llm_start_time)
items_for_token_2wav = []
for i in range(len(batch["ids"])):
llm_post_processing_start_time = time.time()
if args.backend == "trtllm-serve":
speech_tokens_str = outputs[i].strip().split('><')
if len(speech_tokens_str) > 1:
speech_tokens_str = [
t if t.startswith('<') else '<' + t for t in speech_tokens_str
]
speech_tokens_str = [
t if t.endswith('>') else t + '>' for t in speech_tokens_str
]
speech_ids = extract_speech_ids(speech_tokens_str)
else:
input_length = len(batch["input_ids"][i])
generated_ids = outputs[i][input_length:]
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
speech_ids = extract_speech_ids(speech_tokens_str)
print(i, speech_ids)
if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
continue
if args.prompt_text is not None:
current_prompt_text = args.prompt_text
current_prompt_audio = prompt_speech_16k
else:
current_prompt_text = batch["prompt_text"][i]
current_prompt_audio = batch["prompt_audio_list"][i]
llm_post_processing_end_time = time.time()
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
if current_prompt_audio is not None:
items_for_token_2wav.append({
"speech_ids": speech_ids,
"prompt_audio": current_prompt_audio.squeeze(0),
"id": batch["ids"][i]
})
else:
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
if not t2w_batch:
continue
t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
t2w_ids = [item["id"] for item in t2w_batch]
token2wav_start_time = time.time()
generated_wavs = token2wav_model(
t2w_generated_speech_tokens_list,
t2w_prompt_audios_list,
t2w_prompt_audios_sample_rate,
)
token2wav_end_time = time.time()
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
audio_save_start_time = time.time()
for j, audio_hat in enumerate(generated_wavs):
generated_wave = audio_hat.squeeze().cpu().numpy()
total_audio_samples += len(generated_wave)
target_sample_rate = 24000
utt = t2w_ids[j]
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
audio_save_end_time = time.time()
total_audio_save_time += audio_save_end_time - audio_save_start_time
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
last_batch_end_time = time.time()
if rank == 0:
progress_bar.close()
end_time = time.time()
target_sample_rate = 24000
total_audio_duration_seconds = total_audio_samples / target_sample_rate
log_file_path = os.path.join(args.output_dir, "log.txt")
with open(log_file_path, 'w') as f:
args_dict = vars(args)
log_data = {
"args": args_dict,
"data_load_time_seconds": total_data_load_time,
"audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
"speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
"text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
"llm_time_seconds": total_llm_time,
"llm_post_processing_time_seconds": total_llm_post_processing_time,
"token2wav_time_seconds": total_token2wav_time,
"audio_save_time_seconds": total_audio_save_time,
"total_audio_duration_seconds": total_audio_duration_seconds,
"pipeline_time_seconds": end_time - start_time,
}
print(log_data)
f.write(json.dumps(log_data, indent=4))
print(f"Metrics logged to {log_file_path}")
if __name__ == "__main__":
args = get_args()
if args.backend == "vllm":
from vllm import LLM, SamplingParams
elif args.backend == "trtllm":
import tensorrt_llm
from tensorrt_llm.runtime import ModelRunnerCpp
elif args.backend == "hf":
from transformers import AutoModelForCausalLM
elif args.backend == "trtllm-serve":
pass
else:
raise ValueError(f"Unsupported backend: {args.backend}")
main(args)

View File

@@ -0,0 +1,14 @@
hyperpyyaml
s3tokenizer
onnxruntime-gpu
omegaconf
conformer
hydra-core
lightning
gdown
wget
librosa
pyworld
openai-whisper
tritonclient
modelscope

View File

@@ -0,0 +1,142 @@
#!/bin/bash
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
export CUDA_VISIBLE_DEVICES=0
cosyvoice_path=/workspace/CosyVoice
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
stage=$1
stop_stage=$2
huggingface_model_local_dir=./cosyvoice2_llm
model_scope_model_local_dir=./CosyVoice2-0.5B
trt_dtype=bfloat16
trt_weights_dir=./trt_weights_${trt_dtype}
trt_engines_dir=./trt_engines_${trt_dtype}
model_repo=./model_repo_cosyvoice2
use_spk2info_cache=False
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
echo "Cloning CosyVoice"
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
cd $cosyvoice_path
git submodule update --init --recursive
cd runtime/triton_trtllm
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading CosyVoice2-0.5B"
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
# download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint to TensorRT weights"
python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
--output_dir $trt_weights_dir \
--dtype $trt_dtype || exit 1
echo "Building TensorRT engines"
trtllm-build --checkpoint_dir $trt_weights_dir \
--output_dir $trt_engines_dir \
--max_batch_size 16 \
--max_num_tokens 32768 \
--gemm_plugin $trt_dtype || exit 1
echo "Testing TensorRT engines"
python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
--tokenizer_dir $huggingface_model_local_dir \
--top_k 50 --top_p 0.95 --temperature 0.8 \
--engine_dir=$trt_engines_dir || exit 1
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Creating model repository"
rm -rf $model_repo
mkdir -p $model_repo
cosyvoice2_dir="cosyvoice2"
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
cp -r ./model_repo/tensorrt_llm $model_repo
cp -r ./model_repo/token2wav $model_repo
if [ $use_spk2info_cache == "False" ]; then
cp -r ./model_repo/audio_tokenizer $model_repo
cp -r ./model_repo/speaker_embedding $model_repo
fi
ENGINE_PATH=$trt_engines_dir
MAX_QUEUE_DELAY_MICROSECONDS=0
MODEL_DIR=$model_scope_model_local_dir
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
BLS_INSTANCE_NUM=4
TRITON_MAX_BATCH_SIZE=16
DECOUPLED_MODE=True # True for streaming, False for offline
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
if [ $use_spk2info_cache == "False" ]; then
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Starting Triton server"
tritonserver --model-repository $model_repo
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Single request test http, only work for offline TTS mode"
python3 client_http.py \
--reference-audio ./assets/prompt_audio.wav \
--reference-text "吃燕窝就选燕之屋本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝营养更均衡本节目由豆本豆豆奶特约播出。" \
--target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
--model-name cosyvoice2
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Running benchmark client grpc"
num_task=4
mode=streaming
BLS_INSTANCE_NUM=4
python3 client_grpc.py \
--server-addr localhost \
--model-name cosyvoice2 \
--num-tasks $num_task \
--mode $mode \
--use-spk2info-cache $use_spk2info_cache \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "stage 6: Offline inference benchmark"
n_gpus=1
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
backend=trtllm # hf, trtllm, vllm
batch_sizes=(16 8 4 2 1)
token2wav_batch_size=1
for batch_size in ${batch_sizes[@]}; do
for dataset in ${datasets[@]}; do
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
CUDA_VISIBLE_DEVICES=0 \
python3 offline_inference.py \
--output-dir $output_dir \
--llm-model-name-or-path $huggingface_model_local_dir \
--token2wav-path $model_scope_model_local_dir \
--backend $backend \
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
--engine-dir $trt_engines_dir \
--split-name ${dataset} || exit 1
done
done
fi

View File

@@ -0,0 +1,225 @@
#!/bin/bash
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
export CUDA_VISIBLE_DEVICES=0
cosyvoice_path=/workspace/CosyVoice
stepaudio2_path=/workspace/Step-Audio2
export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
stage=$1
stop_stage=$2
huggingface_model_local_dir=./cosyvoice2_llm
model_scope_model_local_dir=./CosyVoice2-0.5B
step_audio_model_dir=./Step-Audio-2-mini
trt_dtype=bfloat16
trt_weights_dir=./trt_weights_${trt_dtype}
trt_engines_dir=./trt_engines_${trt_dtype}
model_repo=./model_repo_cosyvoice2_dit
bls_instance_num=10
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
echo "Cloning Step-Audio2-mini"
git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt $stepaudio2_path
echo "Cloning CosyVoice"
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
cd $cosyvoice_path
git submodule update --init --recursive
cd runtime/triton_trtllm
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading CosyVoice2-0.5B"
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
echo "Step-Audio2-mini"
huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini
cd $step_audio_model_dir/token2wav
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx
wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx
cd -
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint to TensorRT weights"
python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
--output_dir $trt_weights_dir \
--dtype $trt_dtype || exit 1
echo "Building TensorRT engines"
trtllm-build --checkpoint_dir $trt_weights_dir \
--output_dir $trt_engines_dir \
--max_batch_size 64 \
--max_num_tokens 32768 \
--gemm_plugin $trt_dtype || exit 1
echo "Testing TensorRT engines"
python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
--tokenizer_dir $huggingface_model_local_dir \
--top_k 50 --top_p 0.95 --temperature 0.8 \
--engine_dir=$trt_engines_dir || exit 1
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Creating model repository async mode"
rm -rf $model_repo
mkdir -p $model_repo
cosyvoice2_dir="cosyvoice2_dit"
token2wav_dir="token2wav_dit"
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
cp -r ./model_repo/${token2wav_dir} $model_repo
cp -r ./model_repo/audio_tokenizer $model_repo
cp -r ./model_repo/speaker_embedding $model_repo
ENGINE_PATH=$trt_engines_dir
MAX_QUEUE_DELAY_MICROSECONDS=0
MODEL_DIR=$model_scope_model_local_dir
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
BLS_INSTANCE_NUM=$bls_instance_num
TRITON_MAX_BATCH_SIZE=1
DECOUPLED_MODE=True # Only streaming TTS mode is supported using Nvidia Triton for now
STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav
python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve"
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64 --kv_cache_free_gpu_memory_fraction 0.4 &
tritonserver --model-repository $model_repo --http-port 18000 &
wait
# Test using curl
# curl http://localhost:8000/v1/chat/completions \
# -H "Content-Type: application/json" \
# -d '{
# "model": "",
# "messages":[{"role": "user", "content": "Where is New York?"},
# {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
# "max_tokens": 512,
# "temperature": 0.8,
# "top_p": 0.95,
# "top_k": 50,
# "stop": ["<|eos1|>"],
# "repetition_penalty": 1.2,
# "stream": false
# }'
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Running benchmark client"
num_task=4
mode=streaming
BLS_INSTANCE_NUM=$bls_instance_num
python3 client_grpc.py \
--server-addr localhost \
--server-port 8001 \
--model-name cosyvoice2_dit \
--num-tasks $num_task \
--mode $mode \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_single_gpu_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script"
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
backend=trtllm # hf, trtllm, vllm, trtllm-serve
batch_sizes=(16)
token2wav_batch_size=1
for batch_size in ${batch_sizes[@]}; do
for dataset in ${datasets[@]}; do
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
CUDA_VISIBLE_DEVICES=1 \
python3 offline_inference.py \
--output-dir $output_dir \
--llm-model-name-or-path $huggingface_model_local_dir \
--token2wav-path $step_audio_model_dir/token2wav \
--backend $backend \
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
--engine-dir $trt_engines_dir \
--split-name ${dataset} || exit 1
done
done
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "Running Step-Audio2-mini DiT Token2Wav inference using a single python script"
export CUDA_VISIBLE_DEVICES=1
# Note: Using pre-computed cosyvoice2 tokens
python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential
# Offline Token2wav inference
python3 token2wav_dit.py --enable-trt
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "Disaggregated Server: LLM and Token2wav on different GPUs"
echo "Starting LLM server on GPU 0"
export CUDA_VISIBLE_DEVICES=0
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64 --kv_cache_free_gpu_memory_fraction 0.4 &
echo "Starting Token2wav server on GPUs 1-3"
Token2wav_num_gpus=3
http_port=17000
grpc_port=18000
metrics_port=16000
for i in $(seq 0 $(($Token2wav_num_gpus - 1))); do
echo "Starting server on GPU $i"
http_port=$((http_port + 1))
grpc_port=$((grpc_port + 1))
metrics_port=$((metrics_port + 1))
# Two instances of Token2wav server on the same GPU
CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
http_port=$((http_port + 1))
grpc_port=$((grpc_port + 1))
metrics_port=$((metrics_port + 1))
CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
done
wait
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Running benchmark client for Disaggregated Server"
per_gpu_instances=2
mode=streaming
BLS_INSTANCE_NUM=$bls_instance_num
Token2wav_num_gpus=(1 2 3)
concurrent_tasks=(1 2 3 4 5 6)
for n_gpu in ${Token2wav_num_gpus[@]}; do
echo "Test 1 GPU for LLM server and $n_gpu GPUs for Token2wav servers"
for concurrent_task in ${concurrent_tasks[@]}; do
num_instances=$((per_gpu_instances * n_gpu))
for i in $(seq 1 $num_instances); do
port=$(($i + 18000))
python3 client_grpc.py \
--server-addr localhost \
--server-port $port \
--model-name cosyvoice2_dit \
--num-tasks $concurrent_task \
--mode $mode \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_disagg_concurrent_tasks_${concurrent_task}_per_instance_total_token2wav_instances_${num_instances}_port_${port} &
done
wait
done
done
fi

View File

@@ -0,0 +1,330 @@
import argparse
import os
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import AutoConfig
import tensorrt_llm
from tensorrt_llm._utils import release_gc
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import QWenForCausalLM
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization import QuantAlgo
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default=None, required=True)
parser.add_argument('--tp_size',
type=int,
default=1,
help='N-way tensor parallelism size')
parser.add_argument('--pp_size',
type=int,
default=1,
help='N-way pipeline parallelism size')
parser.add_argument('--cp_size',
type=int,
default=1,
help='N-way context parallelism size')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'float16', 'bfloat16', 'float32'],
help="The data type for the model weights and activations if not quantized. "
"If 'auto', the data type is automatically inferred from the source model; "
"however, if the source dtype is float32, it is converted to float16.")
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument(
'--disable_weight_only_quant_plugin',
default=False,
action="store_true",
help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4', 'int4_gptq'],
help='Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument(
'--calib_dataset',
type=str,
default='ccdv/cnn_dailymail',
help="The huggingface dataset name or the local directory of the dataset for calibration."
)
parser.add_argument(
"--smoothquant",
"-sq",
type=float,
default=None,
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
" to Smoothquant the model, and output int8 weights."
" A good first try is 0.5. Must be in [0, 1]")
parser.add_argument(
'--per_channel',
action="store_true",
default=False,
help='By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
'--per_token',
action="store_true",
default=False,
help='By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
'--int8_kv_cache',
default=False,
action="store_true",
help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
)
parser.add_argument(
'--per_group',
default=False,
action="store_true",
help='By default, we use a single static scaling factor to scale weights in the int4 range. '
'per_group chooses at run time, and for each group, a custom scaling factor. '
'The flag is built for GPTQ/AWQ quantization.')
parser.add_argument('--group_size',
type=int,
default=128,
help='Group size used in GPTQ quantization.')
parser.add_argument("--load_model_on_cpu", action="store_true")
parser.add_argument(
'--use_parallel_embedding',
action="store_true",
default=False,
help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
)
parser.add_argument(
'--embedding_sharding_dim',
type=int,
default=0,
choices=[0, 1],
help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
)
parser.add_argument('--output_dir',
type=str,
default='tllm_checkpoint',
help='The path to save the TensorRT-LLM checkpoint')
parser.add_argument(
'--workers',
type=int,
default=1,
help='The number of workers for converting checkpoint in parallel')
parser.add_argument(
'--moe_tp_size',
type=int,
default=-1,
help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
)
parser.add_argument(
'--moe_ep_size',
type=int,
default=-1,
help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
)
args = parser.parse_args()
return args
def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
'''return config dict with quantization info based on the command line args
'''
quant_config = QuantConfig()
if args.use_weight_only:
if args.weight_only_precision == 'int8':
quant_config.quant_algo = QuantAlgo.W8A16
elif args.weight_only_precision == 'int4':
quant_config.quant_algo = QuantAlgo.W4A16
elif args.smoothquant:
quant_config.smoothquant_val = args.smoothquant
if args.per_channel:
if args.per_token:
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
else:
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
else:
if args.per_token:
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
else:
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
if args.int8_kv_cache:
quant_config.kv_cache_quant_algo = QuantAlgo.INT8
if args.weight_only_precision == 'int4_gptq':
quant_config.group_size = args.group_size
quant_config.has_zero_point = True
quant_config.pre_quant_scale = False
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
return quant_config
def update_quant_config_from_hf(quant_config, hf_config,
override_fields) -> tuple[QuantConfig, dict]:
hf_config_dict = hf_config.to_dict()
if hf_config_dict.get('quantization_config'):
# update the quant_algo, and clamp_val.
if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
logger.info(
"Load quantization configs from huggingface model_config.")
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
quant_config.group_size = hf_config_dict['quantization_config'].get(
'group_size', 128)
quant_config.has_zero_point = hf_config_dict[
'quantization_config'].get('zero_point', False)
override_fields.update({"use_autoawq": True})
elif hf_config_dict['quantization_config'].get(
'quant_method') == 'gptq':
logger.info(
"Load quantization configs from huggingface model_config.")
desc_act = hf_config_dict['quantization_config'].get(
'desc_act', False)
if desc_act:
raise ValueError("GPTQ with desc_act=True is not implemented!")
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
quant_config.group_size = hf_config_dict['quantization_config'].get(
'group_size', 128)
quant_config.has_zero_point = hf_config_dict[
'quantization_config'].get('sym', False)
return quant_config, override_fields
def args_to_build_options(args):
return {
'use_parallel_embedding': args.use_parallel_embedding,
'embedding_sharding_dim': args.embedding_sharding_dim,
'disable_weight_only_quant_plugin':
args.disable_weight_only_quant_plugin
}
def convert_and_save_hf(args):
model_dir = args.model_dir
world_size = args.tp_size * args.pp_size
# Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
# Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
# before the refactor is done.
override_fields = {}
override_fields.update(args_to_build_options(args))
quant_config = args_to_quant_config(args)
try:
hf_config = AutoConfig.from_pretrained(model_dir,
trust_remote_code=True)
quant_config, override_fields = update_quant_config_from_hf(
quant_config, hf_config, override_fields)
except BaseException:
logger.warning("AutoConfig cannot load the huggingface config.")
if args.smoothquant is not None or args.int8_kv_cache:
mapping = Mapping(world_size=world_size,
tp_size=args.tp_size,
pp_size=args.pp_size,
moe_tp_size=args.moe_tp_size,
moe_ep_size=args.moe_ep_size,
cp_size=args.cp_size)
QWenForCausalLM.quantize(args.model_dir,
args.output_dir,
dtype=args.dtype,
mapping=mapping,
quant_config=quant_config,
calib_dataset=args.calib_dataset,
**override_fields)
else:
def convert_and_save_rank(args, rank):
mapping = Mapping(world_size=world_size,
rank=rank,
tp_size=args.tp_size,
pp_size=args.pp_size,
moe_tp_size=args.moe_tp_size,
moe_ep_size=args.moe_ep_size)
qwen = QWenForCausalLM.from_hugging_face(model_dir,
args.dtype,
mapping=mapping,
quant_config=quant_config,
**override_fields)
qwen.config.mapping.cp_size = args.cp_size
qwen.config.mapping.attn_tp_size = -1
qwen.config.mapping.attn_cp_size = -1
qwen.config.mapping.world_size *= args.cp_size
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
del qwen
execute(args.workers, [convert_and_save_rank] * world_size, args)
release_gc()
def execute(workers, func, args):
if workers == 1:
for rank, f in enumerate(func):
f(args, rank)
else:
with ThreadPoolExecutor(max_workers=workers) as p:
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert len(
exceptions
) == 0, "Checkpoint conversion failed, please check error log."
def main():
print(tensorrt_llm.__version__)
args = parse_arguments()
if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
# moe default to tp-only
args.moe_tp_size = args.tp_size
args.moe_ep_size = 1
elif (args.moe_tp_size == -1):
args.moe_tp_size = args.tp_size // args.moe_ep_size
elif (args.moe_ep_size == -1):
args.moe_ep_size = args.tp_size // args.moe_tp_size
assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
), "moe_tp_size * moe_ep_size must equal to tp_size"
tik = time.time()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
assert args.model_dir is not None
convert_and_save_hf(args)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Total time of converting checkpoints: {t}')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,69 @@
# /usr/bin/env python3
from argparse import ArgumentParser
from string import Template
def split(string, delimiter):
"""Split a string using delimiter. Supports escaping.
Args:
string (str): The string to split.
delimiter (str): The delimiter to split the string with.
Returns:
list: A list of strings.
"""
result = []
current = ""
escape = False
for char in string:
if escape:
current += char
escape = False
elif char == delimiter:
result.append(current)
current = ""
elif char == "\\":
escape = True
else:
current += char
result.append(current)
return result
def main(file_path, substitutions, in_place):
with open(file_path) as f:
pbtxt = Template(f.read())
sub_dict = {
"max_queue_size": 0,
'max_queue_delay_microseconds': 0,
}
for sub in split(substitutions, ","):
key, value = split(sub, ":")
sub_dict[key] = value
assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}."
pbtxt = pbtxt.safe_substitute(sub_dict)
if in_place:
with open(file_path, "w") as f:
f.write(pbtxt)
else:
print(pbtxt)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("file_path", help="path of the .pbtxt to modify")
parser.add_argument(
"substitutions",
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
)
parser.add_argument("--in_place",
"-i",
action="store_true",
help="do the operation in-place")
args = parser.parse_args()
main(**vars(args))

View File

@@ -0,0 +1,138 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import numpy as np
import torch
import tensorrt_llm
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime import ModelRunnerCpp
from transformers import AutoTokenizer
def parse_arguments(args=None):
parser = argparse.ArgumentParser()
parser.add_argument(
'--input_text',
type=str,
nargs='+',
default=["Born in north-east France, Soyer trained as a"])
parser.add_argument('--tokenizer_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
parser.add_argument('--engine_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
parser.add_argument('--log_level', type=str, default="debug")
parser.add_argument('--kv_cache_free_gpu_memory_fraction', type=float, default=0.6)
parser.add_argument('--temperature', type=float, default=0.8)
parser.add_argument('--top_k', type=int, default=50)
parser.add_argument('--top_p', type=float, default=0.95)
return parser.parse_args(args=args)
def parse_input(tokenizer,
input_text=None,
prompt_template=None):
batch_input_ids = []
for curr_text in input_text:
if prompt_template is not None:
curr_text = prompt_template.format(input_text=curr_text)
input_ids = tokenizer.encode(
curr_text)
batch_input_ids.append(input_ids)
batch_input_ids = [
torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
]
logger.debug(f"Input token ids (batch_size = {len(batch_input_ids)}):")
for i, input_ids in enumerate(batch_input_ids):
logger.debug(f"Request {i}: {input_ids.tolist()}")
return batch_input_ids
def main(args):
runtime_rank = tensorrt_llm.mpi_rank()
logger.set_level(args.log_level)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
prompt_template = "<|sos|>{input_text}<|task_id|>"
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>")
batch_input_ids = parse_input(tokenizer=tokenizer,
input_text=args.input_text,
prompt_template=prompt_template)
input_lengths = [x.size(0) for x in batch_input_ids]
runner_kwargs = dict(
engine_dir=args.engine_dir,
rank=runtime_rank,
max_output_len=1024,
enable_context_fmha_fp32_acc=False,
max_batch_size=len(batch_input_ids),
max_input_len=max(input_lengths),
kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
cuda_graph_mode=False,
gather_generation_logits=False,
)
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
with torch.no_grad():
outputs = runner.generate(
batch_input_ids=batch_input_ids,
max_new_tokens=1024,
end_id=end_id,
pad_id=end_id,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
num_return_sequences=1,
repetition_penalty=1.1,
random_seed=42,
streaming=False,
output_sequence_lengths=True,
output_generation_logits=False,
return_dict=True,
return_all_generated_tokens=False)
torch.cuda.synchronize()
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
num_output_sents, num_beams, _ = output_ids.size()
assert num_beams == 1
beam = 0
batch_size = len(input_lengths)
num_return_sequences = num_output_sents // batch_size
assert num_return_sequences == 1
for i in range(batch_size * num_return_sequences):
batch_idx = i // num_return_sequences
seq_idx = i % num_return_sequences
inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist()
input_text = tokenizer.decode(inputs)
print(f'Input [Text {batch_idx}]: \"{input_text}\"')
output_begin = input_lengths[batch_idx]
output_end = sequence_lengths[i][beam]
outputs = output_ids[i][beam][output_begin:output_end].tolist()
output_text = tokenizer.decode(outputs)
print(f'Output [Text {batch_idx}]: \"{output_text}\"')
logger.debug(str(outputs))
if __name__ == '__main__':
args = parse_arguments()
main(args)

View File

@@ -0,0 +1,122 @@
import torch
import os
import argparse
from datasets import load_dataset
from torch.utils.data import DataLoader
import numpy as np
import torchaudio
import time
from token2wav_dit import CosyVoice2_Token2Wav
import soundfile as sf
def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
prompt_speech_tokens_list, prompt_text_list = [], []
for item in batch:
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio)
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
ids.append(item['id'])
prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens'])
prompt_text_list.append(item['prompt_text'])
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--enable-trt", action="store_true")
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--output-dir", type=str, default="generated_wavs")
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"])
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
dataset_name = args.dataset_name
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
CHUNK_SIZE = 25
token_frame_rate = 25
OVERLAP_SIZE = 0
warmup_times = 3
for _ in range(warmup_times):
start_time = time.time()
total_forward_count = 0
for batch in data_loader:
tts_speech_list = []
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
assert prompt_audio_sample_rate == 16000
prompt_text = prompt_text_list[0]
prompt_speech_tokens = prompt_speech_tokens_list[0]
semantic_token_ids_arr, token_offset = [], 0
flow_prompt_speech_token_len = len(prompt_speech_tokens)
buffer = generated_speech_tokens
output_wavs = []
chunk_index = 0
while True:
if args.strategy == "equal":
this_chunk_size = CHUNK_SIZE
elif args.strategy == "exponential":
this_chunk_size = token_frame_rate * (2 ** chunk_index)
if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
wavs = token2wav_model.forward_streaming(
buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len],
False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio,
prompt_audio_sample_rate=prompt_audio_sample_rate
)
buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
output_wavs.append(wavs)
total_forward_count += 1
chunk_index += 1
else:
wavs = token2wav_model.forward_streaming(
buffer, True, request_id=id, speaker_id=f"{id}",
prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate
)
output_wavs.append(wavs)
total_forward_count += 1
# chunk_index += 1
break
for i, wav in enumerate(output_wavs):
output_wavs[i] = wav.cpu().numpy().squeeze()
audios = output_wavs
reconstructed_audio = np.concatenate(audios)
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
end_time = time.time()
if _ == 0:
token2wav_model.speaker_cache = {}
print(f"Warmup time: {end_time - start_time} seconds")
print("clear speaker cache")
elif _ == 1:
print(f"Cost time without speaker cache: {end_time - start_time} seconds")
else:
print(f"Cost time with speaker cache: {end_time - start_time} seconds")
print(f"Total flow matching forward calls: {total_forward_count}")

View File

@@ -0,0 +1,335 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
CUDA_VISIBLE_DEVICES=0 \
python3 token2wav.py --enable-trt || exit 1
"""
import torch
from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
from flashcosyvoice.modules.hifigan import HiFTGenerator
from flashcosyvoice.utils.audio import mel_spectrogram
import torchaudio.compliance.kaldi as kaldi
import onnxruntime
import s3tokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset
import torchaudio
import os
import logging
import argparse
import queue
import time
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
import tensorrt as trt
logging.info("Converting onnx to trt...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
# load onnx model
with open(onnx_model, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise ValueError('failed to parse {}'.format(onnx_model))
# set input shapes
for i in range(len(trt_kwargs['input_names'])):
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
# set input and output data type
for i in range(network.num_inputs):
input_tensor = network.get_input(i)
input_tensor.dtype = tensor_dtype
for i in range(network.num_outputs):
output_tensor = network.get_output(i)
output_tensor.dtype = tensor_dtype
config.add_optimization_profile(profile)
engine_bytes = builder.build_serialized_network(network, config)
# save trt engine
with open(trt_model, "wb") as f:
f.write(engine_bytes)
logging.info("Succesfully convert onnx to trt...")
class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
self.trt_engine = trt_engine
self.device = device
for _ in range(trt_concurrent):
trt_context = trt_engine.create_execution_context()
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
self.trt_context_pool.put([trt_context, trt_stream])
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
def acquire_estimator(self):
return self.trt_context_pool.get(), self.trt_engine
def release_estimator(self, context, stream):
self.trt_context_pool.put([context, stream])
class CosyVoice2_Token2Wav(torch.nn.Module):
def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0):
super().__init__()
self.device_id = device_id
self.device = f"cuda:{device_id}"
self.flow = CausalMaskedDiffWithXvec()
self.flow.half()
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
self.flow.to(self.device).eval()
self.hift = HiFTGenerator()
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval()
gpu = "l20"
if enable_trt:
self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
1,
True)
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False)
def forward_spk_embedding(self, spk_feat):
if isinstance(self.spk_model, onnxruntime.InferenceSession):
return self.spk_model.run(
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
else:
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
# NOTE need to synchronize when switching stream
with torch.cuda.device(self.device_id):
torch.cuda.current_stream().synchronize()
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
batch_size = spk_feat.size(0)
with stream:
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
data_ptrs = [spk_feat.contiguous().data_ptr(),
output_tensor.contiguous().data_ptr()]
for i, j in enumerate(data_ptrs):
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.spk_model.release_estimator(spk_model, stream)
return output_tensor.cpu().numpy().flatten().tolist()
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
trt_kwargs = self.get_spk_trt_kwargs()
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
import tensorrt as trt
with open(spk_model, 'rb') as f:
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_spk_trt_kwargs(self):
min_shape = [(1, 4, 80)]
opt_shape = [(1, 500, 80)]
max_shape = [(1, 3000, 80)]
input_names = ["input"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_bs=2, max_batch_size=16)
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
opt_shape = [(opt_bs * 2, 80, 500), (opt_bs * 2, 1, 500), (opt_bs * 2, 80, 500), (opt_bs * 2, 80, 500), (opt_bs * 2,), (opt_bs * 2, 80)]
max_shape = [(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2,),
(max_batch_size * 2, 80)]
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
prompt_speech_tokens_list, prompt_speech_mels_list = [], []
for audio in prompt_audios_list:
assert len(audio.shape) == 1
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
prompt_speech_mels_list.append(log_mel)
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
)
for i in range(len(prompt_speech_tokens)):
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
prompt_speech_tokens_list.append(speech_tokens_i)
return prompt_speech_tokens_list
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
spk_emb_for_flow = []
for audio in prompt_audios_list:
assert len(audio.shape) == 1
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
spk_emb = self.forward_spk_embedding(spk_feat)
spk_emb_for_flow.append(spk_emb)
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
return spk_emb_for_flow
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
prompt_mels_for_flow = []
prompt_mels_lens_for_flow = []
for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
assert len(audio.shape) == 1
audio = audio.unsqueeze(0)
if sample_rate != 24000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
mel_len = mel.shape[0]
prompt_mels_for_flow.append(mel)
prompt_mels_lens_for_flow.append(mel_len)
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
return prompt_mels_for_flow, prompt_mels_lens_for_flow
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor,
prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
batch_size = prompt_mels_for_flow.shape[0]
flow_inputs = []
flow_inputs_lens = []
for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
flow_inputs_lens = torch.tensor(flow_inputs_lens)
with torch.amp.autocast(self.device, dtype=torch.float16):
generated_mels, generated_mels_lens = self.flow(
flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device),
streaming=False, finalize=True
)
return generated_mels, generated_mels_lens
def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
batch_size = generated_mels.shape[0]
generated_wavs = []
for i in range(batch_size):
mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
wav, _ = self.hift(speech_feat=mel)
generated_wavs.append(wav)
return generated_wavs
@torch.inference_mode()
def forward(
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
):
# assert all item in prompt_audios_sample_rate is 16000
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
generated_mels, generated_mels_lens = self.forward_flow(
prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
return generated_wavs
def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
for _, item in enumerate(batch):
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
audio = torch.from_numpy(item['prompt_audio']['array']).float()
prompt_audios_list.append(audio)
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
ids.append(item['id'])
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--enable-trt", action="store_true")
parser.add_argument("--model-dir", type=str, default="./CosyVoice2-0.5B")
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--output-dir", type=str, default="generated_wavs")
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
# mkdir output_dir if not exists
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
dataset_name = "yuekai/seed_tts_cosy2"
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
for _ in range(args.warmup):
start_time = time.time()
for batch in data_loader:
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
for id, wav in zip(ids, generated_wavs):
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
end_time = time.time()
epoch_time = end_time - start_time
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")

View File

@@ -0,0 +1 @@
model_repo/token2wav_dit/1/token2wav_dit.py

View File

@@ -29,27 +29,24 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
for utt in tqdm(utt_list):
data = open(utt2wav[utt], 'rb').read()
data_list.append(data)
wav_list = [utt2wav[utt] for utt in utt_list]
text_list = [utt2text[utt] for utt in utt_list]
spk_list = [utt2spk[utt] for utt in utt_list]
uttembedding_list = [utt2embedding[utt] for utt in utt_list]
spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
if args.dpo:
reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
# 保存到parquet,utt2parquet_file,spk2parquet_file
df = pd.DataFrame()
df['utt'] = utt_list
df['wav'] = wav_list
df['audio_data'] = data_list
df['text'] = text_list
df['spk'] = spk_list
df['utt_embedding'] = uttembedding_list
df['spk_embedding'] = spkembedding_list
df['speech_token'] = speech_token_list
df['wav'] = [utt2wav[utt] for utt in utt_list]
df['text'] = [utt2text[utt] for utt in utt_list]
df['spk'] = [utt2spk[utt] for utt in utt_list]
if utt2embedding is not None:
df['utt_embedding'] = [utt2embedding[utt] for utt in utt_list]
if spk2embedding is not None:
df['spk_embedding'] = [spk2embedding[utt2spk[utt]] for utt in utt_list]
if utt2speech_token is not None:
df['speech_token'] = [utt2speech_token[utt] for utt in utt_list]
if utt2instruct is not None:
df['instruct'] = [utt2instruct[utt] for utt in utt_list]
if args.dpo:
df['reject_speech_token'] = reject_speech_token_list
df['reject_speech_token'] = [utt2reject_speech_token.get(utt, None) for utt in utt_list]
df.to_parquet(parquet_file)
with open(utt2parquet_file, 'w') as f:
json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
@@ -91,11 +88,19 @@ if __name__ == "__main__":
for l in f:
l = l.replace('\n', '').split()
utt2spk[l[0]] = l[1]
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
if os.path.exists('{}/instruct'.format(args.src_dir)):
utt2instruct = {}
with open('{}/instruct'.format(args.src_dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2instruct[l[0]] = ' '.join(l[1:])
else:
utt2instruct = None
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) if os.path.exists('{}/utt2embedding.pt'.format(args.src_dir)) else None
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) if os.path.exists('{}/spk2embedding.pt'.format(args.src_dir)) else None
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) if os.path.exists('{}/utt2speech_token.pt'.format(args.src_dir)) else None
if args.dpo:
utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir))
utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir)) if os.path.exists('{}_reject/utt2speech_token.pt'.format(args.src_dir)) else {}
utts = list(utt2wav.keys())
# Using process pool to speedup

View File

@@ -4,20 +4,36 @@ from vllm import ModelRegistry
from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
from cosyvoice.cli.cosyvoice import CosyVoice2
from cosyvoice.utils.file_utils import load_wav
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.common import set_all_random_seed
from tqdm import tqdm
def main():
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
def cosyvoice2_example():
""" CosyVoice2 vllm usage
"""
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
for i in tqdm(range(100)):
set_all_random_seed(i)
for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
continue
def cosyvoice3_example():
""" CosyVoice3 vllm usage
"""
cosyvoice = AutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B', load_trt=True, load_vllm=True, fp16=False)
for i in tqdm(range(100)):
set_all_random_seed(i)
for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
'./asset/zero_shot_prompt.wav', stream=False)):
continue
def main():
# cosyvoice2_example()
cosyvoice3_example()
if __name__ == '__main__':
main()

Some files were not shown because too many files have changed in this diff Show More