mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
Merge pull request #1583 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
This commit is contained in:
@@ -31,7 +31,7 @@
|
|||||||
|
|
||||||
- [x] 2025/08
|
- [x] 2025/08
|
||||||
|
|
||||||
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support
|
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
|
||||||
|
|
||||||
- [x] 2025/07
|
- [x] 2025/07
|
||||||
|
|
||||||
|
|||||||
6
examples/grpo/cosyvoice2/Dockerfile
Normal file
6
examples/grpo/cosyvoice2/Dockerfile
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
|
||||||
|
COPY requirements.txt /myworkspace/requirements.txt
|
||||||
|
RUN pip install -r /myworkspace/requirements.txt
|
||||||
|
RUN pip install -U nvidia-pytriton
|
||||||
|
RUN git clone https://github.com/yuekaizhang/verl.git /myworkspace/verl -b thread && cd /myworkspace/verl && pip install --no-deps -e .
|
||||||
|
RUN git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /myworkspace/PytritonSenseVoice && cd /myworkspace/PytritonSenseVoice && pip install -e .
|
||||||
125
examples/grpo/cosyvoice2/README.md
Normal file
125
examples/grpo/cosyvoice2/README.md
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# CosyVoice2 LLM Reinforcement Learning Recipe
|
||||||
|
|
||||||
|
This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Environment Setup](#environment-setup)
|
||||||
|
- [Data Preparation](#data-preparation)
|
||||||
|
- [Reward Function & ASR Server](#reward-function--asr-server)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Evaluation](#evaluation)
|
||||||
|
- [Export Model](#export-model)
|
||||||
|
- [Results](#results)
|
||||||
|
- [Acknowledgement](#acknowledgement)
|
||||||
|
|
||||||
|
## Environment Setup
|
||||||
|
We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile.
|
||||||
|
```bash
|
||||||
|
docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
|
||||||
|
```
|
||||||
|
If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally.
|
||||||
|
|
||||||
|
## Data Preparation
|
||||||
|
|
||||||
|
`prepare_data.py` expects a JSON/JSONL file with at least the following schema:
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
{
|
||||||
|
"text": "An example sentence to be synthesized."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face.
|
||||||
|
|
||||||
|
Stage `0` converts raw JSONL files into the parquet format expected by veRL:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 0 0
|
||||||
|
```
|
||||||
|
Create two JSONL files—`train.jsonl` and `test.jsonl`.
|
||||||
|
The script will then generate two Parquet files:
|
||||||
|
|
||||||
|
```
|
||||||
|
data/parquet_tiny/train.parquet
|
||||||
|
data/parquet_tiny/test.parquet
|
||||||
|
```
|
||||||
|
|
||||||
|
Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens.
|
||||||
|
|
||||||
|
|
||||||
|
## Reward Function & ASR Server
|
||||||
|
|
||||||
|
To compute rewards, we run a lightweight server that:
|
||||||
|
|
||||||
|
1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model.
|
||||||
|
2. Transcribes the waveform with **SenseVoice** ASR.
|
||||||
|
3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1.
|
||||||
|
|
||||||
|
Start the server (stage `1`) in a dedicated terminal or on a separate GPU:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 1 1
|
||||||
|
# Triton server listens on ports 8000/8001/8002
|
||||||
|
```
|
||||||
|
|
||||||
|
The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
Run stage `2` to start GRPO training:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 2 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Key CLI arguments passed to `verl.trainer.main_ppo`:
|
||||||
|
|
||||||
|
* `algorithm.adv_estimator=grpo` – use GRPO instead of PPO.
|
||||||
|
* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet`
|
||||||
|
* `custom_reward_function.path=reward_tts.py` – custom reward function described above.
|
||||||
|
|
||||||
|
Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware.
|
||||||
|
> [!TIP]
|
||||||
|
> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model.
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model
|
||||||
|
```
|
||||||
|
|
||||||
|
You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 4 4
|
||||||
|
```
|
||||||
|
|
||||||
|
This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> The script also supports the Seed-TTS test set by setting `dataset=test_zh`.
|
||||||
|
|
||||||
|
## Export Model
|
||||||
|
|
||||||
|
To use the RL-trained model with the official CosyVoice repository:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 5 5
|
||||||
|
```
|
||||||
|
|
||||||
|
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
|
||||||
|
> [!TIP]
|
||||||
|
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment |
|
||||||
|
|-------|------------------------|------------------------------|---------|
|
||||||
|
| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) |
|
||||||
|
| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model |
|
||||||
|
|
||||||
|
## Acknowledgement
|
||||||
|
|
||||||
|
This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo).
|
||||||
71
examples/grpo/cosyvoice2/huggingface_to_pretrained.py
Normal file
71
examples/grpo/cosyvoice2/huggingface_to_pretrained.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
python3 hf2pretrained.py --hf-cosyvoice2-llm-path /workspace/rl-exp/checkpoint-400 --output-path /workspace/CosyVoice2-0.5B/llm-new.pt
|
||||||
|
"""
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
import torch
|
||||||
|
from safetensors import safe_open
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-cosyvoice2-llm-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The RL trained CosyVoice2 model path in HuggingFace format",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-path",
|
||||||
|
type=str,
|
||||||
|
default="./llm.pt",
|
||||||
|
help="The path to save the llm.pt",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.hf_cosyvoice2_llm_path)
|
||||||
|
speech_start_idx = tokenizer.convert_tokens_to_ids("<|s_0|>")
|
||||||
|
cosyvoice2_token_size = 6561 + 3
|
||||||
|
llm_embedding_vocab_size = 2
|
||||||
|
|
||||||
|
hf_tensors = {}
|
||||||
|
with safe_open(f"{args.hf_cosyvoice2_llm_path}/model.safetensors", framework="pt", device="cpu") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k.startswith("lm_head.bias"):
|
||||||
|
# RL trained model disable bias for lm_head
|
||||||
|
continue
|
||||||
|
new_k = "llm.model." + k
|
||||||
|
hf_tensors[new_k] = f.get_tensor(k)
|
||||||
|
if k.startswith("lm_head"):
|
||||||
|
hf_tensors["llm_decoder.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
|
||||||
|
hf_tensors["llm_decoder.bias"] = torch.zeros_like(hf_tensors["llm_decoder.weight"][:, 0])
|
||||||
|
if k.startswith("model.embed_tokens"):
|
||||||
|
hf_tensors["speech_embedding.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
|
||||||
|
hf_tensors["llm_embedding.weight"] = f.get_tensor(k)[speech_start_idx + cosyvoice2_token_size:speech_start_idx + cosyvoice2_token_size + llm_embedding_vocab_size]
|
||||||
|
|
||||||
|
# use tie_word_embeddings=True
|
||||||
|
hf_tensors["llm.model.model.embed_tokens.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"][:151936]
|
||||||
|
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
|
||||||
|
|
||||||
|
torch.save(hf_tensors, args.output_path)
|
||||||
397
examples/grpo/cosyvoice2/infer_dataset.py
Normal file
397
examples/grpo/cosyvoice2/infer_dataset.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Example Usage
|
||||||
|
dataset=zero_shot_zh
|
||||||
|
output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts
|
||||||
|
|
||||||
|
token2wav_path=/workspace/CosyVoice2-0.5B
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
torchrun --nproc_per_node=8 \
|
||||||
|
infer_dataset.py \
|
||||||
|
--output-dir $output_dir \
|
||||||
|
--llm-model-name-or-path $llm_path/merged_hf_model \
|
||||||
|
--token2wav-path $token2wav_path \
|
||||||
|
--split-name ${dataset} || exit 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||||
|
from cosyvoice.utils.file_utils import load_wav
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
import soundfile as sf
|
||||||
|
import s3tokenizer
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
try:
|
||||||
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"
|
||||||
|
|
||||||
|
|
||||||
|
def audio_decode_cosyvoice2(
|
||||||
|
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate audio from tokens with optional tone and prompt embedding.
|
||||||
|
"""
|
||||||
|
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
|
||||||
|
"empty", prompt_text, prompt_speech_16k, 24000
|
||||||
|
)
|
||||||
|
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||||
|
token=audio_tokens.to(codec_decoder.model.device),
|
||||||
|
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_token_len=torch.tensor(
|
||||||
|
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
|
||||||
|
).to(codec_decoder.model.device),
|
||||||
|
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
|
||||||
|
finalize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||||
|
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio_hat
|
||||||
|
|
||||||
|
|
||||||
|
def extract_speech_ids(speech_tokens_str):
|
||||||
|
"""Extract speech IDs from token strings like <|s_23456|>"""
|
||||||
|
speech_ids = []
|
||||||
|
for token_str in speech_tokens_str:
|
||||||
|
if token_str.startswith('<|s_') and token_str.endswith('|>'):
|
||||||
|
num_str = token_str[4:-2]
|
||||||
|
num = int(num_str)
|
||||||
|
speech_ids.append(num)
|
||||||
|
else:
|
||||||
|
print(f"Unexpected token: {token_str}")
|
||||||
|
return speech_ids
|
||||||
|
|
||||||
|
|
||||||
|
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||||
|
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||||
|
speech_id_str = ""
|
||||||
|
for token in cosy2_tokens:
|
||||||
|
speech_id_str += f"<|s_{token}|>"
|
||||||
|
return speech_id_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
|
||||||
|
parser.add_argument(
|
||||||
|
"--split-name",
|
||||||
|
type=str,
|
||||||
|
default="wenetspeech4tts",
|
||||||
|
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir", required=True, type=str, help="dir to save result"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="batch size (per-device) for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers", type=int, default=1, help="workers for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefetch", type=int, default=5, help="prefetch for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model-name-or-path",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="LLM model path (includes both model and tokenizer)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token2wav-path",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="CosyVoice2 token2wav model path",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-text",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The prompt text for CosyVoice2",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-speech-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The path to the prompt speech for CosyVoice2",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-p",
|
||||||
|
type=float,
|
||||||
|
default=0.95,
|
||||||
|
help="top p for sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="temperature for sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="top k for sampling",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||||
|
"""Simplified data collator for batch_size=1 processing"""
|
||||||
|
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
||||||
|
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
||||||
|
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||||
|
mels, prompt_audio_cosy2tokens_list = [], []
|
||||||
|
for item in batch:
|
||||||
|
prompt_text, target_text = (
|
||||||
|
item["prompt_text"],
|
||||||
|
item["target_text"],
|
||||||
|
)
|
||||||
|
prompt_text_list.append(prompt_text)
|
||||||
|
# Combine prompt and target text
|
||||||
|
full_text = prompt_text + target_text
|
||||||
|
|
||||||
|
# get prompt audio for CosyVoice2 (convert to 16kHz)
|
||||||
|
ref_audio_org, ref_sr = (
|
||||||
|
item["prompt_audio"]["array"],
|
||||||
|
item["prompt_audio"]["sampling_rate"],
|
||||||
|
)
|
||||||
|
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
|
||||||
|
# ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
|
||||||
|
print(ref_audio_org.shape)
|
||||||
|
|
||||||
|
if ref_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||||
|
ref_audio = resampler(ref_audio_org)
|
||||||
|
else:
|
||||||
|
ref_audio = ref_audio_org
|
||||||
|
|
||||||
|
prompt_audio_list.append(ref_audio)
|
||||||
|
|
||||||
|
if "prompt_audio_cosy2_tokens" in item:
|
||||||
|
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
|
||||||
|
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
|
||||||
|
else:
|
||||||
|
# convert to float first
|
||||||
|
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
|
||||||
|
|
||||||
|
if len(mels) > 0:
|
||||||
|
mels, mels_lens = s3tokenizer.padding(mels)
|
||||||
|
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
||||||
|
for i in range(len(codes)):
|
||||||
|
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
||||||
|
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
|
||||||
|
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
||||||
|
# Create chat template for LLM generation
|
||||||
|
chat = [
|
||||||
|
{"role": "user", "content": full_text},
|
||||||
|
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
|
||||||
|
]
|
||||||
|
if 'system' in tokenizer.chat_template:
|
||||||
|
tokenizer.chat_template = TEMPLATE
|
||||||
|
input_ids = tokenizer.apply_chat_template(
|
||||||
|
chat,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors='pt',
|
||||||
|
continue_final_message=True
|
||||||
|
)
|
||||||
|
input_ids_list.append(input_ids.squeeze(0))
|
||||||
|
|
||||||
|
# For batch_size=1, no need to pad
|
||||||
|
if len(input_ids_list) == 1:
|
||||||
|
input_ids = input_ids_list[0].unsqueeze(0)
|
||||||
|
else:
|
||||||
|
# Handle batch > 1 if needed
|
||||||
|
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||||
|
input_ids_list = [
|
||||||
|
torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
|
||||||
|
for input_ids in input_ids_list
|
||||||
|
]
|
||||||
|
input_ids = torch.stack(input_ids_list)
|
||||||
|
|
||||||
|
ids = [item["id"] for item in batch]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"ids": ids,
|
||||||
|
"prompt_text": prompt_text_list,
|
||||||
|
"prompt_audio_list": prompt_audio_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed():
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
rank = int(os.environ.get("RANK", 0))
|
||||||
|
print(
|
||||||
|
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||||
|
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
return world_size, local_rank, rank
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
world_size, local_rank, rank = init_distributed()
|
||||||
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
|
||||||
|
# Load LLM model and tokenizer directly
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
||||||
|
model.eval()
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
cosyvoice_codec = CosyVoice2(
|
||||||
|
args.token2wav_path, load_jit=True, load_trt=True, fp16=True
|
||||||
|
)
|
||||||
|
if args.prompt_speech_path:
|
||||||
|
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
|
||||||
|
else:
|
||||||
|
prompt_speech_16k = None
|
||||||
|
s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None
|
||||||
|
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
|
||||||
|
dataset = load_dataset(
|
||||||
|
dataset_name,
|
||||||
|
split=args.split_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
prefetch_factor=args.prefetch,
|
||||||
|
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
total_steps = len(dataset)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||||
|
|
||||||
|
for batch in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
input_ids = batch["input_ids"].to(device)
|
||||||
|
|
||||||
|
# Generate speech tokens using LLM
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids,
|
||||||
|
max_new_tokens=2048, # Max length for generation
|
||||||
|
do_sample=True,
|
||||||
|
top_p=args.top_p,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_k=args.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process each sample in the batch
|
||||||
|
for i in range(len(batch["ids"])):
|
||||||
|
# Extract generated tokens (excluding input)
|
||||||
|
input_length = input_ids[i].shape[0]
|
||||||
|
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
|
||||||
|
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Extract speech IDs from token strings like <|s_23456|>
|
||||||
|
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||||
|
|
||||||
|
if len(speech_ids) == 0:
|
||||||
|
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert to tensor for CosyVoice2
|
||||||
|
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
|
||||||
|
|
||||||
|
if args.prompt_text is not None:
|
||||||
|
current_prompt_text = args.prompt_text
|
||||||
|
current_prompt_audio = prompt_speech_16k
|
||||||
|
else:
|
||||||
|
current_prompt_text = batch["prompt_text"][i]
|
||||||
|
current_prompt_audio = batch["prompt_audio_list"][i]
|
||||||
|
|
||||||
|
if current_prompt_audio is not None:
|
||||||
|
# Generate audio using CosyVoice2
|
||||||
|
audio_hat = audio_decode_cosyvoice2(
|
||||||
|
audio_tokens,
|
||||||
|
current_prompt_text,
|
||||||
|
current_prompt_audio,
|
||||||
|
cosyvoice_codec,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to numpy and save
|
||||||
|
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||||
|
target_sample_rate = 24000
|
||||||
|
|
||||||
|
utt = batch["ids"][i]
|
||||||
|
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
|
||||||
|
|
||||||
|
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
|
||||||
|
else:
|
||||||
|
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.update(world_size * len(batch["ids"]))
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
86
examples/grpo/cosyvoice2/prepare_data.py
Normal file
86
examples/grpo/cosyvoice2/prepare_data.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Preprocess the Text to Speech dataset to parquet format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
from verl.utils.hdfs_io import copy, makedirs
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
|
||||||
|
parser.add_argument("--test_file", required=True, help="Path to test JSON/JSONL file")
|
||||||
|
parser.add_argument("--local_dir", default=None, required=True)
|
||||||
|
parser.add_argument("--hdfs_dir", default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load datasets from local JSON files
|
||||||
|
train_dataset = datasets.load_dataset("json", data_files=args.train_file)['train']
|
||||||
|
test_dataset = datasets.load_dataset("json", data_files=args.test_file)['train']
|
||||||
|
|
||||||
|
# add a row to each data item that represents a unique id
|
||||||
|
def make_map_fn(split):
|
||||||
|
def process_fn(example, idx):
|
||||||
|
text = example.pop("text")
|
||||||
|
|
||||||
|
# use cosyvoice2 official huggingface compatible checkpoint template
|
||||||
|
question = text
|
||||||
|
answer = ""
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"data_source": f"{args.train_file}_{args.test_file}", # Use file names as data source
|
||||||
|
"prompt": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": question,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": answer,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"ability": "text-to-speech",
|
||||||
|
"reward_model": {"style": "rule", "ground_truth": text},
|
||||||
|
"extra_info": {
|
||||||
|
"split": split,
|
||||||
|
"index": idx,
|
||||||
|
"text": text,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
return process_fn
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
|
||||||
|
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
|
||||||
|
|
||||||
|
local_dir = args.local_dir
|
||||||
|
hdfs_dir = args.hdfs_dir
|
||||||
|
|
||||||
|
print(train_dataset)
|
||||||
|
print(test_dataset)
|
||||||
|
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
|
||||||
|
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
|
||||||
|
|
||||||
|
if hdfs_dir is not None:
|
||||||
|
makedirs(hdfs_dir)
|
||||||
|
|
||||||
|
copy(src=local_dir, dst=hdfs_dir)
|
||||||
135
examples/grpo/cosyvoice2/pretrained_to_huggingface.py
Normal file
135
examples/grpo/cosyvoice2/pretrained_to_huggingface.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage: Instruct TTS
|
||||||
|
python3 infer.py \
|
||||||
|
--token2wav-path /workspace/CosyVoice2-0.5B \
|
||||||
|
--prompt-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
||||||
|
--prompt-speech-path ./assets/prompt_audio.wav \
|
||||||
|
--model-path ./transformers_cosyvoice2_llm \
|
||||||
|
--input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。"
|
||||||
|
"""
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||||
|
import sys
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import torch
|
||||||
|
|
||||||
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained-cosyvoice2-path",
|
||||||
|
type=str,
|
||||||
|
default="/workspace/CosyVoice2-0.5B",
|
||||||
|
help="Token2Wav path, default to %(default)r",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
type=str,
|
||||||
|
default='./transformers_cosyvoice2_llm',
|
||||||
|
help="The path to save the model",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
cosy2_model = CosyVoice2(
|
||||||
|
args.pretrained_cosyvoice2_path, load_jit=False, load_trt=False, fp16=False
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = cosy2_model.model.llm.llm.model
|
||||||
|
|
||||||
|
speech_embedding = cosy2_model.model.llm.speech_embedding
|
||||||
|
llm_decoder = cosy2_model.model.llm.llm_decoder
|
||||||
|
llm_embedding = cosy2_model.model.llm.llm_embedding
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(f"{args.pretrained_cosyvoice2_path}/CosyVoice-BlankEN")
|
||||||
|
special_tokens = {
|
||||||
|
'eos_token': '<|endoftext|>',
|
||||||
|
'pad_token': '<|endoftext|>',
|
||||||
|
'additional_special_tokens': [
|
||||||
|
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
||||||
|
'[breath]', '<strong>', '</strong>', '[noise]',
|
||||||
|
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
||||||
|
'[quick_breath]',
|
||||||
|
"<laughter>", "</laughter>",
|
||||||
|
"[hissing]", "[sigh]", "[vocalized-noise]",
|
||||||
|
"[lipsmack]", "[mn]"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
tokenizer.add_special_tokens(special_tokens)
|
||||||
|
|
||||||
|
original_tokenizer_vocab_size = len(tokenizer)
|
||||||
|
cosyvoice2_token_size = 6561
|
||||||
|
new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
|
||||||
|
"<|eos1|>", "<|eos2|>", "<|eos3|>", "<|sos|>", "<|task_id|>"
|
||||||
|
]
|
||||||
|
num_added_tokens = tokenizer.add_tokens(new_tokens)
|
||||||
|
|
||||||
|
llm.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128)
|
||||||
|
vocab_size = llm.get_input_embeddings().weight.shape[0]
|
||||||
|
|
||||||
|
feature_size = speech_embedding.embedding_dim
|
||||||
|
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=vocab_size, bias=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# set the weight and bias of the new lm_head to 0
|
||||||
|
new_lm_head.weight.data.zero_()
|
||||||
|
# make bias value -inf
|
||||||
|
new_lm_head.bias.data.fill_(-float('inf'))
|
||||||
|
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
|
||||||
|
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
|
||||||
|
|
||||||
|
llm.lm_head = new_lm_head
|
||||||
|
input_embeddings = llm.get_input_embeddings()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
|
||||||
|
input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
|
||||||
|
|
||||||
|
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
|
||||||
|
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
|
||||||
|
original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
|
||||||
|
llm.generation_config.eos_token_id = eos_token_ids
|
||||||
|
llm.generation_config.temperature = 1.0
|
||||||
|
llm.generation_config.top_p = 0.8
|
||||||
|
llm.generation_config.top_k = 25
|
||||||
|
|
||||||
|
llm.config.eos_token_id = original_tokenizer_vocab_size + cosyvoice2_token_size
|
||||||
|
llm.config.vocab_size = vocab_size
|
||||||
|
llm.config.tie_word_embeddings = False
|
||||||
|
llm.config.use_bias = True
|
||||||
|
llm.to(torch.bfloat16)
|
||||||
|
llm.save_pretrained(args.save_path)
|
||||||
|
|
||||||
|
TEMPLATE = (
|
||||||
|
"{%- for message in messages %}"
|
||||||
|
"{%- if message['role'] == 'user' %}"
|
||||||
|
"{{- '<|sos|>' + message['content'] + '<|task_id|>' }}"
|
||||||
|
"{%- elif message['role'] == 'assistant' %}"
|
||||||
|
"{{- message['content']}}"
|
||||||
|
"{%- endif %}"
|
||||||
|
"{%- endfor %}"
|
||||||
|
)
|
||||||
|
tokenizer.chat_template = TEMPLATE
|
||||||
|
tokenizer.save_pretrained(args.save_path)
|
||||||
31
examples/grpo/cosyvoice2/requirements.txt
Normal file
31
examples/grpo/cosyvoice2/requirements.txt
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
conformer==0.3.2
|
||||||
|
diffusers==0.29.0
|
||||||
|
gdown==5.1.0
|
||||||
|
gradio
|
||||||
|
hydra-core==1.3.2
|
||||||
|
HyperPyYAML==1.2.2
|
||||||
|
inflect==7.3.1
|
||||||
|
librosa==0.10.2
|
||||||
|
lightning==2.2.4
|
||||||
|
matplotlib==3.7.5
|
||||||
|
modelscope==1.15.0
|
||||||
|
networkx==3.1
|
||||||
|
omegaconf==2.3.0
|
||||||
|
onnx==1.16.0
|
||||||
|
onnxruntime-gpu==1.18.0
|
||||||
|
protobuf==4.25
|
||||||
|
pydantic==2.7.0
|
||||||
|
pyworld==0.3.4
|
||||||
|
rich==13.7.1
|
||||||
|
soundfile==0.12.1
|
||||||
|
tensorboard==2.14.0
|
||||||
|
wget==3.2
|
||||||
|
WeTextProcessing==1.0.3
|
||||||
|
s3tokenizer
|
||||||
|
tensorrt
|
||||||
|
sherpa_onnx
|
||||||
|
jiwer
|
||||||
|
zhon
|
||||||
|
numpy==1.25.2
|
||||||
|
pypinyin
|
||||||
|
openai-whisper
|
||||||
233
examples/grpo/cosyvoice2/reward_tts.py
Normal file
233
examples/grpo/cosyvoice2/reward_tts.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Reward calculation for CosyVoice2-0.5B.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_ids(token_str: str) -> List[int]:
|
||||||
|
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
|
||||||
|
|
||||||
|
|
||||||
|
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
|
||||||
|
"""Send token IDs and ground-truth text to the Triton server and get reward."""
|
||||||
|
|
||||||
|
tokens_arr = np.array(tokens, dtype=np.int32).reshape(1, -1)
|
||||||
|
lens_arr = np.array([[tokens_arr.shape[1]]], dtype=np.int32)
|
||||||
|
|
||||||
|
gt_arr = np.array([ground_truth.encode("utf-8")], dtype=object)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "TOKENS",
|
||||||
|
"shape": list(tokens_arr.shape),
|
||||||
|
"datatype": "INT32",
|
||||||
|
"data": tokens_arr.tolist(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "TOKEN_LENS",
|
||||||
|
"shape": list(lens_arr.shape),
|
||||||
|
"datatype": "INT32",
|
||||||
|
"data": lens_arr.tolist(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "GT_TEXT",
|
||||||
|
"shape": [1, 1],
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"data": [ground_truth],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
rsp = requests.post(
|
||||||
|
REWARD_SERVER_URL,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json=payload,
|
||||||
|
timeout=timeout,
|
||||||
|
verify=False,
|
||||||
|
params={"request_id": "0"},
|
||||||
|
)
|
||||||
|
rsp.raise_for_status()
|
||||||
|
result = rsp.json()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Reward is returned as the first output
|
||||||
|
return float(result["outputs"][0]["data"][0])
|
||||||
|
except (KeyError, IndexError, TypeError):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def compute_score(
|
||||||
|
data_source: str,
|
||||||
|
solution_str: str,
|
||||||
|
ground_truth: str,
|
||||||
|
extra_info: dict | None = None,
|
||||||
|
*,
|
||||||
|
debug_dump: bool = False,
|
||||||
|
) -> float:
|
||||||
|
"""Return reward in [0, 1] using the Triton ASR service.
|
||||||
|
|
||||||
|
The reward is based on the pinyin-level WER between the ASR transcript
|
||||||
|
produced from *solution_str* and the provided *ground_truth* text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Decode token IDs
|
||||||
|
ids = _parse_ids(solution_str)
|
||||||
|
|
||||||
|
# Query remote server for reward
|
||||||
|
try:
|
||||||
|
reward = _remote_reward(ids, ground_truth)
|
||||||
|
except Exception as e:
|
||||||
|
reward = 0.0
|
||||||
|
|
||||||
|
if debug_dump:
|
||||||
|
print(
|
||||||
|
f"\033[92m[{data_source}] Remote reward: {reward:.4f}\033[0m"
|
||||||
|
)
|
||||||
|
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
# CLI quick test
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
"""Parse command line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Test TTS CER scoring with data from JSONL file",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--input", "-i",
|
||||||
|
type=str,
|
||||||
|
default="data/emilia_zh-cosy-tiny-test.jsonl",
|
||||||
|
help="Path to input JSONL file"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-samples", "-n",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Maximum number of samples to process (default: all)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-interactive",
|
||||||
|
action="store_true",
|
||||||
|
help="Run in non-interactive mode (process all samples without prompts)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable debug mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def load_jsonl(file_path: str):
|
||||||
|
"""Load data from jsonl file."""
|
||||||
|
data = []
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
data.append(json.loads(line.strip()))
|
||||||
|
return data
|
||||||
|
|
||||||
|
def code_to_solution_str(code_list: List[int]) -> str:
|
||||||
|
"""Convert code list to solution string format."""
|
||||||
|
return ''.join([f"<|s_{code}|>" for code in code_list])
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load data from jsonl file
|
||||||
|
print(f"Loading data from: {args.input}")
|
||||||
|
data_list = load_jsonl(args.input)
|
||||||
|
print(f"Loaded {len(data_list)} samples")
|
||||||
|
|
||||||
|
# Limit samples if specified
|
||||||
|
if args.max_samples is not None:
|
||||||
|
data_list = data_list[:args.max_samples]
|
||||||
|
print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
|
||||||
|
|
||||||
|
# Process each sample
|
||||||
|
begin_time = time.time()
|
||||||
|
for i, sample in enumerate(data_list):
|
||||||
|
print(f"\n--- Sample {i+1}/{len(data_list)} ---")
|
||||||
|
print(f"Index: {sample.get('index', 'unknown')}")
|
||||||
|
print(f"Text: {sample['text']}")
|
||||||
|
|
||||||
|
# Extract required fields
|
||||||
|
code_list = sample['code']
|
||||||
|
ground_truth = sample['text']
|
||||||
|
data_source = sample.get('index', f'sample_{i}') # Use index as data_source
|
||||||
|
|
||||||
|
# Convert code list to solution string
|
||||||
|
solution_str = code_to_solution_str(code_list)
|
||||||
|
print(f"Solution tokens: {len(code_list)} tokens")
|
||||||
|
if args.debug:
|
||||||
|
print(f"Solution string: {solution_str}")
|
||||||
|
else:
|
||||||
|
print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
|
||||||
|
|
||||||
|
# Call compute_score function
|
||||||
|
try:
|
||||||
|
score = compute_score(
|
||||||
|
data_source=data_source,
|
||||||
|
solution_str=solution_str,
|
||||||
|
ground_truth=ground_truth,
|
||||||
|
extra_info=None,
|
||||||
|
debug_dump=args.debug
|
||||||
|
)
|
||||||
|
print(f"Final Score: {score:.4f}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error computing score: {e}")
|
||||||
|
|
||||||
|
# Ask user if they want to continue (for interactive mode)
|
||||||
|
if not args.no_interactive and i < len(data_list) - 1:
|
||||||
|
try:
|
||||||
|
response = input("\nPress Enter to continue or 'q' to quit: ").strip().lower()
|
||||||
|
if response == 'q':
|
||||||
|
break
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nStopped by user")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(f"\nProcessed {min(i+1, len(data_list))} samples")
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"Time taken: {end_time - begin_time} seconds")
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: File not found - {args.input}")
|
||||||
|
print("Please check the file path or use --input to specify correct path")
|
||||||
|
print("Run with --help for usage information")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
159
examples/grpo/cosyvoice2/run.sh
Normal file
159
examples/grpo/cosyvoice2/run.sh
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=4
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
export PYTHONPATH=/workspace/CosyVoice
|
||||||
|
model_scope_model_path=./CosyVoice2-0.5B
|
||||||
|
sft_model_path=./transformers_cosyvoice2_llm
|
||||||
|
|
||||||
|
if [ $stage -le -2 ] && [ $stop_stage -ge -2 ]; then
|
||||||
|
log "stage -2: install dependencies locally if pre-built docker image is not available"
|
||||||
|
conda create -n cosyvoice2 python=3.10 -y
|
||||||
|
conda activate cosyvoice2
|
||||||
|
# install verl
|
||||||
|
git clone https://github.com/yuekaizhang/verl.git -b thread
|
||||||
|
cd verl
|
||||||
|
USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh
|
||||||
|
pip install --no-deps -e .
|
||||||
|
cd -
|
||||||
|
# install requirements
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install -U nvidia-pytriton
|
||||||
|
git clone https://github.com/yuekaizhang/PytritonSenseVoice.git && cd PytritonSenseVoice && pip install -e .
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
|
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
|
||||||
|
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
|
||||||
|
python3 pretrained_to_huggingface.py \
|
||||||
|
--pretrained-cosyvoice2-path $model_scope_model_path \
|
||||||
|
--save-path $sft_model_path
|
||||||
|
|
||||||
|
# Or, you could use the following command to download the huggingface compatible checkpoint
|
||||||
|
# huggingface-cli download --local-dir $sft_model_path yuekai/cosyvoice2_llm
|
||||||
|
|
||||||
|
# Note: we remove the lm_head's bias to make it compatible with the Qwen2.5-0.5B model in Transformers.
|
||||||
|
fi
|
||||||
|
|
||||||
|
data_dir=data/parquet_aishell3
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "stage 0: prepare data into verl format"
|
||||||
|
mkdir -p $data_dir
|
||||||
|
wget -O data/aishell-3.jsonl https://huggingface.co/datasets/SparkAudio/voxbox/resolve/main/metadata/aishell-3.jsonl
|
||||||
|
# total 88035 samples
|
||||||
|
head -n 80000 data/aishell-3.jsonl > data/train.jsonl
|
||||||
|
tail -n 100 data/aishell-3.jsonl > data/test.jsonl
|
||||||
|
python prepare_data.py \
|
||||||
|
--train_file data/train.jsonl \
|
||||||
|
--test_file data/test.jsonl \
|
||||||
|
--local_dir $data_dir
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "stage 1: start token2wav asr server for reward function"
|
||||||
|
python3 token2wav_asr_server.py --number-of-devices 8
|
||||||
|
fi
|
||||||
|
|
||||||
|
exp_name=official_llm_aishell3_grpo
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
log "stage 2: grpo train"
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||||
|
export MKL_SERVICE_FORCE_INTEL=TRUE
|
||||||
|
n_gpus_per_node=8
|
||||||
|
micro_batch_size=4
|
||||||
|
train_batch_size=32
|
||||||
|
python3 -m verl.trainer.main_ppo \
|
||||||
|
algorithm.adv_estimator=grpo \
|
||||||
|
data.train_files=$data_dir/train.parquet \
|
||||||
|
data.val_files=$data_dir/test.parquet \
|
||||||
|
data.train_batch_size=$train_batch_size \
|
||||||
|
data.max_prompt_length=1024 \
|
||||||
|
data.max_response_length=512 \
|
||||||
|
data.truncation='error' \
|
||||||
|
actor_rollout_ref.model.use_remove_padding=False \
|
||||||
|
actor_rollout_ref.model.path=$sft_model_path \
|
||||||
|
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||||
|
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
|
||||||
|
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \
|
||||||
|
actor_rollout_ref.actor.use_kl_loss=False \
|
||||||
|
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||||
|
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||||
|
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||||
|
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \
|
||||||
|
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||||
|
actor_rollout_ref.rollout.name=vllm \
|
||||||
|
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||||
|
actor_rollout_ref.rollout.do_sample=true \
|
||||||
|
actor_rollout_ref.rollout.temperature=0.8 \
|
||||||
|
actor_rollout_ref.rollout.top_p=0.95 \
|
||||||
|
actor_rollout_ref.rollout.top_k=25 \
|
||||||
|
actor_rollout_ref.rollout.n=4 \
|
||||||
|
actor_rollout_ref.rollout.val_kwargs.do_sample=true \
|
||||||
|
actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
|
||||||
|
actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
|
||||||
|
actor_rollout_ref.rollout.val_kwargs.top_k=25 \
|
||||||
|
reward_model.reward_manager=prime \
|
||||||
|
custom_reward_function.path=reward_tts.py \
|
||||||
|
custom_reward_function.name=compute_score \
|
||||||
|
trainer.project_name='cosyvoice2_grpo' \
|
||||||
|
trainer.experiment_name=$exp_name \
|
||||||
|
trainer.logger=['console','wandb'] \
|
||||||
|
trainer.n_gpus_per_node=$n_gpus_per_node \
|
||||||
|
trainer.nnodes=1 \
|
||||||
|
trainer.save_freq=100 \
|
||||||
|
trainer.test_freq=100 \
|
||||||
|
trainer.resume_mode='auto' \
|
||||||
|
trainer.total_epochs=1 \
|
||||||
|
trainer.val_before_train=False
|
||||||
|
fi
|
||||||
|
|
||||||
|
steps=(100 200 300 400 500)
|
||||||
|
for step in ${steps[@]}; do
|
||||||
|
llm_path=./checkpoints/cosyvoice2_grpo/$exp_name/global_step_${step}
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "stage 3: merge the model"
|
||||||
|
python -m verl.model_merger merge \
|
||||||
|
--backend fsdp \
|
||||||
|
--local_dir $llm_path/actor \
|
||||||
|
--target_dir $llm_path/merged_hf_model || exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "stage 4: Test the model"
|
||||||
|
dataset=zero_shot_zh # from CosyVoice3 test set
|
||||||
|
# dataset=test_zh # from seed_tts test set
|
||||||
|
output_dir=./outputs_${exp_name}_${step}_${dataset}
|
||||||
|
|
||||||
|
token2wav_path=/workspace/CosyVoice2-0.5B
|
||||||
|
model_path=$llm_path/merged_hf_model
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
torchrun --nproc_per_node=8 \
|
||||||
|
infer_dataset.py \
|
||||||
|
--output-dir $output_dir \
|
||||||
|
--llm-model-name-or-path $model_path \
|
||||||
|
--token2wav-path $token2wav_path \
|
||||||
|
--split-name ${dataset} || exit 1
|
||||||
|
|
||||||
|
bash scripts/compute_wer.sh $output_dir ${dataset}
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
|
log "stage 5: Convert the RL trained model to CosyVoice repo format"
|
||||||
|
python3 huggingface_to_pretrained.py \
|
||||||
|
--hf-cosyvoice2-llm-path $llm_path/merged_hf_model \
|
||||||
|
--output-path /workspace/CosyVoice2-0.5B/llm-new.pt
|
||||||
|
# You need to manually move the llm-new.pt to overwrite /workspace/CosyVoice2-0.5B/llm.pt
|
||||||
|
# However, we found that the RL trained model accuracy would slightly drop after this conversion.
|
||||||
|
# Please be careful or use the huggingface format inference code.
|
||||||
|
fi
|
||||||
33
examples/grpo/cosyvoice2/scripts/compute_wer.sh
Normal file
33
examples/grpo/cosyvoice2/scripts/compute_wer.sh
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
wav_dir=$1
|
||||||
|
wav_files=$(ls $wav_dir/*.wav)
|
||||||
|
# if wav_files is empty, then exit
|
||||||
|
if [ -z "$wav_files" ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
split_name=$2
|
||||||
|
model_path=models/sherpa-onnx-paraformer-zh-2023-09-14
|
||||||
|
|
||||||
|
if [ ! -d $model_path ]; then
|
||||||
|
pip install sherpa-onnx
|
||||||
|
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||||
|
mkdir models
|
||||||
|
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C models
|
||||||
|
fi
|
||||||
|
|
||||||
|
python3 scripts/offline-decode-files.py \
|
||||||
|
--tokens=$model_path/tokens.txt \
|
||||||
|
--paraformer=$model_path/model.int8.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=24000 \
|
||||||
|
--log-dir $wav_dir \
|
||||||
|
--feature-dim=80 \
|
||||||
|
--split-name $split_name \
|
||||||
|
--name sherpa_onnx \
|
||||||
|
$wav_files
|
||||||
|
|
||||||
|
# python3 scripts/paraformer-pytriton-client.py \
|
||||||
|
# --log-dir $wav_dir \
|
||||||
|
# --split-name $split_name \
|
||||||
|
# $wav_files
|
||||||
756
examples/grpo/cosyvoice2/scripts/offline-decode-files.py
Normal file
756
examples/grpo/cosyvoice2/scripts/offline-decode-files.py
Normal file
@@ -0,0 +1,756 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright (c) 2023 by manyeyes
|
||||||
|
# Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||||
|
file(s) with a non-streaming model.
|
||||||
|
|
||||||
|
(1) For paraformer
|
||||||
|
|
||||||
|
./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=/path/to/tokens.txt \
|
||||||
|
--paraformer=/path/to/paraformer.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=16000 \
|
||||||
|
--feature-dim=80 \
|
||||||
|
/path/to/0.wav \
|
||||||
|
/path/to/1.wav
|
||||||
|
|
||||||
|
(2) For transducer models from icefall
|
||||||
|
|
||||||
|
./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=/path/to/tokens.txt \
|
||||||
|
--encoder=/path/to/encoder.onnx \
|
||||||
|
--decoder=/path/to/decoder.onnx \
|
||||||
|
--joiner=/path/to/joiner.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=16000 \
|
||||||
|
--feature-dim=80 \
|
||||||
|
/path/to/0.wav \
|
||||||
|
/path/to/1.wav
|
||||||
|
|
||||||
|
(3) For CTC models from NeMo
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
|
||||||
|
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(4) For Whisper models
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||||
|
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||||
|
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||||
|
--whisper-task=transcribe \
|
||||||
|
--num-threads=1 \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(5) For CTC models from WeNet
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||||
|
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(6) For tdnn models of the yesno recipe from icefall
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--sample-rate=8000 \
|
||||||
|
--feature-dim=23 \
|
||||||
|
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||||
|
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
|
||||||
|
|
||||||
|
Please refer to
|
||||||
|
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||||
|
to install sherpa-onnx and to download non-streaming pre-trained models
|
||||||
|
used in this file.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import wave
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple, Dict, Iterable, TextIO, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
from datasets import load_dataset
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
import kaldialign
|
||||||
|
from zhon.hanzi import punctuation
|
||||||
|
import string
|
||||||
|
punctuation_all = punctuation + string.punctuation
|
||||||
|
Pathlike = Union[str, Path]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_punctuation(text: str) -> str:
|
||||||
|
for x in punctuation_all:
|
||||||
|
if x == '\'':
|
||||||
|
continue
|
||||||
|
text = text.replace(x, '')
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def store_transcripts(
|
||||||
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Save predicted results and reference transcripts to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
File to save the results to.
|
||||||
|
texts:
|
||||||
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
|
the reference transcript and the third element is the predicted result.
|
||||||
|
If it is a multi-talker ASR system, the ref and hyp may also be lists of
|
||||||
|
strings.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
with open(filename, "w", encoding="utf8") as f:
|
||||||
|
for cut_id, ref, hyp in texts:
|
||||||
|
if char_level:
|
||||||
|
ref = list("".join(ref))
|
||||||
|
hyp = list("".join(hyp))
|
||||||
|
print(f"{cut_id}:\tref={ref}", file=f)
|
||||||
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
|
||||||
|
|
||||||
|
def write_error_stats(
|
||||||
|
f: TextIO,
|
||||||
|
test_set_name: str,
|
||||||
|
results: List[Tuple[str, str]],
|
||||||
|
enable_log: bool = True,
|
||||||
|
compute_CER: bool = False,
|
||||||
|
sclite_mode: bool = False,
|
||||||
|
) -> float:
|
||||||
|
"""Write statistics based on predicted results and reference transcripts.
|
||||||
|
|
||||||
|
It will write the following to the given file:
|
||||||
|
|
||||||
|
- WER
|
||||||
|
- number of insertions, deletions, substitutions, corrects and total
|
||||||
|
reference words. For example::
|
||||||
|
|
||||||
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||||
|
reference words (2337 correct)
|
||||||
|
|
||||||
|
- The difference between the reference transcript and predicted result.
|
||||||
|
An instance is given below::
|
||||||
|
|
||||||
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||||
|
|
||||||
|
The above example shows that the reference word is `EDISON`,
|
||||||
|
but it is predicted to `ADDISON` (a substitution error).
|
||||||
|
|
||||||
|
Another example is::
|
||||||
|
|
||||||
|
FOR THE FIRST DAY (SIR->*) I THINK
|
||||||
|
|
||||||
|
The reference word `SIR` is missing in the predicted
|
||||||
|
results (a deletion error).
|
||||||
|
results:
|
||||||
|
An iterable of tuples. The first element is the cut_id, the second is
|
||||||
|
the reference transcript and the third element is the predicted result.
|
||||||
|
enable_log:
|
||||||
|
If True, also print detailed WER to the console.
|
||||||
|
Otherwise, it is written only to the given file.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||||
|
ins: Dict[str, int] = defaultdict(int)
|
||||||
|
dels: Dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# `words` stores counts per word, as follows:
|
||||||
|
# corr, ref_sub, hyp_sub, ins, dels
|
||||||
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||||
|
num_corr = 0
|
||||||
|
ERR = "*"
|
||||||
|
|
||||||
|
if compute_CER:
|
||||||
|
for i, res in enumerate(results):
|
||||||
|
cut_id, ref, hyp = res
|
||||||
|
ref = list("".join(ref))
|
||||||
|
hyp = list("".join(hyp))
|
||||||
|
results[i] = (cut_id, ref, hyp)
|
||||||
|
|
||||||
|
for cut_id, ref, hyp in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||||
|
for ref_word, hyp_word in ali:
|
||||||
|
if ref_word == ERR:
|
||||||
|
ins[hyp_word] += 1
|
||||||
|
words[hyp_word][3] += 1
|
||||||
|
elif hyp_word == ERR:
|
||||||
|
dels[ref_word] += 1
|
||||||
|
words[ref_word][4] += 1
|
||||||
|
elif hyp_word != ref_word:
|
||||||
|
subs[(ref_word, hyp_word)] += 1
|
||||||
|
words[ref_word][1] += 1
|
||||||
|
words[hyp_word][2] += 1
|
||||||
|
else:
|
||||||
|
words[ref_word][0] += 1
|
||||||
|
num_corr += 1
|
||||||
|
ref_len = sum([len(r) for _, r, _ in results])
|
||||||
|
sub_errs = sum(subs.values())
|
||||||
|
ins_errs = sum(ins.values())
|
||||||
|
del_errs = sum(dels.values())
|
||||||
|
tot_errs = sub_errs + ins_errs + del_errs
|
||||||
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||||
|
|
||||||
|
if enable_log:
|
||||||
|
logging.info(
|
||||||
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||||
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||||
|
f"{del_errs} del, {sub_errs} sub ]"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"%WER = {tot_err_rate}", file=f)
|
||||||
|
print(
|
||||||
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||||
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||||
|
f"words ({num_corr} correct)",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||||
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||||
|
for cut_id, ref, hyp in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
|
combine_successive_errors = True
|
||||||
|
if combine_successive_errors:
|
||||||
|
ali = [[[x], [y]] for x, y in ali]
|
||||||
|
for i in range(len(ali) - 1):
|
||||||
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||||
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||||
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||||
|
ali[i] = [[], []]
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
list(filter(lambda a: a != ERR, x)),
|
||||||
|
list(filter(lambda a: a != ERR, y)),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
ali = list(filter(lambda x: x != [[], []], ali))
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
ERR if x == [] else " ".join(x),
|
||||||
|
ERR if y == [] else " ".join(y),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{cut_id}:\t"
|
||||||
|
+ " ".join(
|
||||||
|
(
|
||||||
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||||
|
for ref_word, hyp_word in ali
|
||||||
|
)
|
||||||
|
),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||||
|
|
||||||
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||||
|
print(f"{count} {ref} -> {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("DELETIONS: count ref", file=f)
|
||||||
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||||
|
print(f"{count} {ref}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("INSERTIONS: count hyp", file=f)
|
||||||
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||||
|
print(f"{count} {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||||
|
for _, word, counts in sorted(
|
||||||
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||||
|
):
|
||||||
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||||
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||||
|
ref_count = corr + ref_sub + dels
|
||||||
|
hyp_count = corr + hyp_sub + ins
|
||||||
|
|
||||||
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="Path to tokens.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hotwords-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The file containing hotwords, one words/phrases per line, like
|
||||||
|
HELLO WORLD
|
||||||
|
你好世界
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hotwords-score",
|
||||||
|
type=float,
|
||||||
|
default=1.5,
|
||||||
|
help="""
|
||||||
|
The hotword score of each token for biasing word/phrase. Used only if
|
||||||
|
--hotwords-file is given.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--modeling-unit",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
|
||||||
|
Used only when hotwords-file is given.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-vocab",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path to the bpe vocabulary, the bpe vocabulary is generated by
|
||||||
|
sentencepiece, you can also export the bpe vocabulary through a bpe model
|
||||||
|
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
|
||||||
|
and modeling-unit is bpe or cjkchar+bpe.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the joiner model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--paraformer",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from Paraformer",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nemo-ctc",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from NeMo CTC",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--wenet-ctc",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from WeNet CTC",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tdnn-model",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx for the tdnn model of the yesno recipe",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-threads",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of threads for neural network computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to whisper encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to whisper decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-language",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="""It specifies the spoken language in the input audio file.
|
||||||
|
Example values: en, fr, de, zh, jp.
|
||||||
|
Available languages for multilingual models can be found at
|
||||||
|
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||||
|
If not specified, we infer the language from the input audio file.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-task",
|
||||||
|
default="transcribe",
|
||||||
|
choices=["transcribe", "translate"],
|
||||||
|
type=str,
|
||||||
|
help="""For multilingual models, if you specify translate, the output
|
||||||
|
will be in English.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-tail-paddings",
|
||||||
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="""Number of tail padding frames.
|
||||||
|
We have removed the 30-second constraint from whisper, so you need to
|
||||||
|
choose the amount of tail padding frames by yourself.
|
||||||
|
Use -1 to use a default value for tail padding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="Valid values are greedy_search and modified_beam_search",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="True to show debug messages",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="""Sample rate of the feature extractor. Must match the one
|
||||||
|
expected by the model. Note: The input sound files can have a
|
||||||
|
different sample rate from this argument.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--feature-dim",
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help="Feature dimension. Must match the one expected by the model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||||
|
"format with a single channel, and each sample has 16-bit, "
|
||||||
|
"i.e., int16_t. "
|
||||||
|
"The sample rate of the file can be arbitrary and does not need to "
|
||||||
|
"be 16 kHz",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--name",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The directory containing the input sound files to decode",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-dir",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The directory containing the input sound files to decode",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--label",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="wav_base_name label",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataset related arguments for loading labels when label file is not provided
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-name",
|
||||||
|
type=str,
|
||||||
|
default="yuekai/seed_tts_cosy2",
|
||||||
|
help="Huggingface dataset name for loading labels",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--split-name",
|
||||||
|
type=str,
|
||||||
|
default="wenetspeech4tts",
|
||||||
|
help="Dataset split name for loading labels",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def assert_file_exists(filename: str):
|
||||||
|
assert Path(filename).is_file(), (
|
||||||
|
f"{filename} does not exist!\n"
|
||||||
|
"Please refer to "
|
||||||
|
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
wave_filename:
|
||||||
|
Path to a wave file. It should be single channel and can be of type
|
||||||
|
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- A 1-D array of dtype np.float32 containing the samples,
|
||||||
|
which are normalized to the range [-1, 1].
|
||||||
|
- Sample rate of the wave file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
samples, sample_rate = sf.read(wave_filename, dtype="float32")
|
||||||
|
assert (
|
||||||
|
samples.ndim == 1
|
||||||
|
), f"Expected single channel, but got {samples.ndim} channels."
|
||||||
|
|
||||||
|
samples_float32 = samples.astype(np.float32)
|
||||||
|
|
||||||
|
return samples_float32, sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_text_alimeeting(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Text normalization similar to M2MeT challenge baseline.
|
||||||
|
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
text = text.replace('\u00A0', '') # test_hard
|
||||||
|
text = text.replace(" ", "")
|
||||||
|
text = text.replace("<sil>", "")
|
||||||
|
text = text.replace("<%>", "")
|
||||||
|
text = text.replace("<->", "")
|
||||||
|
text = text.replace("<$>", "")
|
||||||
|
text = text.replace("<#>", "")
|
||||||
|
text = text.replace("<_>", "")
|
||||||
|
text = text.replace("<space>", "")
|
||||||
|
text = text.replace("`", "")
|
||||||
|
text = text.replace("&", "")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
if re.search("[a-zA-Z]", text):
|
||||||
|
text = text.upper()
|
||||||
|
text = text.replace("A", "A")
|
||||||
|
text = text.replace("a", "A")
|
||||||
|
text = text.replace("b", "B")
|
||||||
|
text = text.replace("c", "C")
|
||||||
|
text = text.replace("k", "K")
|
||||||
|
text = text.replace("t", "T")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
text = text.replace("丶", "")
|
||||||
|
text = text.replace("。", "")
|
||||||
|
text = text.replace("、", "")
|
||||||
|
text = text.replace("?", "")
|
||||||
|
text = remove_punctuation(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
assert_file_exists(args.tokens)
|
||||||
|
assert args.num_threads > 0, args.num_threads
|
||||||
|
|
||||||
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
|
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||||
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
|
||||||
|
assert_file_exists(args.paraformer)
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||||
|
paraformer=args.paraformer,
|
||||||
|
tokens=args.tokens,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
feature_dim=args.feature_dim,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Started!")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
streams, results = [], []
|
||||||
|
total_duration = 0
|
||||||
|
|
||||||
|
for i, wave_filename in enumerate(args.sound_files):
|
||||||
|
assert_file_exists(wave_filename)
|
||||||
|
samples, sample_rate = read_wave(wave_filename)
|
||||||
|
duration = len(samples) / sample_rate
|
||||||
|
total_duration += duration
|
||||||
|
s = recognizer.create_stream()
|
||||||
|
s.accept_waveform(sample_rate, samples)
|
||||||
|
|
||||||
|
streams.append(s)
|
||||||
|
if i % 10 == 0:
|
||||||
|
recognizer.decode_streams(streams)
|
||||||
|
results += [s.result.text for s in streams]
|
||||||
|
streams = []
|
||||||
|
print(f"Processed {i} files")
|
||||||
|
# process the last batch
|
||||||
|
if streams:
|
||||||
|
recognizer.decode_streams(streams)
|
||||||
|
results += [s.result.text for s in streams]
|
||||||
|
end_time = time.time()
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
results_dict = {}
|
||||||
|
for wave_filename, result in zip(args.sound_files, results):
|
||||||
|
print(f"{wave_filename}\n{result}")
|
||||||
|
print("-" * 10)
|
||||||
|
wave_basename = Path(wave_filename).stem
|
||||||
|
results_dict[wave_basename] = result
|
||||||
|
|
||||||
|
elapsed_seconds = end_time - start_time
|
||||||
|
rtf = elapsed_seconds / total_duration
|
||||||
|
print(f"num_threads: {args.num_threads}")
|
||||||
|
print(f"decoding_method: {args.decoding_method}")
|
||||||
|
print(f"Wave duration: {total_duration:.3f} s")
|
||||||
|
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||||
|
print(
|
||||||
|
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load labels either from file or from dataset
|
||||||
|
labels_dict = {}
|
||||||
|
|
||||||
|
if args.label:
|
||||||
|
# Load labels from file (original functionality)
|
||||||
|
print(f"Loading labels from file: {args.label}")
|
||||||
|
with open(args.label, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
# fields = line.strip().split(" ")
|
||||||
|
# fields = [item for item in fields if item]
|
||||||
|
# assert len(fields) == 4
|
||||||
|
# prompt_text, prompt_audio, text, audio_path = fields
|
||||||
|
|
||||||
|
fields = line.strip().split("|")
|
||||||
|
fields = [item for item in fields if item]
|
||||||
|
assert len(fields) == 4
|
||||||
|
audio_path, prompt_text, prompt_audio, text = fields
|
||||||
|
labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
|
||||||
|
else:
|
||||||
|
# Load labels from dataset (new functionality)
|
||||||
|
print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
|
||||||
|
if 'zero' in args.split_name:
|
||||||
|
dataset_name = "yuekai/CV3-Eval"
|
||||||
|
else:
|
||||||
|
dataset_name = "yuekai/seed_tts_cosy2"
|
||||||
|
dataset = load_dataset(
|
||||||
|
dataset_name,
|
||||||
|
split=args.split_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
audio_id = item["id"]
|
||||||
|
labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
|
||||||
|
|
||||||
|
print(f"Loaded {len(labels_dict)} labels from dataset")
|
||||||
|
|
||||||
|
# Perform evaluation if labels are available
|
||||||
|
if labels_dict:
|
||||||
|
|
||||||
|
final_results = []
|
||||||
|
for key, value in results_dict.items():
|
||||||
|
if key in labels_dict:
|
||||||
|
final_results.append((key, labels_dict[key], value))
|
||||||
|
else:
|
||||||
|
print(f"Warning: No label found for {key}, skipping...")
|
||||||
|
|
||||||
|
if final_results:
|
||||||
|
store_transcripts(
|
||||||
|
filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
|
||||||
|
)
|
||||||
|
with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
|
||||||
|
write_error_stats(f, "test-set", final_results, enable_log=True)
|
||||||
|
|
||||||
|
with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
|
||||||
|
print(f.readline()) # WER
|
||||||
|
print(f.readline()) # Detailed errors
|
||||||
|
else:
|
||||||
|
print("No matching labels found for evaluation")
|
||||||
|
else:
|
||||||
|
print("No labels available for evaluation")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
346
examples/grpo/cosyvoice2/token2wav_asr_server.py
Normal file
346
examples/grpo/cosyvoice2/token2wav_asr_server.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Pytriton server for token2wav conversion and ASR"""
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||||
|
from omnisense.models import OmniSenseVoiceSmall
|
||||||
|
from pytriton.proxy.types import Request
|
||||||
|
from pytriton.triton import Triton, TritonConfig
|
||||||
|
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
|
||||||
|
from pytriton.decorators import batch
|
||||||
|
import argparse
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from typing import Any, List
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from scipy.signal import resample
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
from jiwer import wer
|
||||||
|
from pypinyin import lazy_pinyin, Style
|
||||||
|
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||||
|
|
||||||
|
# Chinese text normalizer (cached globally)
|
||||||
|
zh_tn_model = ZhNormalizer(
|
||||||
|
cache_dir="./cache",
|
||||||
|
remove_erhua=False,
|
||||||
|
remove_interjections=False,
|
||||||
|
remove_puncts=True,
|
||||||
|
overwrite_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
|
||||||
|
logger = logging.getLogger("token2wav_asr_server")
|
||||||
|
|
||||||
|
|
||||||
|
class _ASR_Server:
|
||||||
|
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
||||||
|
|
||||||
|
def __init__(self, device_id: int):
|
||||||
|
self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
|
||||||
|
|
||||||
|
@batch
|
||||||
|
def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray):
|
||||||
|
"""
|
||||||
|
WAV: np.ndarray, WAV_LENS: np.ndarray
|
||||||
|
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
|
||||||
|
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
|
||||||
|
"""
|
||||||
|
logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape)
|
||||||
|
wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))]
|
||||||
|
|
||||||
|
results = self._model.transcribe_single_batch(
|
||||||
|
wavs,
|
||||||
|
language="zh",
|
||||||
|
textnorm="woitn",
|
||||||
|
)
|
||||||
|
texts = [result.text for result in results]
|
||||||
|
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
||||||
|
return {"TRANSCRIPTS": transcripts}
|
||||||
|
|
||||||
|
|
||||||
|
def audio_decode_cosyvoice2(
|
||||||
|
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate audio from tokens with optional tone and prompt embedding.
|
||||||
|
"""
|
||||||
|
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
|
||||||
|
"empty", prompt_text, prompt_speech_16k, 24000
|
||||||
|
)
|
||||||
|
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||||
|
token=audio_tokens.to(codec_decoder.model.device),
|
||||||
|
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_token_len=torch.tensor(
|
||||||
|
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
|
||||||
|
).to(codec_decoder.model.device),
|
||||||
|
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
|
||||||
|
finalize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||||
|
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio_hat
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_prompt_from_dataset(dataset):
|
||||||
|
"""
|
||||||
|
Get random prompt text and speech from the pre-loaded dataset.
|
||||||
|
Returns (prompt_text, prompt_speech_16k)
|
||||||
|
"""
|
||||||
|
random_idx = random.randint(0, len(dataset) - 1)
|
||||||
|
sample = dataset[random_idx]
|
||||||
|
|
||||||
|
# Extract audio data
|
||||||
|
audio_data = sample["audio"]
|
||||||
|
audio_array = audio_data["array"]
|
||||||
|
sample_rate = audio_data["sampling_rate"]
|
||||||
|
|
||||||
|
# Convert audio to 16kHz if needed
|
||||||
|
if sample_rate != 16000:
|
||||||
|
num_samples = int(len(audio_array) * (16000 / sample_rate))
|
||||||
|
audio_array = resample(audio_array, num_samples)
|
||||||
|
|
||||||
|
# Convert to torch tensor
|
||||||
|
prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0)
|
||||||
|
prompt_text = sample["text"]
|
||||||
|
# remove space in prompt_text
|
||||||
|
prompt_text = prompt_text.replace(" ", "")
|
||||||
|
return prompt_text, prompt_speech_16k
|
||||||
|
|
||||||
|
|
||||||
|
class _Token2Wav_ASR:
|
||||||
|
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
||||||
|
|
||||||
|
def __init__(self, device_id: int):
|
||||||
|
self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
|
||||||
|
self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"]
|
||||||
|
|
||||||
|
# Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model
|
||||||
|
# CosyVoice2 internally uses generic "cuda" device, so we first switch the
|
||||||
|
# current CUDA context to the desired card before the object is created.
|
||||||
|
# Afterwards, all parameters loaded with the generic "cuda" device will
|
||||||
|
# reside on this GPU. We keep the selected id in `self.device_id` and
|
||||||
|
# will set the context again for every forward call to avoid race
|
||||||
|
# conditions when several instances are used in the same process.
|
||||||
|
|
||||||
|
self.device_id = device_id
|
||||||
|
|
||||||
|
# Construct the TTS codec decoder under the correct CUDA device context
|
||||||
|
with torch.cuda.device(self.device_id):
|
||||||
|
self.codec_decoder = CosyVoice2(
|
||||||
|
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@batch
|
||||||
|
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
|
||||||
|
"""
|
||||||
|
WAV: np.ndarray, WAV_LENS: np.ndarray
|
||||||
|
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
|
||||||
|
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
|
||||||
|
"""
|
||||||
|
# Ensure the default CUDA device is set correctly for this invocation
|
||||||
|
torch.cuda.set_device(self.device_id)
|
||||||
|
|
||||||
|
if self.device_id == 0:
|
||||||
|
print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")
|
||||||
|
|
||||||
|
tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]
|
||||||
|
|
||||||
|
# Decode ground-truth text strings (BYTES → str)
|
||||||
|
if GT_TEXT.ndim == 2:
|
||||||
|
gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
|
||||||
|
else:
|
||||||
|
gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]
|
||||||
|
|
||||||
|
wavs = []
|
||||||
|
for tokens in tokens_list:
|
||||||
|
prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
|
||||||
|
audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
|
||||||
|
audio_hat = audio_decode_cosyvoice2(
|
||||||
|
audio_tokens,
|
||||||
|
prompt_text,
|
||||||
|
prompt_speech_16k,
|
||||||
|
self.codec_decoder,
|
||||||
|
)
|
||||||
|
# resample to 16000 using soundfile
|
||||||
|
audio_hat = audio_hat.squeeze(0).float().cpu()
|
||||||
|
audio_hat = audio_hat.numpy()
|
||||||
|
num_samples = int(len(audio_hat) * (16000 / 24000))
|
||||||
|
audio_hat = resample(audio_hat, num_samples)
|
||||||
|
wavs.append(audio_hat)
|
||||||
|
|
||||||
|
results = self.asr_model.transcribe_single_batch(
|
||||||
|
wavs,
|
||||||
|
language="zh",
|
||||||
|
textnorm="woitn",
|
||||||
|
)
|
||||||
|
texts = [result.text for result in results]
|
||||||
|
|
||||||
|
# ---------------- Reward computation ----------------
|
||||||
|
rewards = []
|
||||||
|
for gt_text, hyp_text in zip(gt_texts, texts):
|
||||||
|
gt_norm = zh_tn_model.normalize(gt_text).lower()
|
||||||
|
hyp_norm = zh_tn_model.normalize(hyp_text).lower()
|
||||||
|
|
||||||
|
gt_pinyin = lazy_pinyin(
|
||||||
|
gt_norm,
|
||||||
|
style=Style.TONE3,
|
||||||
|
tone_sandhi=True,
|
||||||
|
neutral_tone_with_five=True,
|
||||||
|
)
|
||||||
|
hyp_pinyin = lazy_pinyin(
|
||||||
|
hyp_norm,
|
||||||
|
style=Style.TONE3,
|
||||||
|
tone_sandhi=True,
|
||||||
|
neutral_tone_with_five=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin)))
|
||||||
|
reward_val = 1.0 - np.tanh(3.0 * c)
|
||||||
|
reward_val = max(0.0, min(1.0, reward_val))
|
||||||
|
rewards.append(reward_val)
|
||||||
|
print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}")
|
||||||
|
|
||||||
|
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
||||||
|
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
|
||||||
|
|
||||||
|
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_function_factory(device_ids: List[int], model_name: str):
|
||||||
|
"""Creates a list of inference functions, one for each requested device ID."""
|
||||||
|
infer_funcs = []
|
||||||
|
for device_id in device_ids:
|
||||||
|
if model_name == "sensevoice":
|
||||||
|
infer_funcs.append(_ASR_Server(device_id=device_id))
|
||||||
|
else:
|
||||||
|
infer_funcs.append(_Token2Wav_ASR(device_id=device_id))
|
||||||
|
return infer_funcs
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Batch size of request.",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--number-of-instances-per-device",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of model instances to load.",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--number-of-devices",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of devices to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default="token2wav_asr",
|
||||||
|
choices=["token2wav_asr", "sensevoice"],
|
||||||
|
help="Model name.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
log_level = logging.DEBUG if args.verbose else logging.INFO
|
||||||
|
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
|
||||||
|
|
||||||
|
triton_config = TritonConfig(
|
||||||
|
http_port=8000,
|
||||||
|
grpc_port=8001,
|
||||||
|
metrics_port=8002,
|
||||||
|
)
|
||||||
|
|
||||||
|
device_ids = [i for i in range(args.number_of_devices)]
|
||||||
|
device_ids = device_ids * args.number_of_instances_per_device
|
||||||
|
|
||||||
|
with Triton(config=triton_config) as triton:
|
||||||
|
logger.info("Loading SenseVoice model on device ids: %s", device_ids)
|
||||||
|
if args.model_name == "sensevoice":
|
||||||
|
triton.bind(
|
||||||
|
model_name="sensevoice",
|
||||||
|
infer_func=_infer_function_factory(device_ids, args.model_name),
|
||||||
|
inputs=[
|
||||||
|
Tensor(name="WAV", dtype=np.float32, shape=(-1,)),
|
||||||
|
Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)),
|
||||||
|
Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)),
|
||||||
|
Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
|
||||||
|
],
|
||||||
|
config=ModelConfig(
|
||||||
|
max_batch_size=args.max_batch_size,
|
||||||
|
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
|
||||||
|
),
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
triton.bind(
|
||||||
|
model_name="token2wav_asr",
|
||||||
|
infer_func=_infer_function_factory(device_ids, args.model_name),
|
||||||
|
inputs=[
|
||||||
|
Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)),
|
||||||
|
Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)),
|
||||||
|
Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)),
|
||||||
|
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
|
||||||
|
],
|
||||||
|
config=ModelConfig(
|
||||||
|
max_batch_size=args.max_batch_size,
|
||||||
|
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
|
||||||
|
),
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
logger.info("Serving inference")
|
||||||
|
triton.serve()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user