59 Commits

Author SHA1 Message Date
Xiang Lyu
f08872a82f Merge pull request #1814 from hexisyztem/main
[BUG FIX] 使用 float64 避免精度误差问题,弃用 CPU 计算,避免拖累性能
2026-02-04 13:10:40 +08:00
lyuxiang.lx
f2ddcbe7f9 fix ras 2026-02-04 11:44:39 +08:00
禾息
0d990d6074 [BUG FIX] 使用 float64 避免精度误差问题,弃用 CPU 计算,避免拖累性能 2026-01-30 18:10:36 +08:00
lyuxiang.lx
c93d3dda01 update 2026-01-29 17:31:04 +00:00
lyuxiang.lx
84e41729ea update 2026-01-29 10:29:22 +00:00
lyuxiang.lx
f26cde56df update 2026-01-29 06:13:36 +00:00
lyuxiang.lx
66b80dbccb online feature 2026-01-28 15:19:07 +00:00
lyuxiang.lx
1822c5c908 fix fm train bug 2026-01-19 10:48:24 +08:00
lyuxiang.lx
1dcc59676f add japanese 2026-01-12 07:06:45 +00:00
lyuxiang.lx
7fdd80dc64 fix padding 2026-01-07 07:12:52 +00:00
lyuxiang.lx
f97d50d559 fix sequence logic 2026-01-07 07:06:04 +00:00
lyuxiang.lx
652132ebaa Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2026-01-05 11:02:01 +08:00
lyuxiang.lx
1ceb2b7b1e fix bug 2026-01-05 11:01:19 +08:00
Xiang Lyu
55e3e370a0 Merge pull request #1758 from orbisai0security/fix/V-005-pickle-deserialization
[Security] Fix CRITICAL vulnerability: V-005
2025-12-31 11:45:44 +08:00
lyuxiang.lx
46788c7379 update readme 2025-12-31 11:28:47 +08:00
Xiang Lyu
881177287c Merge pull request #1640 from Jzz1943/main
support vLLM >=0.11.0 (V1 engine) for better performance
2025-12-31 10:40:02 +08:00
Xiang Lyu
f88a14e41d Merge branch 'main' into main 2025-12-31 10:37:18 +08:00
lyuxiang.lx
dac6566fc3 update tokens 2025-12-30 16:14:36 +00:00
lyuxiang.lx
cc91e40db8 more silent_token 2025-12-30 15:53:33 +00:00
lyuxiang.lx
ab7f1f4a86 update silent token 2025-12-30 15:12:18 +00:00
lyuxiang.lx
e15222b17c refine code 2025-12-30 09:24:52 +00:00
lyuxiang.lx
cfa1c115b2 add silent_token 2025-12-30 09:18:17 +00:00
lyuxiang.lx
dd5cdb6ebf Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-12-30 14:03:52 +08:00
lyuxiang.lx
2d7ef0b719 remove instruct warning 2025-12-30 12:36:18 +08:00
Xiang Lyu
ba5db602a9 Merge pull request #1622 from GoyoUijin/Fix/token2wav-cache-thread-unsafe
fix triton token2wav model cache thread unsafety
2025-12-30 11:24:25 +08:00
orbisai0security
5b94675f62 fix: resolve critical vulnerability V-005
Automatically generated security fix
2025-12-29 13:25:05 +00:00
lyuxiang.lx
4c19646b9a update dataset 2025-12-29 12:46:34 +00:00
lyuxiang.lx
63a06227d1 Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-12-29 18:32:54 +08:00
lyuxiang.lx
3b44913782 fix bug 2025-12-29 10:30:54 +00:00
Xiang Lyu
055f64d002 Merge pull request #1728 from yigitcatak/main
fixed a typo in Dockerfile regarding environment variable
2025-12-29 18:24:39 +08:00
lyuxiang.lx
4d7295a9a7 fix cosyvoice3 training 2025-12-29 10:03:14 +00:00
lyuxiang.lx
8524c81acd fix cv3 train 2025-12-29 10:03:14 +00:00
Xiang Lyu
a14e063ead Merge pull request #1722 from majiayu000/fix/issue-1683-ja-language-tag
docs: fix Japanese language tag in example.py comment
2025-12-29 16:12:54 +08:00
lyuxiang.lx
2db78e7058 fix export_jit.py 2025-12-23 17:23:23 +08:00
lyuxiang.lx
7538c6a73d update export_jit 2025-12-23 15:23:29 +08:00
yigitcatak
823ae2c60d fix a typo in Dockerfile 2025-12-23 16:22:48 +09:00
lyuxiang.lx
59cb2bf16c fix jit_export 2025-12-23 14:06:40 +08:00
majiayu000
80bebb1978 docs: fix Japanese language tag in example.py comment
Changed <|jp|> to <|ja|> to match the actual tokenizer implementation.

The LANGUAGES dict in cosyvoice/tokenizer/tokenizer.py defines 'ja' for Japanese,
not 'jp'. This fixes the misleading comment that could cause issues like #621.

Fixes #1683
2025-12-22 19:11:47 +08:00
Xiang Lyu
bc34459bb8 Merge pull request #1693 from whiteshirt0429/main
Fix CosyVoice3 config error
2025-12-22 13:56:19 +08:00
lyuxiang.lx
9f27b42cd9 update 2025-12-17 18:58:50 +08:00
lyuxiang.lx
a7d6e2251a update libritts cosyvoice3.yaml 2025-12-17 17:15:10 +08:00
lyuxiang.lx
7baefaf0f2 update libritts cosyvoice3.yaml 2025-12-17 17:14:17 +08:00
di.wu
ff0d05c380 Fix CosyVoice3 config error 2025-12-17 14:57:17 +08:00
lyuxiang.lx
f5816b4e51 update readme 2025-12-17 03:16:00 +00:00
lyuxiang.lx
8b54619760 update 2025-12-16 15:00:03 +08:00
lyuxiang.lx
2abd42220e add x-transformer version 2025-12-16 14:45:02 +08:00
lyuxiang.lx
2d6bb9bd80 Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-12-15 16:16:35 +08:00
lyuxiang.lx
0b80c0746a update 2025-12-15 16:13:53 +08:00
Xiangang Li
e98b828f33 Update README.md 2025-12-15 15:54:32 +08:00
lyuxiang.lx
4d4c787be0 update 2025-12-15 15:33:52 +08:00
lyuxiang.lx
781a49acb4 update metric 2025-12-15 14:52:32 +08:00
lyuxiang.lx
9476a063b3 update metric 2025-12-15 14:48:17 +08:00
Xiang Lyu
3426ceb70f Merge pull request #1671 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-12-15 14:00:40 +08:00
lyuxiang.lx
089343ab0a update 2025-12-15 12:44:06 +08:00
Xiang Lyu
0c50894d49 Merge pull request #1670 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-12-15 12:17:39 +08:00
zhongze.jiang
6816fc6a6f support vLLM >=0.11.0 (V1 engine only) 2025-11-10 16:30:42 +08:00
김의진
e8bf717333 Fix: generate token2wav_request_id from cosyvoice2
- Since all token2wav requests within a single cosyvoice2 request must share the same request_id, modify the logic so that a new request_id is generated only if it does not already exist, and ensure that the same request_id is sent consistently.
2025-10-27 18:12:17 +09:00
김의진
fa2781405f Revert "fix triton token2wav model cache thread unsafety"
This reverts commit cd26dd1932.
2025-10-27 18:07:30 +09:00
김의진
cd26dd1932 fix triton token2wav model cache thread unsafety 2025-10-27 17:20:14 +09:00
40 changed files with 435 additions and 363 deletions

View File

@@ -1,12 +1,12 @@
[![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)](https://github.com/Akshay090/svg-banners) ![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)
## 👉🏻 CosyVoice 👈🏻 ## 👉🏻 CosyVoice 👈🏻
**Fun-CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/abs/2505.17589); [Modelscope](https://www.modelscope.cn/studios/FunAudioLLM/Fun-CosyVoice3-0.5B); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval) **Fun-CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/pdf/2505.17589); [Modelscope](https://www.modelscope.cn/models/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [Huggingface](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B) **CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/pdf/2412.10117); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B)
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M) **CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice-300M); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice-300M)
## Highlight🔥 ## Highlight🔥
@@ -60,23 +60,25 @@
- [x] Fastapi server and client - [x] Fastapi server and client
## Evaluation ## Evaluation
| Model | Model Size | CER (%) ↓ (test-zh) | WER (%) ↓ (test-en) | CER (%) ↓ (test-hard) |
|-------|------------|---------------------|---------------------|-----------------------| | Model | Open-Source | Model Size | test-zh<br>CER (%) ↓ | test-zh<br>SS (%) ↑ | test-en<br>WER (%) ↓ | test-en<br>SS (%) ↑ | test-hard<br>CER (%) ↓ | test-hard<br>SS (%) ↑ |
| Human | - | 1.26 | 2.14 | - | | :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Seed-TTS | - | 1.12 | 2.25 | 7.59 | | Human | - | - | 1.26 | 75.5 | 2.14 | 73.4 | - | - |
| MiniMax-Speech | - | 0.83 | 1.65 | - | | Seed-TTS | ❌ | - | 1.12 | 79.6 | 2.25 | 76.2 | 7.59 | 77.6 |
| F5-TTS | 0.3B | 1.52 | 2.00 | 8.67 | | MiniMax-Speech | ❌ | - | 0.83 | 78.3 | 1.65 | 69.2 | - | - |
| SparkTTS | 0.5B | 1.20 | 1.98 | - | | F5-TTS | ✅ | 0.3B | 1.52 | 74.1 | 2.00 | 64.7 | 8.67 | 71.3 |
| CosyVoice2 | 0.5B | 1.45 | 2.57 | 6.83 | | Spark TTS | ✅ | 0.5B | 1.2 | 66.0 | 1.98 | 57.3 | - | - |
| FireRedTTS-2 | 1.5B | 1.14 | 1.95 | - | | CosyVoice2 | ✅ | 0.5B | 1.45 | 75.7 | 2.57 | 65.9 | 6.83 | 72.4 |
| IndexTTS2 | 1.5B | 1.01 | 1.52 | 7.12 | | FireRedTTS2 | ✅ | 1.5B | 1.14 | 73.2 | 1.95 | 66.5 | - | - |
| VibeVoice | 1.5B | 1.16 | 3.04 | - | | Index-TTS2 | ✅ | 1.5B | 1.03 | 76.5 | 2.23 | 70.6 | 7.12 | 75.5 |
| HiggsAudio-v2 | 3B | 1.50 | 2.44 | - | | VibeVoice-1.5B | ✅ | 1.5B | 1.16 | 74.4 | 3.04 | 68.9 | - | - |
| VoxPCM | 0.5B | 0.93 | 1.85 | 8.87 | | VibeVoice-Realtime | ✅ | 0.5B | - | - | 2.05 | 63.3 | - | - |
| GLM-TTS | 1.5B | 1.03 | - | - | | HiggsAudio-v2 | ✅ | 3B | 1.50 | 74.0 | 2.44 | 67.7 | - | - |
| GLM-TTS_RL | 1.5B | 0.89 | - | - | | VoxCPM | ✅ | 0.5B | 0.93 | 77.2 | 1.85 | 72.9 | 8.87 | 73.0 |
| Fun-CosyVoice3-0.5B-2512 | 0.5B | 1.21 | 2.24 | 6.71 | | GLM-TTS | ✅ | 1.5B | 1.03 | 76.1 | - | - | - | - |
| Fun-CosyVoice3-0.5B-2512_RL | 0.5B | 0.81 | 1.68 | 5.44 | | GLM-TTS RL | ✅ | 1.5B | 0.89 | 76.4 | - | - | - | - |
| Fun-CosyVoice3-0.5B-2512 | ✅ | 0.5B | 1.21 | 78.0 | 2.24 | 71.8 | 6.71 | 75.8 |
| Fun-CosyVoice3-0.5B-2512_RL | ✅ | 0.5B | 0.81 | 77.4 | 1.68 | 69.5 | 5.44 | 75.0 |
## Install ## Install
@@ -111,7 +113,7 @@
We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource. We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
``` python ``` python
# SDK模型下载 # modelscope SDK model download
from modelscope import snapshot_download from modelscope import snapshot_download
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B') snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B') snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
@@ -119,6 +121,15 @@ snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-3
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT') snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct') snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd') snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
# for oversea users, huggingface SDK model download
from huggingface_hub import snapshot_download
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
snapshot_download('FunAudioLLM/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
snapshot_download('FunAudioLLM/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('FunAudioLLM/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('FunAudioLLM/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('FunAudioLLM/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
``` ```
Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance. Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.
@@ -140,15 +151,19 @@ Follow the code in `example.py` for detailed usage of each model.
python example.py python example.py
``` ```
#### CosyVoice2 vllm Usage #### vLLM Usage
If you want to use vllm for inference, please install `vllm==v0.9.0`. Older vllm version do not support CosyVoice2 inference. CosyVoice2/3 now supports **vLLM 0.11.x+ (V1 engine)** and **vLLM 0.9.0 (legacy)**.
Older vllm version(<0.9.0) do not support CosyVoice inference, and versions in between (e.g., 0.10.x) are not tested.
Notice that `vllm==v0.9.0` has a lot of specific requirements, for example `torch==2.7.0`. You can create a new env to in case your hardward do not support vllm and old env is corrupted. Notice that `vllm` has a lot of specific requirements. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
``` sh ``` sh
conda create -n cosyvoice_vllm --clone cosyvoice conda create -n cosyvoice_vllm --clone cosyvoice
conda activate cosyvoice_vllm conda activate cosyvoice_vllm
pip install vllm==v0.9.0 transformers==4.51.3 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com # for vllm==0.9.0
pip install vllm==v0.9.0 transformers==4.51.3 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# for vllm>=0.11.0
pip install vllm==v0.11.0 transformers==4.57.1 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
python vllm_example.py python vllm_example.py
``` ```
@@ -165,7 +180,7 @@ python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
#### Advanced Usage #### Advanced Usage
For advanced users, we have provided training and inference scripts in `examples/libritts/cosyvoice/run.sh`. For advanced users, we have provided training and inference scripts in `examples/libritts`.
#### Build for deployment #### Build for deployment

View File

@@ -24,9 +24,7 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR)) sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import AutoModel from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
from cosyvoice.utils.file_utils import logging from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
def get_args(): def get_args():
@@ -61,15 +59,7 @@ def main():
model = AutoModel(model_dir=args.model_dir) model = AutoModel(model_dir=args.model_dir)
if get_model_type(model.model) == CosyVoiceModel: if model.__class__.__name__ == 'CosyVoice':
# 1. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
elif get_model_type(model.model) == CosyVoice2Model:
# 1. export llm text_encoder # 1. export llm text_encoder
llm_text_encoder = model.model.llm.text_encoder llm_text_encoder = model.model.llm.text_encoder
script = get_optimized_script(llm_text_encoder) script = get_optimized_script(llm_text_encoder)
@@ -93,6 +83,14 @@ def main():
script = get_optimized_script(flow_encoder.half()) script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder') logging.info('successfully export flow_encoder')
elif model.__class__.__name__ == 'CosyVoice2':
# 1. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
else: else:
raise ValueError('unsupported model type') raise ValueError('unsupported model type')

View File

@@ -49,6 +49,7 @@ def get_args():
parser.add_argument('--train_data', required=True, help='train data file') parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file') parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path') parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
parser.add_argument('--onnx_path', required=False, help='onnx path, which is required for online feature extraction')
parser.add_argument('--checkpoint', help='checkpoint model') parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir') parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir', parser.add_argument('--tensorboard_dir',
@@ -96,6 +97,7 @@ def get_args():
@record @record
def main(): def main():
args = get_args() args = get_args()
os.environ['onnx_path'] = args.onnx_path
logging.basicConfig(level=logging.DEBUG, logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s') format='%(asctime)s %(levelname)s %(message)s')
# gan train has some special initialization logic # gan train has some special initialization logic
@@ -104,12 +106,10 @@ def main():
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model} override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
if gan is True: if gan is True:
override_dict.pop('hift') override_dict.pop('hift')
try: if args.qwen_pretrain_path is not None:
with open(args.config, 'r') as f: override_dict['qwen_pretrain_path'] = args.qwen_pretrain_path
configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path}) with open(args.config, 'r') as f:
except Exception: configs = load_hyperpyyaml(f, overrides=override_dict)
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
if gan is True: if gan is True:
configs['train_conf'] = configs['train_conf_gan'] configs['train_conf'] = configs['train_conf_gan']
configs['train_conf'].update(vars(args)) configs['train_conf'].update(vars(args))

View File

@@ -89,6 +89,8 @@ class CosyVoice:
start_time = time.time() start_time = time.time()
def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True): def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
if self.__class__.__name__ == 'CosyVoice3' and '<|endofprompt|>' not in prompt_text + tts_text:
logging.warning('<|endofprompt|> not found in CosyVoice3 inference, check your input text')
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend) prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text): if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
@@ -114,7 +116,7 @@ class CosyVoice:
start_time = time.time() start_time = time.time()
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True): def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!' assert self.__class__.__name__ == 'CosyVoice', 'inference_instruct is only implemented for CosyVoice!'
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend) instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)

View File

@@ -20,18 +20,9 @@ import numpy as np
import whisper import whisper
from typing import Callable from typing import Callable
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
import torchaudio
import os import os
import re import re
import inflect import inflect
try:
import ttsfrd
use_ttsfrd = True
except ImportError:
print("failed to import ttsfrd, use wetext instead")
from wetext import Normalizer as ZhNormalizer
from wetext import Normalizer as EnNormalizer
use_ttsfrd = False
from cosyvoice.utils.file_utils import logging, load_wav from cosyvoice.utils.file_utils import logging, load_wav
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
@@ -56,21 +47,33 @@ class CosyVoiceFrontEnd:
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
"CPUExecutionProvider"]) "CPUExecutionProvider"])
if os.path.exists(spk2info): if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device) self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
else: else:
self.spk2info = {} self.spk2info = {}
self.allowed_special = allowed_special self.allowed_special = allowed_special
self.use_ttsfrd = use_ttsfrd self.inflect_parser = inflect.engine()
if self.use_ttsfrd: # NOTE compatible when no text frontend tool is avaliable
try:
import ttsfrd
self.frd = ttsfrd.TtsFrontendEngine() self.frd = ttsfrd.TtsFrontendEngine()
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
'failed to initialize ttsfrd resource' 'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyinvg') self.frd.set_lang_type('pinyinvg')
else: self.text_frontend = 'ttsfrd'
self.zh_tn_model = ZhNormalizer(remove_erhua=False) logging.info('use ttsfrd frontend')
self.en_tn_model = EnNormalizer() except:
self.inflect_parser = inflect.engine() try:
from wetext import Normalizer as ZhNormalizer
from wetext import Normalizer as EnNormalizer
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
self.en_tn_model = EnNormalizer()
self.text_frontend = 'wetext'
logging.info('use wetext frontend')
except:
self.text_frontend = ''
logging.info('no frontend is avaliable')
def _extract_text_token(self, text): def _extract_text_token(self, text):
if isinstance(text, Generator): if isinstance(text, Generator):
@@ -131,12 +134,13 @@ class CosyVoiceFrontEnd:
if text_frontend is False or text == '': if text_frontend is False or text == '':
return [text] if split is True else text return [text] if split is True else text
text = text.strip() text = text.strip()
if self.use_ttsfrd: if self.text_frontend == 'ttsfrd':
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]] texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
text = ''.join(texts) text = ''.join(texts)
else: else:
if contains_chinese(text): if contains_chinese(text):
text = self.zh_tn_model.normalize(text) if self.text_frontend == 'wetext':
text = self.zh_tn_model.normalize(text)
text = text.replace("\n", "") text = text.replace("\n", "")
text = replace_blank(text) text = replace_blank(text)
text = replace_corner_mark(text) text = replace_corner_mark(text)
@@ -147,7 +151,8 @@ class CosyVoiceFrontEnd:
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False)) token_min_n=60, merge_len=20, comma_split=False))
else: else:
text = self.en_tn_model.normalize(text) if self.text_frontend == 'wetext':
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser) text = spell_out_number(text, self.inflect_parser)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False)) token_min_n=60, merge_len=20, comma_split=False))
@@ -178,7 +183,7 @@ class CosyVoiceFrontEnd:
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding, 'flow_embedding': embedding} 'llm_embedding': embedding, 'flow_embedding': embedding}
else: else:
model_input = self.spk2info[zero_shot_spk_id] model_input = {**self.spk2info[zero_shot_spk_id]}
model_input['text'] = tts_text_token model_input['text'] = tts_text_token
model_input['text_len'] = tts_text_token_len model_input['text_len'] = tts_text_token_len
return model_input return model_input

View File

@@ -60,14 +60,15 @@ class CosyVoiceModel:
self.mel_overlap_dict = {} self.mel_overlap_dict = {}
self.flow_cache_dict = {} self.flow_cache_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
self.silent_tokens = []
def load(self, llm_model, flow_model, hift_model): def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) self.llm.load_state_dict(torch.load(llm_model, map_location=self.device, weights_only=True), strict=True)
self.llm.to(self.device).eval() self.llm.to(self.device).eval()
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) self.flow.load_state_dict(torch.load(flow_model, map_location=self.device, weights_only=True), strict=True)
self.flow.to(self.device).eval() self.flow.to(self.device).eval()
# in case hift_model is a hifigan model # in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device, weights_only=True).items()}
self.hift.load_state_dict(hift_state_dict, strict=True) self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval() self.hift.to(self.device).eval()
@@ -98,26 +99,33 @@ class CosyVoiceModel:
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
cur_silent_token_num, max_silent_token_num = 0, 5
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False): with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
if isinstance(text, Generator): if isinstance(text, Generator):
assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!' assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
for i in self.llm.inference_bistream(text=text, token_generator = self.llm.inference_bistream(text=text,
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device))
else:
token_generator = self.llm.inference(text=text.to(self.device),
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
prompt_text=prompt_text.to(self.device), prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device)): embedding=llm_embedding.to(self.device),
self.tts_speech_token_dict[uuid].append(i) uuid=uuid)
else: for i in token_generator:
for i in self.llm.inference(text=text.to(self.device), if i in self.silent_tokens:
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), cur_silent_token_num += 1
prompt_text=prompt_text.to(self.device), if cur_silent_token_num > max_silent_token_num:
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), continue
prompt_speech_token=llm_prompt_speech_token.to(self.device), else:
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), cur_silent_token_num = 0
embedding=llm_embedding.to(self.device), self.tts_speech_token_dict[uuid].append(i)
uuid=uuid):
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True self.llm_end_dict[uuid] = True
def vc_job(self, source_speech_token, uuid): def vc_job(self, source_speech_token, uuid):
@@ -248,6 +256,10 @@ class CosyVoice2Model(CosyVoiceModel):
self.fp16 = fp16 self.fp16 = fp16
# NOTE must matching training static_chunk_size # NOTE must matching training static_chunk_size
self.token_hop_len = 25 self.token_hop_len = 25
# NOTE increase token_hop_len incrementally to avoid duplicate inference
self.token_max_hop_len = 4 * self.token_hop_len
self.stream_scale_factor = 2
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
# hift cache # hift cache
self.mel_cache_len = 8 self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480) self.source_cache_len = int(self.mel_cache_len * 480)
@@ -260,6 +272,7 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict = {} self.tts_speech_token_dict = {}
self.llm_end_dict = {} self.llm_end_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
self.silent_tokens = []
def load_jit(self, flow_encoder_model): def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
@@ -344,6 +357,7 @@ class CosyVoice2Model(CosyVoiceModel):
stream=stream, stream=stream,
finalize=False) finalize=False)
token_offset += this_token_hop_len token_offset += this_token_hop_len
self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
yield {'tts_speech': this_tts_speech.cpu()} yield {'tts_speech': this_tts_speech.cpu()}
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len: if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
break break
@@ -394,6 +408,10 @@ class CosyVoice3Model(CosyVoice2Model):
self.fp16 = fp16 self.fp16 = fp16
# NOTE must matching training static_chunk_size # NOTE must matching training static_chunk_size
self.token_hop_len = 25 self.token_hop_len = 25
# NOTE increase token_hop_len incrementally to avoid duplicate inference
self.token_max_hop_len = 4 * self.token_hop_len
self.stream_scale_factor = 2
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
# rtf and decoding related # rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock() self.lock = threading.Lock()
@@ -401,6 +419,8 @@ class CosyVoice3Model(CosyVoice2Model):
self.tts_speech_token_dict = {} self.tts_speech_token_dict = {}
self.llm_end_dict = {} self.llm_end_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
# FSQ silent and breath token
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16): with torch.cuda.amp.autocast(self.fp16):

View File

@@ -145,7 +145,11 @@ def Dataset(data_list_file,
shuffle=shuffle, shuffle=shuffle,
partition=partition) partition=partition)
# map partial arg to padding func # map partial arg to padding func
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo) for i in range(1, len(data_pipeline)):
if data_pipeline[i].func.__name__ == 'compute_fbank' and gan is True:
data_pipeline[i] = partial(data_pipeline[i], token_mel_ratio=0)
if data_pipeline[i].func.__name__ == 'padding':
data_pipeline[i] = partial(data_pipeline[i], gan=gan, dpo=dpo)
for func in data_pipeline: for func in data_pipeline:
dataset = Processor(dataset, func, mode=mode) dataset = Processor(dataset, func, mode=mode)
return dataset return dataset

View File

@@ -16,17 +16,19 @@ import random
import pyarrow.parquet as pq import pyarrow.parquet as pq
from io import BytesIO from io import BytesIO
import numpy as np
import whisper
import torch import torch
import torchaudio import torchaudio
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F import torch.nn.functional as F
import pyworld as pw import pyworld as pw
from cosyvoice.utils.onnx import embedding_extractor, online_feature
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
def parquet_opener(data, mode='train', tts_data={}): def parquet_opener(data, mode='train'):
""" Give url or local file, return file descriptor """ Give url or local file, return file descriptor
Inplace operation. Inplace operation.
@@ -44,12 +46,8 @@ def parquet_opener(data, mode='train', tts_data={}):
df = df.to_pandas() df = df.to_pandas()
for i in range(len(df)): for i in range(len(df)):
sample.update(dict(df.loc[i])) sample.update(dict(df.loc[i]))
if mode == 'train': # NOTE do not return sample directly, must initialize a new dict
# NOTE do not return sample directly, must initialize a new dict yield {**sample}
yield {**sample}
else:
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
yield {**sample, 'tts_index': index, 'tts_text': text}
except Exception as ex: except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex)) logging.warning('Failed to open {}, ex info {}'.format(url, ex))
@@ -96,9 +94,9 @@ def filter(data,
continue continue
if len(sample['text_token']) > token_max_length: if len(sample['text_token']) > token_max_length:
continue continue
if len(sample['speech_token']) == 0: if online_feature is False and len(sample['speech_token']) == 0:
continue continue
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0: if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
continue continue
if num_frames != 0: if num_frames != 0:
if len(sample['text_token']) / num_frames < min_output_input_ratio: if len(sample['text_token']) / num_frames < min_output_input_ratio:
@@ -159,7 +157,7 @@ def truncate(data, truncate_length=24576, mode='train'):
def compute_fbank(data, def compute_fbank(data,
feat_extractor, feat_extractor,
token_mel_ratio=0, num_frames=-1,
mode='train'): mode='train'):
""" Extract fbank """ Extract fbank
@@ -174,14 +172,28 @@ def compute_fbank(data,
assert 'speech' in sample assert 'speech' in sample
assert 'utt' in sample assert 'utt' in sample
assert 'text_token' in sample assert 'text_token' in sample
waveform = sample['speech'] # NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) if num_frames != -1:
if token_mel_ratio != 0: index = int(np.ceil(sample['speech'].shape[1] / num_frames))
# trim to align speech_token and speech_feat sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0])) sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
feat = feat[:token_mel_ratio * token_len] yield sample
sample["speech_token"] = sample["speech_token"][:token_len]
sample['speech_feat'] = feat
def compute_whisper_fbank(data, num_frames=-1, mode='train'):
""" Extract whisper fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
if num_frames != -1:
assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
yield sample yield sample
@@ -220,8 +232,13 @@ def parse_embedding(data, normalize, mode='train'):
Iterable[{key, feat, label}] Iterable[{key, feat, label}]
""" """
for sample in data: for sample in data:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32) if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32) sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
embedding = embedding_extractor.inference(sample['speech_16k'])
sample['spk_embedding'] = sample['utt_embedding'] = embedding
else:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
if normalize: if normalize:
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0) sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0) sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
@@ -244,8 +261,6 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if 'instruct' in sample: if 'instruct' in sample:
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special) sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
else:
sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
yield sample yield sample
@@ -260,13 +275,14 @@ def shuffle(data, shuffle_size=10000, mode='train'):
Iterable[{key, feat, label}] Iterable[{key, feat, label}]
""" """
buf = [] buf = []
yield_size = int(shuffle_size / 2)
for sample in data: for sample in data:
buf.append(sample) buf.append(sample)
if len(buf) >= shuffle_size: if len(buf) >= shuffle_size:
random.shuffle(buf) random.shuffle(buf)
for x in buf: for x in buf[:yield_size]:
yield x yield x
buf = [] buf = buf[yield_size:]
# The sample left over # The sample left over
random.shuffle(buf) random.shuffle(buf)
for x in buf: for x in buf:
@@ -372,70 +388,42 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
""" """
for sample in data: for sample in data:
assert isinstance(sample, list) assert isinstance(sample, list)
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], order = torch.argsort(torch.tensor([x['speech'].size(1) for x in sample], dtype=torch.int32), descending=True)
dtype=torch.int32) batch = {}
order = torch.argsort(speech_feat_len, descending=True) batch['utts'] = [sample[i]['utt'] for i in order]
batch['text'] = [sample[i]['text'] for i in order]
utts = [sample[i]['utt'] for i in order]
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
speech = pad_sequence(speech, batch_first=True, padding_value=0)
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']) for i in order] text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) batch['text_token_len'] = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0) batch['text_token'] = pad_sequence(text_token, batch_first=True, padding_value=0)
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order] speech_feat = [sample[i]['speech_feat'] for i in order]
instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32) batch['speech_feat_len'] = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0) batch['speech_feat'] = pad_sequence(speech_feat, batch_first=True, padding_value=0)
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0) batch['utt_embedding'] = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) batch['spk_embedding'] = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = { if torch.tensor(['instruct_token' in sample[i] for i in order]).all():
"utts": utts, instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
"speech": speech, batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
"speech_len": speech_len, batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
"speech_token": speech_token, if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
"speech_token_len": speech_token_len, whisper_feat = [sample[i]['whisper_feat'] for i in order]
"speech_feat": speech_feat, batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
"speech_feat_len": speech_feat_len, batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
"text": text, if torch.tensor(['speech_token' in sample[i] for i in order]).all():
"text_token": text_token, speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
"text_token_len": text_token_len, batch['speech_token_len'] = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
"instruct_token": instruct_token, batch['speech_token'] = pad_sequence(speech_token, batch_first=True, padding_value=0)
"instruct_token_len": instruct_token_len,
"utt_embedding": utt_embedding,
"spk_embedding": spk_embedding,
}
if gan is True: if gan is True:
# in gan train, we need pitch_feat # in gan train, we need speech/pitch_feat
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
batch['speech_len'] = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
batch['speech'] = pad_sequence(speech, batch_first=True, padding_value=0)
pitch_feat = [sample[i]['pitch_feat'] for i in order] pitch_feat = [sample[i]['pitch_feat'] for i in order]
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32) batch['pitch_feat_len'] = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
pitch_feat = pad_sequence(pitch_feat, batch['pitch_feat'] = pad_sequence(pitch_feat, batch_first=True, padding_value=0)
batch_first=True,
padding_value=0)
batch["pitch_feat"] = pitch_feat
batch["pitch_feat_len"] = pitch_feat_len
else:
# only gan train needs speech, delete it to save memory
del batch["speech"]
del batch["speech_len"]
if dpo is True: if dpo is True:
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order] reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32) batch['reject_speech_token_len'] = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
reject_speech_token = pad_sequence(reject_speech_token, batch['reject_speech_token'] = pad_sequence(reject_speech_token, batch_first=True, padding_value=0)
batch_first=True,
padding_value=0)
batch['reject_speech_token'] = reject_speech_token
batch['reject_speech_token_len'] = reject_speech_token_len
if use_spk_embedding is True: if use_spk_embedding is True:
batch["embedding"] = batch["spk_embedding"] batch["embedding"] = batch["spk_embedding"]
else: else:

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import os, logging
import random import random
from typing import Dict, Optional from typing import Dict, Optional
import torch import torch
@@ -19,6 +19,7 @@ import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from omegaconf import DictConfig from omegaconf import DictConfig
from cosyvoice.utils.mask import make_pad_mask from cosyvoice.utils.mask import make_pad_mask
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
class MaskedDiffWithXvec(torch.nn.Module): class MaskedDiffWithXvec(torch.nn.Module):
@@ -179,14 +180,19 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
self.only_mask_loss = only_mask_loss self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len self.pre_lookahead_len = pre_lookahead_len
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
def forward( def forward(
self, self,
batch: dict, batch: dict,
device: torch.device, device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]: ) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device) if 'speech_token' not in batch:
token_len = batch['speech_token_len'].to(device) token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device) feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device) feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device) embedding = batch['embedding'].to(device)
@@ -308,14 +314,19 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
self.decoder = decoder self.decoder = decoder
self.only_mask_loss = only_mask_loss self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio self.token_mel_ratio = token_mel_ratio
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
def forward( def forward(
self, self,
batch: dict, batch: dict,
device: torch.device, device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]: ) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device) if 'speech_token' not in batch:
token_len = batch['speech_token_len'].to(device) token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device) feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device) feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device) embedding = batch['embedding'].to(device)
@@ -332,8 +343,9 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
token = self.input_embedding(torch.clamp(token, min=0)) * mask token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode # text encode
h, h_lengths = self.encoder(token, token_len, streaming=streaming) h = self.pre_lookahead_layer(token)
h = self.encoder_proj(h) h = h.repeat_interleave(self.token_mel_ratio, dim=1)
mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
# get conditions # get conditions
conds = torch.zeros(feat.shape, device=token.device) conds = torch.zeros(feat.shape, device=token.device)
@@ -344,7 +356,6 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
conds[i, :index] = feat[i, :index] conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2) conds = conds.transpose(1, 2)
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
loss, _ = self.decoder.compute_loss( loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(), feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1), mask.unsqueeze(1),

View File

@@ -174,8 +174,7 @@ class ConditionalCFM(BASECFM):
# random timestep # random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t = 1 - torch.cos(t * 0.5 * torch.pi)
# sample noise p(x_0) # sample noise p(x_0)
z = torch.randn_like(x1) z = torch.randn_like(x1)

View File

@@ -713,8 +713,8 @@ class CausalHiFTGenerator(HiFTGenerator):
@torch.inference_mode() @torch.inference_mode()
def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor: def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
# mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary # mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
self.f0_predictor.to('cpu') self.f0_predictor.to(torch.float64)
f0 = self.f0_predictor(speech_feat.cpu(), finalize=finalize).to(speech_feat) f0 = self.f0_predictor(speech_feat.to(torch.float64), finalize=finalize).to(speech_feat)
# f0->source # f0->source
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
s, _, _ = self.m_source(s) s, _, _ = self.m_source(s)

View File

@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import queue import os, queue
import random import random
import time import time
import threading import threading
@@ -28,6 +28,7 @@ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy from cosyvoice.utils.common import th_accuracy
from cosyvoice.utils.file_utils import logging from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.mask import make_pad_mask from cosyvoice.utils.mask import make_pad_mask
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
class TransformerLM(torch.nn.Module): class TransformerLM(torch.nn.Module):
@@ -300,19 +301,29 @@ class Qwen2LM(TransformerLM):
# 5. vllm related # 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(3)] self.stop_token_ids = [speech_token_size + i for i in range(3)]
self.vllm_output_queue = {} self.vllm_output_queue = {}
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len): def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
lm_target, lm_input = [], [] lm_target, lm_input = [], []
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True) text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True) speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
# NOTE add instruct_token in CosyVoice3
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
else:
instruct_token = [torch.empty(0).to(text_token[0])] * len(text_token)
instruct_token_emb = [torch.empty(0, 896).to(text_token_emb[0])] * len(text_token)
instruct_token_len = torch.zeros(len(text_token)).to(text_token_len)
for i in range(len(text_token)): for i in range(len(text_token)):
# bistream sequence # bistream sequence
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]: if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
this_lm_target, this_lm_input = [], [] this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
this_lm_target.append(IGNORE_ID) this_lm_target += [IGNORE_ID] * instruct_token_len[i]
this_lm_input.append(sos_emb.squeeze(dim=0)) this_lm_input.append(instruct_token_emb[i])
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()): for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist() this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist() this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
@@ -333,8 +344,8 @@ class Qwen2LM(TransformerLM):
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0) this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
# unistream sequence # unistream sequence
else: else:
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token]) this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0) this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
lm_target.append(this_lm_target) lm_target.append(this_lm_target)
lm_input.append(this_lm_input) lm_input.append(this_lm_input)
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
@@ -356,8 +367,11 @@ class Qwen2LM(TransformerLM):
""" """
text_token = batch['text_token'].to(device) text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device) text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device) if 'speech_token' not in batch:
speech_token_len = batch['speech_token_len'].to(device) speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# 1. encode text_token # 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token) text_token_emb = self.llm.model.model.embed_tokens(text_token)
@@ -658,6 +672,8 @@ class CosyVoice3LM(Qwen2LM):
# 5. vllm related # 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(200)] self.stop_token_ids = [speech_token_size + i for i in range(200)]
self.vllm_output_queue = {} self.vllm_output_queue = {}
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
def forward( def forward(
self, self,
@@ -673,14 +689,19 @@ class CosyVoice3LM(Qwen2LM):
""" """
text_token = batch['text_token'].to(device) text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device) text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device) if 'speech_token' not in batch:
speech_token_len = batch['speech_token_len'].to(device) speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# NOTE should append instruct_token to sequence, not implemented yet # NOTE should append instruct_token to sequence, not implemented yet
instruct_token = batch['instruct_token'].to(device) instruct_token = batch['instruct_token'].to(device)
instruct_token_len = batch['instruct_token_len'].to(device) instruct_token_len = batch['instruct_token_len'].to(device)
# 1. encode text_token # 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token) text_token_emb = self.llm.model.model.embed_tokens(text_token)
instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
# 3. sos and task_id # 3. sos and task_id
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1) sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
@@ -691,14 +712,14 @@ class CosyVoice3LM(Qwen2LM):
# 3. prepare llm_input/target # 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
speech_token, speech_token_emb, speech_token_len) speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
lm_target = lm_target.to(device) lm_target = lm_target.to(device)
# 4. run lm forward # 4. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device)) lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output) logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target.to(device)) loss = self.criterion_ce(logits, lm_target.to(device))
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID) acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc} return {'loss': loss, 'acc': acc}
@torch.inference_mode() @torch.inference_mode()

View File

@@ -139,6 +139,7 @@ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25,
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
if rep_num >= win_size * tau_r: if rep_num >= win_size * tau_r:
weighted_scores[top_ids] = -float('inf')
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
return top_ids return top_ids

54
cosyvoice/utils/onnx.py Normal file
View File

@@ -0,0 +1,54 @@
import onnxruntime
import torch, random
import os
import torchaudio.compliance.kaldi as kaldi
class SpeechTokenExtractor():
def __init__(self, model_path):
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.speech_tokenizer_session = onnxruntime.InferenceSession(model_path,
sess_options=option,
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
def inference(self, feat, feat_lengths, device):
speech_token = self.speech_tokenizer_session.run(None,
{self.speech_tokenizer_session.get_inputs()[0].name:
feat.transpose(1, 2).detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name:
feat_lengths.detach().cpu().numpy()})[0]
return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
class EmbeddingExtractor():
def __init__(self, model_path):
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.max_len = 10 * 16000
self.campplus_session = onnxruntime.InferenceSession(model_path,
sess_options=option,
providers=["CPUExecutionProvider"])
def inference(self, speech):
if speech.shape[1] > self.max_len:
start_index = random.randint(0, speech.shape[1] - self.max_len)
speech = speech[:, start_index: start_index + self.max_len]
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.campplus_session.run(None,
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
return torch.tensor(embedding).to(speech.device)
# singleton mode, only initialized once
onnx_path = os.environ.get('onnx_path')
if onnx_path is not None:
embedding_extractor, online_feature = EmbeddingExtractor(model_path=os.path.join(onnx_path, 'campplus.onnx')), True
else:
embedding_extractor, online_feature = None, False

View File

@@ -53,7 +53,7 @@ def init_distributed(args):
def init_dataset_and_dataloader(args, configs, gan, dpo): def init_dataset_and_dataloader(args, configs, gan, dpo):
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline'] data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True) train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=False, partition=False) cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='dev', gan=gan, dpo=dpo, shuffle=False, partition=False)
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
train_data_loader = DataLoader(train_dataset, train_data_loader = DataLoader(train_dataset,
@@ -164,18 +164,18 @@ def init_optimizer_and_scheduler(args, configs, model, gan):
raise ValueError("unknown scheduler: " + configs['train_conf']) raise ValueError("unknown scheduler: " + configs['train_conf'])
if configs['train_conf']['optim_d'] == 'adam': if configs['train_conf']['optim_d'] == 'adam':
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf']) optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
elif configs['train_conf']['optim_d'] == 'adamw': elif configs['train_conf']['optim_d'] == 'adamw':
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf']) optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
else: else:
raise ValueError("unknown optimizer: " + configs['train_conf']) raise ValueError("unknown optimizer: " + configs['train_conf'])
if configs['train_conf']['scheduler_d'] == 'warmuplr': if configs['train_conf']['scheduler_d'] == 'warmuplr':
scheduler_type = WarmupLR scheduler_type = WarmupLR
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf']) scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_d'])
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing': elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing scheduler_type = NoamHoldAnnealing
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf']) scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_d'])
elif configs['train_conf']['scheduler'] == 'constantlr': elif configs['train_conf']['scheduler'] == 'constantlr':
scheduler_type = ConstantLR scheduler_type = ConstantLR
scheduler_d = ConstantLR(optimizer_d) scheduler_d = ConstantLR(optimizer_d)

View File

@@ -23,6 +23,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Optional
from packaging.version import parse as vparse
import vllm
# vLLM-0.11.0+ only support V1 engine
VLLM_V1_ENGINE_ONLY: bool = vparse(vllm.__version__) >= vparse("0.11.0")
if VLLM_V1_ENGINE_ONLY:
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.model_executor.models.qwen2 import * from vllm.model_executor.models.qwen2 import *
@@ -87,10 +96,14 @@ class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: Optional[SamplingMetadata] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, if VLLM_V1_ENGINE_ONLY:
sampling_metadata, self.lm_head.bias) logits = self.logits_processor(self.lm_head, hidden_states,
self.lm_head.bias)
else:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,

View File

@@ -4,7 +4,7 @@ ARG VENV_NAME="cosyvoice"
ENV VENV=$VENV_NAME ENV VENV=$VENV_NAME
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
ENV DEBIAN_FRONTEN=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
SHELL ["/bin/bash", "--login", "-c"] SHELL ["/bin/bash", "--login", "-c"]

View File

@@ -18,7 +18,7 @@ def cosyvoice_example():
# zero_shot usage # zero_shot usage
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')): for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# cross_lingual usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean # cross_lingual usage, <|zh|><|en|><|ja|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.',
'./asset/cross_lingual_prompt.wav')): './asset/cross_lingual_prompt.wav')):
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
@@ -95,6 +95,12 @@ def cosyvoice3_example():
'./asset/zero_shot_prompt.wav', stream=False)): './asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('hotfix_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) torchaudio.save('hotfix_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# NOTE for Japanese usage, you must translate it to katakana.
# 歴史的世界においては、過去は単に過ぎ去ったものではない、プラトンのいう如く非有が有である。 -> レキシ テキ セカイ ニ オイ テ ワ、カコ ワ タンニ スギサッ タ モノ デ ワ ナイ、プラトン イウ ゴトク ヒ ユー ガ ユー デ アル。
for i, j in enumerate(cosyvoice.inference_cross_lingual('You are a helpful assistant.<|endofprompt|>レキシ テキ セカイ ニ オイ テ ワ、カコ ワ タンニ スギサッ タ モノ デ ワ ナイ、プラトン イウ ゴトク ヒ ユー ガ ユー デ アル。',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('japanese_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
def main(): def main():
# cosyvoice_example() # cosyvoice_example()

View File

@@ -1 +0,0 @@
../../../cosyvoice

View File

@@ -40,11 +40,10 @@ def main():
with open('{}/spk2utt'.format(args.des_dir), 'w') as f: with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
for k, v in spk2utt.items(): for k, v in spk2utt.items():
f.write('{} {}\n'.format(k, ' '.join(v))) f.write('{} {}\n'.format(k, ' '.join(v)))
if args.instruct is True: if args.instruct != '':
with open('{}/instruct'.format(args.des_dir), 'w') as f: with open('{}/instruct'.format(args.des_dir), 'w') as f:
for k, v in utt2text.items(): for k, v in utt2text.items():
# NOTE in CosyVoice3, we add instruct in sequence f.write('{} {}\n'.format(k, args.instruct))
f.write('{} You are a helpful assistant.<|endofprompt|>\n'.format(k, v))
return return
@@ -55,8 +54,7 @@ if __name__ == "__main__":
parser.add_argument('--des_dir', parser.add_argument('--des_dir',
type=str) type=str)
parser.add_argument('--instruct', parser.add_argument('--instruct',
action='store_true', type=str,
default=False, default='')
help='create instruct file or not')
args = parser.parse_args() args = parser.parse_args()
main() main()

View File

@@ -27,7 +27,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \ ../../../tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx --onnx_path $pretrained_model_dir/campplus.onnx
done done
fi fi
@@ -35,7 +35,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \ ../../../tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx --onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
done done
fi fi
@@ -44,7 +44,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \ ../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \ --num_processes 10 \
--src_dir data/$x \ --src_dir data/$x \
--des_dir data/$x/parquet --des_dir data/$x/parquet
@@ -69,7 +69,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
for model in llm flow hifigan; do for model in llm flow hifigan; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \ torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \ ../../../cosyvoice/bin/train.py \
--train_engine $train_engine \ --train_engine $train_engine \
--config conf/cosyvoice.yaml \ --config conf/cosyvoice.yaml \
--train_data data/train.data.list \ --train_data data/train.data.list \

View File

@@ -1 +0,0 @@
../../../tools

View File

@@ -139,7 +139,7 @@ tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer> get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special> allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960 max_length: 6000
min_length: 100 min_length: 100
token_max_length: 200 token_max_length: 200
token_min_length: 1 token_min_length: 1
@@ -158,7 +158,9 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
center: False center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor> feat_extractor: !ref <feat_extractor>
token_mel_ratio: 2 num_frames: 960
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
num_frames: 960
compute_f0: !name:cosyvoice.dataset.processor.compute_f0 compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate> sample_rate: !ref <sample_rate>
hop_size: 480 hop_size: 480
@@ -183,6 +185,7 @@ data_pipeline: [
!ref <resample>, !ref <resample>,
!ref <compute_fbank>, !ref <compute_fbank>,
!ref <parse_embedding>, !ref <parse_embedding>,
!ref <compute_whisper_fbank>,
!ref <shuffle>, !ref <shuffle>,
!ref <sort>, !ref <sort>,
!ref <batch>, !ref <batch>,

View File

@@ -1 +0,0 @@
../../../cosyvoice

View File

@@ -24,27 +24,12 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
done done
fi fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # NOTE embedding/token extraction is not necessary now as we support online feature extraction
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \ ../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \ --num_processes 10 \
--src_dir data/$x \ --src_dir data/$x \
--des_dir data/$x/parquet --des_dir data/$x/parquet
@@ -66,16 +51,16 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
fi fi
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
# NOTE will update llm/hift training later
for model in llm flow hifigan; do for model in llm flow hifigan; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \ torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \ ../../../cosyvoice/bin/train.py \
--train_engine $train_engine \ --train_engine $train_engine \
--config conf/cosyvoice2.yaml \ --config conf/cosyvoice2.yaml \
--train_data data/train.data.list \ --train_data data/train.data.list \
--cv_data data/dev.data.list \ --cv_data data/dev.data.list \
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \ --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
--onnx_path $pretrained_model_dir \
--model $model \ --model $model \
--checkpoint $pretrained_model_dir/$model.pt \ --checkpoint $pretrained_model_dir/$model.pt \
--model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \ --model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \

View File

@@ -36,7 +36,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \ ../../../tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx --onnx_path $pretrained_model_dir/campplus.onnx
done done
fi fi
@@ -44,7 +44,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 train-clean-100_reject train-clean-360_reject dev-clean dev-other test-clean test-other; do for x in train-clean-100 train-clean-360 train-other-500 train-clean-100_reject train-clean-360_reject dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \ ../../../tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx --onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
done done
fi fi
@@ -53,7 +53,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \ ../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \ --num_processes 10 \
--dpo \ --dpo \
--src_dir data/$x \ --src_dir data/$x \
@@ -80,11 +80,12 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
for model in llm; do for model in llm; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \ torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \ ../../../cosyvoice/bin/train.py \
--train_engine $train_engine \ --train_engine $train_engine \
--config conf/cosyvoice2.yaml \ --config conf/cosyvoice2.yaml \
--train_data data/train.data.list \ --train_data data/train.data.list \
--cv_data data/dev.data.list \ --cv_data data/dev.data.list \
--onnx_path $pretrained_model_dir \
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \ --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
--model $model \ --model $model \
--checkpoint $pretrained_model_dir/$model.pt \ --checkpoint $pretrained_model_dir/$model.pt \

View File

@@ -1 +0,0 @@
../../../tools

View File

@@ -20,7 +20,7 @@ num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size,
# model params # model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml. # for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this. # for system/third_party class/function, we do not require this.
llm: !new:cosyvoice.llm.llm.Qwen2LM llm: !new:cosyvoice.llm.llm.CosyVoice3LM
llm_input_size: !ref <llm_input_size> llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size> llm_output_size: !ref <llm_output_size>
speech_token_size: 6561 speech_token_size: 6561
@@ -35,8 +35,8 @@ llm: !new:cosyvoice.llm.llm.Qwen2LM
win_size: 10 win_size: 10
tau_r: 0.1 tau_r: 0.1
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithDiT
input_size: 512 input_size: 80
output_size: 80 output_size: 80
spk_embed_dim: !ref <spk_embed_dim> spk_embed_dim: !ref <spk_embed_dim>
output_type: 'mel' output_type: 'mel'
@@ -45,22 +45,10 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
only_mask_loss: True only_mask_loss: True
token_mel_ratio: !ref <token_mel_ratio> token_mel_ratio: !ref <token_mel_ratio>
pre_lookahead_len: 3 pre_lookahead_len: 3
encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder pre_lookahead_layer: !new:cosyvoice.transformer.upsample_encoder.PreLookaheadLayer
output_size: 512 in_channels: 80
attention_heads: 8 channels: 1024
linear_units: 2048 pre_lookahead_len: 3
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 512
use_cnn_module: False
macaron_style: False
static_chunk_size: !ref <chunk_size>
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
in_channels: 240 in_channels: 240
n_spks: 1 n_spks: 1
@@ -73,20 +61,20 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
training_cfg_rate: 0.2 training_cfg_rate: 0.2
inference_cfg_rate: 0.7 inference_cfg_rate: 0.7
reg_loss_type: 'l1' reg_loss_type: 'l1'
estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder estimator: !new:cosyvoice.flow.DiT.dit.DiT
in_channels: 320 dim: 1024
depth: 22
heads: 16
dim_head: 64
ff_mult: 2
mel_dim: 80
mu_dim: 80
spk_dim: 80
out_channels: 80 out_channels: 80
channels: [256]
dropout: 0.0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 12
num_heads: 8
act_fn: 'gelu'
static_chunk_size: !ref <chunk_size> * <token_mel_ratio> static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
num_decoding_left_chunks: !ref <num_decoding_left_chunks> num_decoding_left_chunks: !ref <num_decoding_left_chunks>
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator hift: !new:cosyvoice.hifigan.generator.CausalHiFTGenerator
in_channels: 80 in_channels: 80
base_channels: 512 base_channels: 512
nb_harmonics: 8 nb_harmonics: 8
@@ -105,7 +93,8 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1 lrelu_slope: 0.1
audio_limit: 0.99 audio_limit: 0.99
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor conv_pre_look_right: 4
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.CausalConvRNNF0Predictor
num_class: 1 num_class: 1
in_channels: 80 in_channels: 80
cond_channels: 512 cond_channels: 512
@@ -134,19 +123,20 @@ parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer
token_path: !ref <qwen_pretrain_path> token_path: !ref <qwen_pretrain_path>
skip_special_tokens: True skip_special_tokens: True
version: cosyvoice3
allowed_special: 'all' allowed_special: 'all'
tokenize: !name:cosyvoice.dataset.processor.tokenize tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer> get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special> allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960 max_length: 6000
min_length: 100 min_length: 100
token_max_length: 200 token_max_length: 200
token_min_length: 1 token_min_length: 1
resample: !name:cosyvoice.dataset.processor.resample resample: !name:cosyvoice.dataset.processor.resample
resample_rate: !ref <sample_rate> resample_rate: !ref <sample_rate>
truncate: !name:cosyvoice.dataset.processor.truncate truncate: !name:cosyvoice.dataset.processor.truncate
truncate_length: 24480 # must be a multiplier of hop_size truncate_length: 24960 # must be a multiplier of hop_size and token_mel_ratio
feat_extractor: !name:matcha.utils.audio.mel_spectrogram feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1920 n_fft: 1920
num_mels: 80 num_mels: 80
@@ -154,11 +144,13 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
hop_size: 480 hop_size: 480
win_size: 1920 win_size: 1920
fmin: 0 fmin: 0
fmax: 8000 fmax: null
center: False center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor> feat_extractor: !ref <feat_extractor>
token_mel_ratio: 2 num_frames: 960
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
num_frames: 960
compute_f0: !name:cosyvoice.dataset.processor.compute_f0 compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate> sample_rate: !ref <sample_rate>
hop_size: 480 hop_size: 480
@@ -183,6 +175,7 @@ data_pipeline: [
!ref <resample>, !ref <resample>,
!ref <compute_fbank>, !ref <compute_fbank>,
!ref <parse_embedding>, !ref <parse_embedding>,
!ref <compute_whisper_fbank>,
!ref <shuffle>, !ref <shuffle>,
!ref <sort>, !ref <sort>,
!ref <batch>, !ref <batch>,
@@ -231,4 +224,4 @@ train_conf_gan:
grad_clip: 5 grad_clip: 5
accum_grad: 1 # in gan training, accum_grad must be 1 accum_grad: 1 # in gan training, accum_grad must be 1
log_interval: 100 log_interval: 100
save_per_step: -1 save_per_step: -1

View File

@@ -1 +0,0 @@
../../../cosyvoice

View File

@@ -7,7 +7,7 @@ stop_stage=3
data_url=www.openslr.org/resources/60 data_url=www.openslr.org/resources/60
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
pretrained_model_dir=../../../pretrained_models/CosyVoice3-0.5B pretrained_model_dir=../../../pretrained_models/Fun-CosyVoice3-0.5B
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Data Download" echo "Data Download"
@@ -20,40 +20,25 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt" echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x mkdir -p data/$x
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x --instruct # NOTE in CosyVoice3, we add instruct in sequence
done python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x --instruct "You are a helpful assistant.<|endofprompt|>"
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v3.onnx
done done
fi fi
# NOTE embedding/token extraction is not necessary now as we support online feature extraction
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \ ../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \ --num_processes 10 \
--instruct \
--src_dir data/$x \ --src_dir data/$x \
--des_dir data/$x/parquet --des_dir data/$x/parquet
done done
fi fi
# train llm # train llm
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
job_id=1986 job_id=1986
dist_backend="nccl" dist_backend="nccl"
@@ -67,16 +52,16 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
fi fi
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
# NOTE will update llm/hift training later
for model in llm flow hifigan; do for model in llm flow hifigan; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \ torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \ ../../../cosyvoice/bin/train.py \
--train_engine $train_engine \ --train_engine $train_engine \
--config conf/cosyvoice3.yaml \ --config conf/cosyvoice3.yaml \
--train_data data/train.data.list \ --train_data data/train.data.list \
--cv_data data/dev.data.list \ --cv_data data/dev.data.list \
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \ --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
--onnx_path $pretrained_model_dir \
--model $model \ --model $model \
--checkpoint $pretrained_model_dir/$model.pt \ --checkpoint $pretrained_model_dir/$model.pt \
--model_dir `pwd`/exp/cosyvoice3/$model/$train_engine \ --model_dir `pwd`/exp/cosyvoice3/$model/$train_engine \

View File

@@ -1 +0,0 @@
../../../tools

View File

@@ -1 +0,0 @@
../../../cosyvoice

View File

@@ -27,7 +27,7 @@ fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in dev test train; do for x in dev test train; do
tools/extract_embedding.py --dir data/$x \ ../../../tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx --onnx_path $pretrained_model_dir/campplus.onnx
done done
fi fi
@@ -35,7 +35,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in dev test train; do for x in dev test train; do
tools/extract_speech_token.py --dir data/$x \ ../../../tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx --onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
done done
fi fi
@@ -44,7 +44,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in dev test train; do for x in dev test train; do
mkdir -p data/$x/parquet mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \ ../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \ --num_processes 10 \
--src_dir data/$x \ --src_dir data/$x \
--des_dir data/$x/parquet --des_dir data/$x/parquet
@@ -69,7 +69,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
for model in llm flow hifigan; do for model in llm flow hifigan; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \ torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
cosyvoice/bin/train.py \ ../../../cosyvoice/bin/train.py \
--train_engine $train_engine \ --train_engine $train_engine \
--config conf/cosyvoice.yaml \ --config conf/cosyvoice.yaml \
--train_data data/train.data.list \ --train_data data/train.data.list \

View File

@@ -1 +0,0 @@
../../../tools

View File

@@ -17,6 +17,7 @@ lightning==2.2.4
matplotlib==3.7.5 matplotlib==3.7.5
modelscope==1.20.0 modelscope==1.20.0
networkx==3.1 networkx==3.1
numpy==1.26.4
omegaconf==2.3.0 omegaconf==2.3.0
onnx==1.16.0 onnx==1.16.0
onnxruntime-gpu==1.18.0; sys_platform == 'linux' onnxruntime-gpu==1.18.0; sys_platform == 'linux'
@@ -35,6 +36,7 @@ tensorrt-cu12-libs==10.13.3.9; sys_platform == 'linux'
torch==2.3.1 torch==2.3.1
torchaudio==2.3.1 torchaudio==2.3.1
transformers==4.51.3 transformers==4.51.3
x-transformers==2.11.24
uvicorn==0.30.0 uvicorn==0.30.0
wetext==0.0.4 wetext==0.0.4
wget==3.2 wget==3.2

View File

@@ -24,7 +24,7 @@ import numpy as np
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../../..'.format(ROOT_DIR)) sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import load_wav from cosyvoice.utils.file_utils import load_wav
app = FastAPI() app = FastAPI()
@@ -88,14 +88,8 @@ if __name__ == '__main__':
default=50000) default=50000)
parser.add_argument('--model_dir', parser.add_argument('--model_dir',
type=str, type=str,
default='iic/CosyVoice-300M', default='iic/CosyVoice2-0.5B',
help='local path or modelscope repo id') help='local path or modelscope repo id')
args = parser.parse_args() args = parser.parse_args()
try: cosyvoice = AutoModel(model_dir=args.model_dir)
cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
uvicorn.run(app, host="0.0.0.0", port=args.port) uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@@ -25,7 +25,7 @@ import numpy as np
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../../..'.format(ROOT_DIR)) sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 from cosyvoice.cli.cosyvoice import AutoModel
logging.basicConfig(level=logging.DEBUG, logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s') format='%(asctime)s %(levelname)s %(message)s')
@@ -33,13 +33,7 @@ logging.basicConfig(level=logging.DEBUG,
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer): class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
def __init__(self, args): def __init__(self, args):
try: self.cosyvoice = AutoModel(model_dir=args.model_dir)
self.cosyvoice = CosyVoice(args.model_dir, trt_concurrent=args.max_conc)
except Exception:
try:
self.cosyvoice = CosyVoice2(args.model_dir, trt_concurrent=args.max_conc)
except Exception:
raise TypeError('no valid model_type!')
logging.info('grpc service initialized') logging.info('grpc service initialized')
def Inference(self, request, context): def Inference(self, request, context):
@@ -90,7 +84,7 @@ if __name__ == '__main__':
default=4) default=4)
parser.add_argument('--model_dir', parser.add_argument('--model_dir',
type=str, type=str,
default='iic/CosyVoice-300M', default='iic/CosyVoice2-0.5B',
help='local path or modelscope repo id') help='local path or modelscope repo id')
args = parser.parse_args() args = parser.parse_args()
main() main()

View File

@@ -28,6 +28,7 @@ import json
import os import os
import threading import threading
import time import time
from uuid import uuid4
import numpy as np import numpy as np
import torch import torch
@@ -364,6 +365,7 @@ class TritonPythonModel:
# Generate semantic tokens with LLM # Generate semantic tokens with LLM
generated_ids_iter = self.forward_llm(input_ids) generated_ids_iter = self.forward_llm(input_ids)
token2wav_request_id = request_id or str(uuid4())
if self.decoupled: if self.decoupled:
response_sender = request.get_response_sender() response_sender = request.get_response_sender()
@@ -392,7 +394,7 @@ class TritonPythonModel:
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav( sub_tts_speech = self.forward_token2wav(
this_tts_speech_token, request_id, prompt_speech_tokens, this_tts_speech_token, token2wav_request_id, prompt_speech_tokens,
prompt_speech_feat, prompt_spk_embedding, token_offset, False prompt_speech_feat, prompt_spk_embedding, token_offset, False
) )
@@ -427,7 +429,7 @@ class TritonPythonModel:
time.sleep(0.02) time.sleep(0.02)
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device) this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True) sub_tts_speech = self.forward_token2wav(this_tts_speech_token, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response) response_sender.send(inference_response)
@@ -441,7 +443,7 @@ class TritonPythonModel:
if generated_ids is None or len(generated_ids) == 0: if generated_ids is None or len(generated_ids) == 0:
raise pb_utils.TritonModelException("Generated IDs is None or empty") raise pb_utils.TritonModelException("Generated IDs is None or empty")
audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding) audio = self.forward_token2wav(generated_ids, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
# Prepare response # Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))

View File

@@ -29,31 +29,24 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
for utt in tqdm(utt_list): for utt in tqdm(utt_list):
data = open(utt2wav[utt], 'rb').read() data = open(utt2wav[utt], 'rb').read()
data_list.append(data) data_list.append(data)
wav_list = [utt2wav[utt] for utt in utt_list]
text_list = [utt2text[utt] for utt in utt_list]
spk_list = [utt2spk[utt] for utt in utt_list]
uttembedding_list = [utt2embedding[utt] for utt in utt_list]
spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
if args.dpo:
reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
if args.instruct:
instruct_list = [utt2instruct[utt] for utt in utt_list]
# 保存到parquet,utt2parquet_file,spk2parquet_file # 保存到parquet,utt2parquet_file,spk2parquet_file
df = pd.DataFrame() df = pd.DataFrame()
df['utt'] = utt_list df['utt'] = utt_list
df['wav'] = wav_list
df['audio_data'] = data_list df['audio_data'] = data_list
df['text'] = text_list df['wav'] = [utt2wav[utt] for utt in utt_list]
df['spk'] = spk_list df['text'] = [utt2text[utt] for utt in utt_list]
df['utt_embedding'] = uttembedding_list df['spk'] = [utt2spk[utt] for utt in utt_list]
df['spk_embedding'] = spkembedding_list if utt2embedding is not None:
df['speech_token'] = speech_token_list df['utt_embedding'] = [utt2embedding[utt] for utt in utt_list]
if spk2embedding is not None:
df['spk_embedding'] = [spk2embedding[utt2spk[utt]] for utt in utt_list]
if utt2speech_token is not None:
df['speech_token'] = [utt2speech_token[utt] for utt in utt_list]
if utt2instruct is not None:
df['instruct'] = [utt2instruct[utt] for utt in utt_list]
if args.dpo: if args.dpo:
df['reject_speech_token'] = reject_speech_token_list df['reject_speech_token'] = [utt2reject_speech_token.get(utt, None) for utt in utt_list]
if args.instruct:
df['instruct'] = instruct_list
df.to_parquet(parquet_file) df.to_parquet(parquet_file)
with open(utt2parquet_file, 'w') as f: with open(utt2parquet_file, 'w') as f:
json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2) json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
@@ -72,10 +65,6 @@ if __name__ == "__main__":
type=int, type=int,
default=1, default=1,
help='num processes for make parquets') help='num processes for make parquets')
parser.add_argument('--instruct',
action='store_true',
default=False,
help='has instruct file or not')
parser.add_argument('--src_dir', parser.add_argument('--src_dir',
type=str) type=str)
parser.add_argument('--des_dir', parser.add_argument('--des_dir',
@@ -99,16 +88,19 @@ if __name__ == "__main__":
for l in f: for l in f:
l = l.replace('\n', '').split() l = l.replace('\n', '').split()
utt2spk[l[0]] = l[1] utt2spk[l[0]] = l[1]
if args.instruct is True: if os.path.exists('{}/instruct'.format(args.src_dir)):
utt2instruct = {}
with open('{}/instruct'.format(args.src_dir)) as f: with open('{}/instruct'.format(args.src_dir)) as f:
for l in f: for l in f:
l = l.replace('\n', '').split() l = l.replace('\n', '').split()
utt2instruct[l[0]] = ' '.join(l[1:]) utt2instruct[l[0]] = ' '.join(l[1:])
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) else:
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) utt2instruct = None
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) if os.path.exists('{}/utt2embedding.pt'.format(args.src_dir)) else None
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) if os.path.exists('{}/spk2embedding.pt'.format(args.src_dir)) else None
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) if os.path.exists('{}/utt2speech_token.pt'.format(args.src_dir)) else None
if args.dpo: if args.dpo:
utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir)) utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir)) if os.path.exists('{}_reject/utt2speech_token.pt'.format(args.src_dir)) else {}
utts = list(utt2wav.keys()) utts = list(utt2wav.keys())
# Using process pool to speedup # Using process pool to speedup

View File

@@ -57,9 +57,6 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
prompt_wav = None prompt_wav = None
# if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode # if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode
if mode_checkbox_group in ['自然语言控制']: if mode_checkbox_group in ['自然语言控制']:
if cosyvoice.instruct is False:
gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
yield (cosyvoice.sample_rate, default_data)
if instruct_text == '': if instruct_text == '':
gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本') gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
yield (cosyvoice.sample_rate, default_data) yield (cosyvoice.sample_rate, default_data)
@@ -67,9 +64,6 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略') gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
# if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
if mode_checkbox_group in ['跨语种复刻']: if mode_checkbox_group in ['跨语种复刻']:
if cosyvoice.instruct is True:
gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
yield (cosyvoice.sample_rate, default_data)
if instruct_text != '': if instruct_text != '':
gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略') gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
if prompt_wav is None: if prompt_wav is None:
@@ -167,7 +161,7 @@ if __name__ == '__main__':
default=8000) default=8000)
parser.add_argument('--model_dir', parser.add_argument('--model_dir',
type=str, type=str,
default='pretrained_models/CosyVoice3-0.5B', default='pretrained_models/CosyVoice2-0.5B',
help='local path or modelscope repo id') help='local path or modelscope repo id')
args = parser.parse_args() args = parser.parse_args()
cosyvoice = AutoModel(model_dir=args.model_dir) cosyvoice = AutoModel(model_dir=args.model_dir)