update readme

This commit is contained in:
root
2025-07-30 11:05:49 +00:00
parent 62d082634e
commit 0bc48c1180
6 changed files with 54 additions and 19 deletions

View File

@@ -1,6 +1,6 @@
# CosyVoice2 LLM Reinforcement Learning Recipe
This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08 % to 3.36 %.
This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%.
## Table of Contents
@@ -18,6 +18,7 @@ We recommend using the pre-built Docker image below. Alternatively, you can manu
```bash
docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
```
If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally.
## Data Preparation
@@ -43,16 +44,16 @@ data/parquet_tiny/train.parquet
data/parquet_tiny/test.parquet
```
Each sample is automatically wrapped into a cosyvoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens.
Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens.
## Reward Function & ASR Server
To compute rewards we run a lightweight server that:
To compute rewards, we run a lightweight server that:
1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model.
2. Transcribes the waveform with **SenseVoice** ASR.
3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score in the range \[0-1\].
3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1.
Start the server (stage `1`) in a dedicated terminal or on a separate GPU:
@@ -61,7 +62,7 @@ bash run.sh 1 1
# Triton server listens on ports 8000/8001/8002
```
The custom reward implementation lives in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
## Training
@@ -78,10 +79,12 @@ Key CLI arguments passed to `verl.trainer.main_ppo`:
* `custom_reward_function.path=reward_tts.py` custom reward function described above.
Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware.
> [!TIP]
> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model.
## Evaluation
After training completes, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`):
After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`):
```bash
bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model
@@ -107,15 +110,16 @@ bash run.sh 5 5
```
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
> [!TIP]
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
## Results
| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment |
|-------|------------------------|------------------------------|---------|
| CosyVoice2 LLM (official) | 1.45 % | 4.08 % | See the [paper](https://arxiv.org/abs/2412.10117) |
| CosyVoice2 LLM + GRPO | 1.37 % | **3.36 %** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3) |
| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) |
| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model |
## Acknowledgement
This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo).