Merge pull request #1583 from FunAudioLLM/dev/lyuxiang.lx

Dev/lyuxiang.lx
This commit is contained in:
Xiang Lyu
2025-09-17 10:57:18 +08:00
committed by GitHub
13 changed files with 2379 additions and 1 deletions

View File

@@ -31,7 +31,7 @@
- [x] 2025/08 - [x] 2025/08
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
- [x] 2025/07 - [x] 2025/07

View File

@@ -0,0 +1,6 @@
FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
COPY requirements.txt /myworkspace/requirements.txt
RUN pip install -r /myworkspace/requirements.txt
RUN pip install -U nvidia-pytriton
RUN git clone https://github.com/yuekaizhang/verl.git /myworkspace/verl -b thread && cd /myworkspace/verl && pip install --no-deps -e .
RUN git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /myworkspace/PytritonSenseVoice && cd /myworkspace/PytritonSenseVoice && pip install -e .

View File

@@ -0,0 +1,125 @@
# CosyVoice2 LLM Reinforcement Learning Recipe
This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%.
## Table of Contents
- [Environment Setup](#environment-setup)
- [Data Preparation](#data-preparation)
- [Reward Function & ASR Server](#reward-function--asr-server)
- [Training](#training)
- [Evaluation](#evaluation)
- [Export Model](#export-model)
- [Results](#results)
- [Acknowledgement](#acknowledgement)
## Environment Setup
We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile.
```bash
docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
```
If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally.
## Data Preparation
`prepare_data.py` expects a JSON/JSONL file with at least the following schema:
```jsonc
{
"text": "An example sentence to be synthesized."
}
```
You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face.
Stage `0` converts raw JSONL files into the parquet format expected by veRL:
```bash
bash run.sh 0 0
```
Create two JSONL files—`train.jsonl` and `test.jsonl`.
The script will then generate two Parquet files:
```
data/parquet_tiny/train.parquet
data/parquet_tiny/test.parquet
```
Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens.
## Reward Function & ASR Server
To compute rewards, we run a lightweight server that:
1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model.
2. Transcribes the waveform with **SenseVoice** ASR.
3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1.
Start the server (stage `1`) in a dedicated terminal or on a separate GPU:
```bash
bash run.sh 1 1
# Triton server listens on ports 8000/8001/8002
```
The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
## Training
Run stage `2` to start GRPO training:
```bash
bash run.sh 2 2
```
Key CLI arguments passed to `verl.trainer.main_ppo`:
* `algorithm.adv_estimator=grpo` use GRPO instead of PPO.
* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet`
* `custom_reward_function.path=reward_tts.py` custom reward function described above.
Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware.
> [!TIP]
> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model.
## Evaluation
After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`):
```bash
bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model
```
You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`):
```bash
bash run.sh 4 4
```
This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`.
> [!TIP]
> The script also supports the Seed-TTS test set by setting `dataset=test_zh`.
## Export Model
To use the RL-trained model with the official CosyVoice repository:
```bash
bash run.sh 5 5
```
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
> [!TIP]
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
## Results
| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment |
|-------|------------------------|------------------------------|---------|
| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) |
| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model |
## Acknowledgement
This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo).

View File

@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python3 hf2pretrained.py --hf-cosyvoice2-llm-path /workspace/rl-exp/checkpoint-400 --output-path /workspace/CosyVoice2-0.5B/llm-new.pt
"""
from argparse import ArgumentParser
import torch
from safetensors import safe_open
from transformers import AutoTokenizer
def get_args():
parser = ArgumentParser()
parser.add_argument(
"--hf-cosyvoice2-llm-path",
type=str,
default=None,
help="The RL trained CosyVoice2 model path in HuggingFace format",
)
parser.add_argument(
"--output-path",
type=str,
default="./llm.pt",
help="The path to save the llm.pt",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
tokenizer = AutoTokenizer.from_pretrained(args.hf_cosyvoice2_llm_path)
speech_start_idx = tokenizer.convert_tokens_to_ids("<|s_0|>")
cosyvoice2_token_size = 6561 + 3
llm_embedding_vocab_size = 2
hf_tensors = {}
with safe_open(f"{args.hf_cosyvoice2_llm_path}/model.safetensors", framework="pt", device="cpu") as f:
for k in f.keys():
if k.startswith("lm_head.bias"):
# RL trained model disable bias for lm_head
continue
new_k = "llm.model." + k
hf_tensors[new_k] = f.get_tensor(k)
if k.startswith("lm_head"):
hf_tensors["llm_decoder.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
hf_tensors["llm_decoder.bias"] = torch.zeros_like(hf_tensors["llm_decoder.weight"][:, 0])
if k.startswith("model.embed_tokens"):
hf_tensors["speech_embedding.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
hf_tensors["llm_embedding.weight"] = f.get_tensor(k)[speech_start_idx + cosyvoice2_token_size:speech_start_idx + cosyvoice2_token_size + llm_embedding_vocab_size]
# use tie_word_embeddings=True
hf_tensors["llm.model.model.embed_tokens.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"][:151936]
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
torch.save(hf_tensors, args.output_path)

View File

@@ -0,0 +1,397 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
dataset=zero_shot_zh
output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts
token2wav_path=/workspace/CosyVoice2-0.5B
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --nproc_per_node=8 \
infer_dataset.py \
--output-dir $output_dir \
--llm-model-name-or-path $llm_path/merged_hf_model \
--token2wav-path $token2wav_path \
--split-name ${dataset} || exit 1
"""
import argparse
import json
import os
import sys
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from cosyvoice.cli.cosyvoice import CosyVoice2
from cosyvoice.utils.file_utils import load_wav
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
import soundfile as sf
import s3tokenizer
from functools import partial
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"
def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
):
"""
Generate audio from tokens with optional tone and prompt embedding.
"""
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
"empty", prompt_text, prompt_speech_16k, 24000
)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
codec_decoder.model.device
),
prompt_token_len=torch.tensor(
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
codec_decoder.model.device
),
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
codec_decoder.model.device
),
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
finalize=True,
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def extract_speech_ids(speech_tokens_str):
"""Extract speech IDs from token strings like <|s_23456|>"""
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith('<|s_') and token_str.endswith('|>'):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
speech_id_str = ""
for token in cosy2_tokens:
speech_id_str += f"<|s_{token}|>"
return speech_id_str
def get_args():
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch-size",
default=1,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num-workers", type=int, default=1, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=5, help="prefetch for dataloader"
)
parser.add_argument(
"--llm-model-name-or-path",
required=True,
type=str,
help="LLM model path (includes both model and tokenizer)",
)
parser.add_argument(
"--token2wav-path",
required=True,
type=str,
help="CosyVoice2 token2wav model path",
)
parser.add_argument(
"--prompt-text",
type=str,
default=None,
help="The prompt text for CosyVoice2",
)
parser.add_argument(
"--prompt-speech-path",
type=str,
default=None,
help="The path to the prompt speech for CosyVoice2",
)
parser.add_argument(
"--top-p",
type=float,
default=0.95,
help="top p for sampling",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="temperature for sampling",
)
parser.add_argument(
"--top-k",
type=int,
default=50,
help="top k for sampling",
)
args = parser.parse_args()
return args
def data_collator(batch, tokenizer, s3_tokenizer):
"""Simplified data collator for batch_size=1 processing"""
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
mels, prompt_audio_cosy2tokens_list = [], []
for item in batch:
prompt_text, target_text = (
item["prompt_text"],
item["target_text"],
)
prompt_text_list.append(prompt_text)
# Combine prompt and target text
full_text = prompt_text + target_text
# get prompt audio for CosyVoice2 (convert to 16kHz)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
# ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
print(ref_audio_org.shape)
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
prompt_audio_list.append(ref_audio)
if "prompt_audio_cosy2_tokens" in item:
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
else:
# convert to float first
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
if len(mels) > 0:
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
for i in range(len(codes)):
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
# Create chat template for LLM generation
chat = [
{"role": "user", "content": full_text},
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
]
if 'system' in tokenizer.chat_template:
tokenizer.chat_template = TEMPLATE
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
)
input_ids_list.append(input_ids.squeeze(0))
# For batch_size=1, no need to pad
if len(input_ids_list) == 1:
input_ids = input_ids_list[0].unsqueeze(0)
else:
# Handle batch > 1 if needed
max_len = max([len(input_ids) for input_ids in input_ids_list])
input_ids_list = [
torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
for input_ids in input_ids_list
]
input_ids = torch.stack(input_ids_list)
ids = [item["id"] for item in batch]
return {
"input_ids": input_ids,
"ids": ids,
"prompt_text": prompt_text_list,
"prompt_audio_list": prompt_audio_list,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
# Load LLM model and tokenizer directly
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
model.eval()
model.to(device)
cosyvoice_codec = CosyVoice2(
args.token2wav_path, load_jit=True, load_trt=True, fp16=True
)
if args.prompt_speech_path:
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
else:
prompt_speech_16k = None
s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
dataset = load_dataset(
dataset_name,
split=args.split_name,
trust_remote_code=True,
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
)
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
for batch in dataloader:
with torch.no_grad():
input_ids = batch["input_ids"].to(device)
# Generate speech tokens using LLM
outputs = model.generate(
input_ids,
max_new_tokens=2048, # Max length for generation
do_sample=True,
top_p=args.top_p,
temperature=args.temperature,
top_k=args.top_k,
)
# Process each sample in the batch
for i in range(len(batch["ids"])):
# Extract generated tokens (excluding input)
input_length = input_ids[i].shape[0]
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Extract speech IDs from token strings like <|s_23456|>
speech_ids = extract_speech_ids(speech_tokens_str)
if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
continue
# Convert to tensor for CosyVoice2
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
if args.prompt_text is not None:
current_prompt_text = args.prompt_text
current_prompt_audio = prompt_speech_16k
else:
current_prompt_text = batch["prompt_text"][i]
current_prompt_audio = batch["prompt_audio_list"][i]
if current_prompt_audio is not None:
# Generate audio using CosyVoice2
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
current_prompt_text,
current_prompt_audio,
cosyvoice_codec,
)
# Convert to numpy and save
generated_wave = audio_hat.squeeze(0).cpu().numpy()
target_sample_rate = 24000
utt = batch["ids"][i]
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
else:
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
if rank == 0:
progress_bar.close()
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,86 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the Text to Speech dataset to parquet format
"""
import argparse
import os
import re
import datasets
from verl.utils.hdfs_io import copy, makedirs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
parser.add_argument("--test_file", required=True, help="Path to test JSON/JSONL file")
parser.add_argument("--local_dir", default=None, required=True)
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
# Load datasets from local JSON files
train_dataset = datasets.load_dataset("json", data_files=args.train_file)['train']
test_dataset = datasets.load_dataset("json", data_files=args.test_file)['train']
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
text = example.pop("text")
# use cosyvoice2 official huggingface compatible checkpoint template
question = text
answer = ""
data = {
"data_source": f"{args.train_file}_{args.test_file}", # Use file names as data source
"prompt": [
{
"role": "user",
"content": question,
},
{
"role": "assistant",
"content": answer,
},
],
"ability": "text-to-speech",
"reward_model": {"style": "rule", "ground_truth": text},
"extra_info": {
"split": split,
"index": idx,
"text": text,
},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
print(train_dataset)
print(test_dataset)
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)

View File

@@ -0,0 +1,135 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage: Instruct TTS
python3 infer.py \
--token2wav-path /workspace/CosyVoice2-0.5B \
--prompt-text "吃燕窝就选燕之屋本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝营养更均衡本节目由豆本豆豆奶特约播出。" \
--prompt-speech-path ./assets/prompt_audio.wav \
--model-path ./transformers_cosyvoice2_llm \
--input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。"
"""
from cosyvoice.cli.cosyvoice import CosyVoice2
import sys
from argparse import ArgumentParser
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def get_args():
parser = ArgumentParser()
parser.add_argument(
"--pretrained-cosyvoice2-path",
type=str,
default="/workspace/CosyVoice2-0.5B",
help="Token2Wav path, default to %(default)r",
)
parser.add_argument(
"--save-path",
type=str,
default='./transformers_cosyvoice2_llm',
help="The path to save the model",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
cosy2_model = CosyVoice2(
args.pretrained_cosyvoice2_path, load_jit=False, load_trt=False, fp16=False
)
llm = cosy2_model.model.llm.llm.model
speech_embedding = cosy2_model.model.llm.speech_embedding
llm_decoder = cosy2_model.model.llm.llm_decoder
llm_embedding = cosy2_model.model.llm.llm_embedding
tokenizer = AutoTokenizer.from_pretrained(f"{args.pretrained_cosyvoice2_path}/CosyVoice-BlankEN")
special_tokens = {
'eos_token': '<|endoftext|>',
'pad_token': '<|endoftext|>',
'additional_special_tokens': [
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
'[breath]', '<strong>', '</strong>', '[noise]',
'[laughter]', '[cough]', '[clucking]', '[accent]',
'[quick_breath]',
"<laughter>", "</laughter>",
"[hissing]", "[sigh]", "[vocalized-noise]",
"[lipsmack]", "[mn]"
]
}
tokenizer.add_special_tokens(special_tokens)
original_tokenizer_vocab_size = len(tokenizer)
cosyvoice2_token_size = 6561
new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
"<|eos1|>", "<|eos2|>", "<|eos3|>", "<|sos|>", "<|task_id|>"
]
num_added_tokens = tokenizer.add_tokens(new_tokens)
llm.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128)
vocab_size = llm.get_input_embeddings().weight.shape[0]
feature_size = speech_embedding.embedding_dim
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=vocab_size, bias=True)
with torch.no_grad():
# set the weight and bias of the new lm_head to 0
new_lm_head.weight.data.zero_()
# make bias value -inf
new_lm_head.bias.data.fill_(-float('inf'))
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
llm.lm_head = new_lm_head
input_embeddings = llm.get_input_embeddings()
with torch.no_grad():
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
llm.generation_config.eos_token_id = eos_token_ids
llm.generation_config.temperature = 1.0
llm.generation_config.top_p = 0.8
llm.generation_config.top_k = 25
llm.config.eos_token_id = original_tokenizer_vocab_size + cosyvoice2_token_size
llm.config.vocab_size = vocab_size
llm.config.tie_word_embeddings = False
llm.config.use_bias = True
llm.to(torch.bfloat16)
llm.save_pretrained(args.save_path)
TEMPLATE = (
"{%- for message in messages %}"
"{%- if message['role'] == 'user' %}"
"{{- '<|sos|>' + message['content'] + '<|task_id|>' }}"
"{%- elif message['role'] == 'assistant' %}"
"{{- message['content']}}"
"{%- endif %}"
"{%- endfor %}"
)
tokenizer.chat_template = TEMPLATE
tokenizer.save_pretrained(args.save_path)

View File

@@ -0,0 +1,31 @@
conformer==0.3.2
diffusers==0.29.0
gdown==5.1.0
gradio
hydra-core==1.3.2
HyperPyYAML==1.2.2
inflect==7.3.1
librosa==0.10.2
lightning==2.2.4
matplotlib==3.7.5
modelscope==1.15.0
networkx==3.1
omegaconf==2.3.0
onnx==1.16.0
onnxruntime-gpu==1.18.0
protobuf==4.25
pydantic==2.7.0
pyworld==0.3.4
rich==13.7.1
soundfile==0.12.1
tensorboard==2.14.0
wget==3.2
WeTextProcessing==1.0.3
s3tokenizer
tensorrt
sherpa_onnx
jiwer
zhon
numpy==1.25.2
pypinyin
openai-whisper

View File

@@ -0,0 +1,233 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reward calculation for CosyVoice2-0.5B.
"""
from __future__ import annotations
import re
import json
import time
import argparse
from typing import List
import numpy as np
import requests
REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
def _parse_ids(token_str: str) -> List[int]:
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
"""Send token IDs and ground-truth text to the Triton server and get reward."""
tokens_arr = np.array(tokens, dtype=np.int32).reshape(1, -1)
lens_arr = np.array([[tokens_arr.shape[1]]], dtype=np.int32)
gt_arr = np.array([ground_truth.encode("utf-8")], dtype=object)
payload = {
"inputs": [
{
"name": "TOKENS",
"shape": list(tokens_arr.shape),
"datatype": "INT32",
"data": tokens_arr.tolist(),
},
{
"name": "TOKEN_LENS",
"shape": list(lens_arr.shape),
"datatype": "INT32",
"data": lens_arr.tolist(),
},
{
"name": "GT_TEXT",
"shape": [1, 1],
"datatype": "BYTES",
"data": [ground_truth],
},
]
}
rsp = requests.post(
REWARD_SERVER_URL,
headers={"Content-Type": "application/json"},
json=payload,
timeout=timeout,
verify=False,
params={"request_id": "0"},
)
rsp.raise_for_status()
result = rsp.json()
try:
# Reward is returned as the first output
return float(result["outputs"][0]["data"][0])
except (KeyError, IndexError, TypeError):
return 0.0
def compute_score(
data_source: str,
solution_str: str,
ground_truth: str,
extra_info: dict | None = None,
*,
debug_dump: bool = False,
) -> float:
"""Return reward in [0, 1] using the Triton ASR service.
The reward is based on the pinyin-level WER between the ASR transcript
produced from *solution_str* and the provided *ground_truth* text.
"""
# Decode token IDs
ids = _parse_ids(solution_str)
# Query remote server for reward
try:
reward = _remote_reward(ids, ground_truth)
except Exception as e:
reward = 0.0
if debug_dump:
print(
f"\033[92m[{data_source}] Remote reward: {reward:.4f}\033[0m"
)
return reward
# CLI quick test
if __name__ == "__main__":
import sys
def get_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Test TTS CER scoring with data from JSONL file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input", "-i",
type=str,
default="data/emilia_zh-cosy-tiny-test.jsonl",
help="Path to input JSONL file"
)
parser.add_argument(
"--max-samples", "-n",
type=int,
default=None,
help="Maximum number of samples to process (default: all)"
)
parser.add_argument(
"--no-interactive",
action="store_true",
help="Run in non-interactive mode (process all samples without prompts)"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode"
)
return parser.parse_args()
def load_jsonl(file_path: str):
"""Load data from jsonl file."""
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line.strip()))
return data
def code_to_solution_str(code_list: List[int]) -> str:
"""Convert code list to solution string format."""
return ''.join([f"<|s_{code}|>" for code in code_list])
# Parse command line arguments
args = get_args()
try:
# Load data from jsonl file
print(f"Loading data from: {args.input}")
data_list = load_jsonl(args.input)
print(f"Loaded {len(data_list)} samples")
# Limit samples if specified
if args.max_samples is not None:
data_list = data_list[:args.max_samples]
print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
# Process each sample
begin_time = time.time()
for i, sample in enumerate(data_list):
print(f"\n--- Sample {i+1}/{len(data_list)} ---")
print(f"Index: {sample.get('index', 'unknown')}")
print(f"Text: {sample['text']}")
# Extract required fields
code_list = sample['code']
ground_truth = sample['text']
data_source = sample.get('index', f'sample_{i}') # Use index as data_source
# Convert code list to solution string
solution_str = code_to_solution_str(code_list)
print(f"Solution tokens: {len(code_list)} tokens")
if args.debug:
print(f"Solution string: {solution_str}")
else:
print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
# Call compute_score function
try:
score = compute_score(
data_source=data_source,
solution_str=solution_str,
ground_truth=ground_truth,
extra_info=None,
debug_dump=args.debug
)
print(f"Final Score: {score:.4f}")
except Exception as e:
print(f"Error computing score: {e}")
# Ask user if they want to continue (for interactive mode)
if not args.no_interactive and i < len(data_list) - 1:
try:
response = input("\nPress Enter to continue or 'q' to quit: ").strip().lower()
if response == 'q':
break
except KeyboardInterrupt:
print("\nStopped by user")
break
print(f"\nProcessed {min(i+1, len(data_list))} samples")
end_time = time.time()
print(f"Time taken: {end_time - begin_time} seconds")
except FileNotFoundError:
print(f"Error: File not found - {args.input}")
print("Please check the file path or use --input to specify correct path")
print("Run with --help for usage information")
except Exception as e:
print(f"Error: {e}")

View File

@@ -0,0 +1,159 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=4
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export PYTHONPATH=/workspace/CosyVoice
model_scope_model_path=./CosyVoice2-0.5B
sft_model_path=./transformers_cosyvoice2_llm
if [ $stage -le -2 ] && [ $stop_stage -ge -2 ]; then
log "stage -2: install dependencies locally if pre-built docker image is not available"
conda create -n cosyvoice2 python=3.10 -y
conda activate cosyvoice2
# install verl
git clone https://github.com/yuekaizhang/verl.git -b thread
cd verl
USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh
pip install --no-deps -e .
cd -
# install requirements
pip install -r requirements.txt
pip install -U nvidia-pytriton
git clone https://github.com/yuekaizhang/PytritonSenseVoice.git && cd PytritonSenseVoice && pip install -e .
fi
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
python3 pretrained_to_huggingface.py \
--pretrained-cosyvoice2-path $model_scope_model_path \
--save-path $sft_model_path
# Or, you could use the following command to download the huggingface compatible checkpoint
# huggingface-cli download --local-dir $sft_model_path yuekai/cosyvoice2_llm
# Note: we remove the lm_head's bias to make it compatible with the Qwen2.5-0.5B model in Transformers.
fi
data_dir=data/parquet_aishell3
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: prepare data into verl format"
mkdir -p $data_dir
wget -O data/aishell-3.jsonl https://huggingface.co/datasets/SparkAudio/voxbox/resolve/main/metadata/aishell-3.jsonl
# total 88035 samples
head -n 80000 data/aishell-3.jsonl > data/train.jsonl
tail -n 100 data/aishell-3.jsonl > data/test.jsonl
python prepare_data.py \
--train_file data/train.jsonl \
--test_file data/test.jsonl \
--local_dir $data_dir
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: start token2wav asr server for reward function"
python3 token2wav_asr_server.py --number-of-devices 8
fi
exp_name=official_llm_aishell3_grpo
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "stage 2: grpo train"
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
export MKL_SERVICE_FORCE_INTEL=TRUE
n_gpus_per_node=8
micro_batch_size=4
train_batch_size=32
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$data_dir/train.parquet \
data.val_files=$data_dir/test.parquet \
data.train_batch_size=$train_batch_size \
data.max_prompt_length=1024 \
data.max_response_length=512 \
data.truncation='error' \
actor_rollout_ref.model.use_remove_padding=False \
actor_rollout_ref.model.path=$sft_model_path \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.do_sample=true \
actor_rollout_ref.rollout.temperature=0.8 \
actor_rollout_ref.rollout.top_p=0.95 \
actor_rollout_ref.rollout.top_k=25 \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.rollout.val_kwargs.do_sample=true \
actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
actor_rollout_ref.rollout.val_kwargs.top_k=25 \
reward_model.reward_manager=prime \
custom_reward_function.path=reward_tts.py \
custom_reward_function.name=compute_score \
trainer.project_name='cosyvoice2_grpo' \
trainer.experiment_name=$exp_name \
trainer.logger=['console','wandb'] \
trainer.n_gpus_per_node=$n_gpus_per_node \
trainer.nnodes=1 \
trainer.save_freq=100 \
trainer.test_freq=100 \
trainer.resume_mode='auto' \
trainer.total_epochs=1 \
trainer.val_before_train=False
fi
steps=(100 200 300 400 500)
for step in ${steps[@]}; do
llm_path=./checkpoints/cosyvoice2_grpo/$exp_name/global_step_${step}
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: merge the model"
python -m verl.model_merger merge \
--backend fsdp \
--local_dir $llm_path/actor \
--target_dir $llm_path/merged_hf_model || exit 1
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: Test the model"
dataset=zero_shot_zh # from CosyVoice3 test set
# dataset=test_zh # from seed_tts test set
output_dir=./outputs_${exp_name}_${step}_${dataset}
token2wav_path=/workspace/CosyVoice2-0.5B
model_path=$llm_path/merged_hf_model
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --nproc_per_node=8 \
infer_dataset.py \
--output-dir $output_dir \
--llm-model-name-or-path $model_path \
--token2wav-path $token2wav_path \
--split-name ${dataset} || exit 1
bash scripts/compute_wer.sh $output_dir ${dataset}
fi
done
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "stage 5: Convert the RL trained model to CosyVoice repo format"
python3 huggingface_to_pretrained.py \
--hf-cosyvoice2-llm-path $llm_path/merged_hf_model \
--output-path /workspace/CosyVoice2-0.5B/llm-new.pt
# You need to manually move the llm-new.pt to overwrite /workspace/CosyVoice2-0.5B/llm.pt
# However, we found that the RL trained model accuracy would slightly drop after this conversion.
# Please be careful or use the huggingface format inference code.
fi

View File

@@ -0,0 +1,33 @@
wav_dir=$1
wav_files=$(ls $wav_dir/*.wav)
# if wav_files is empty, then exit
if [ -z "$wav_files" ]; then
exit 1
fi
split_name=$2
model_path=models/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
pip install sherpa-onnx
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
mkdir models
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C models
fi
python3 scripts/offline-decode-files.py \
--tokens=$model_path/tokens.txt \
--paraformer=$model_path/model.int8.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=24000 \
--log-dir $wav_dir \
--feature-dim=80 \
--split-name $split_name \
--name sherpa_onnx \
$wav_files
# python3 scripts/paraformer-pytriton-client.py \
# --log-dir $wav_dir \
# --split-name $split_name \
# $wav_files

View File

@@ -0,0 +1,756 @@
#!/usr/bin/env python3
#
# Copyright (c) 2023 by manyeyes
# Copyright (c) 2023 Xiaomi Corporation
"""
This file demonstrates how to use sherpa-onnx Python API to transcribe
file(s) with a non-streaming model.
(1) For paraformer
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/paraformer.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=16000 \
--feature-dim=80 \
/path/to/0.wav \
/path/to/1.wav
(2) For transducer models from icefall
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=16000 \
--feature-dim=80 \
/path/to/0.wav \
/path/to/1.wav
(3) For CTC models from NeMo
python3 ./python-api-examples/offline-decode-files.py \
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
(4) For Whisper models
python3 ./python-api-examples/offline-decode-files.py \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--whisper-task=transcribe \
--num-threads=1 \
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
(5) For CTC models from WeNet
python3 ./python-api-examples/offline-decode-files.py \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
(6) For tdnn models of the yesno recipe from icefall
python3 ./python-api-examples/offline-decode-files.py \
--sample-rate=8000 \
--feature-dim=23 \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download non-streaming pre-trained models
used in this file.
"""
import argparse
import time
import wave
from pathlib import Path
from typing import List, Tuple, Dict, Iterable, TextIO, Union
import numpy as np
import sherpa_onnx
import soundfile as sf
from datasets import load_dataset
import logging
from collections import defaultdict
import kaldialign
from zhon.hanzi import punctuation
import string
punctuation_all = punctuation + string.punctuation
Pathlike = Union[str, Path]
def remove_punctuation(text: str) -> str:
for x in punctuation_all:
if x == '\'':
continue
text = text.replace(x, '')
return text
def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None:
"""Save predicted results and reference transcripts to a file.
Args:
filename:
File to save the results to.
texts:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
If it is a multi-talker ASR system, the ref and hyp may also be lists of
strings.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf8") as f:
for cut_id, ref, hyp in texts:
if char_level:
ref = list("".join(ref))
hyp = list("".join(hyp))
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
def write_error_stats(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
compute_CER: bool = False,
sclite_mode: bool = False,
) -> float:
"""Write statistics based on predicted results and reference transcripts.
It will write the following to the given file:
- WER
- number of insertions, deletions, substitutions, corrects and total
reference words. For example::
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct)
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
FOR THE FIRST DAY (SIR->*) I THINK
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the cut_id, the second is
the reference transcript and the third element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return None.
"""
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
if compute_CER:
for i, res in enumerate(results):
cut_id, ref, hyp = res
ref = list("".join(ref))
hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
f"{cut_id}:\t"
+ " ".join(
(
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--hotwords-file",
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, like
HELLO WORLD
你好世界
""",
)
parser.add_argument(
"--hotwords-score",
type=float,
default=1.5,
help="""
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
parser.add_argument(
"--modeling-unit",
type=str,
default="",
help="""
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
Used only when hotwords-file is given.
""",
)
parser.add_argument(
"--bpe-vocab",
type=str,
default="",
help="""
The path to the bpe vocabulary, the bpe vocabulary is generated by
sentencepiece, you can also export the bpe vocabulary through a bpe model
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
and modeling-unit is bpe or cjkchar+bpe.
""",
)
parser.add_argument(
"--encoder",
default="",
type=str,
help="Path to the encoder model",
)
parser.add_argument(
"--decoder",
default="",
type=str,
help="Path to the decoder model",
)
parser.add_argument(
"--joiner",
default="",
type=str,
help="Path to the joiner model",
)
parser.add_argument(
"--paraformer",
default="",
type=str,
help="Path to the model.onnx from Paraformer",
)
parser.add_argument(
"--nemo-ctc",
default="",
type=str,
help="Path to the model.onnx from NeMo CTC",
)
parser.add_argument(
"--wenet-ctc",
default="",
type=str,
help="Path to the model.onnx from WeNet CTC",
)
parser.add_argument(
"--tdnn-model",
default="",
type=str,
help="Path to the model.onnx for the tdnn model of the yesno recipe",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--whisper-encoder",
default="",
type=str,
help="Path to whisper encoder model",
)
parser.add_argument(
"--whisper-decoder",
default="",
type=str,
help="Path to whisper decoder model",
)
parser.add_argument(
"--whisper-language",
default="",
type=str,
help="""It specifies the spoken language in the input audio file.
Example values: en, fr, de, zh, jp.
Available languages for multilingual models can be found at
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
If not specified, we infer the language from the input audio file.
""",
)
parser.add_argument(
"--whisper-task",
default="transcribe",
choices=["transcribe", "translate"],
type=str,
help="""For multilingual models, if you specify translate, the output
will be in English.
""",
)
parser.add_argument(
"--whisper-tail-paddings",
default=-1,
type=int,
help="""Number of tail padding frames.
We have removed the 30-second constraint from whisper, so you need to
choose the amount of tail padding frames by yourself.
Use -1 to use a default value for tail padding.
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
parser.add_argument(
"--debug",
type=bool,
default=False,
help="True to show debug messages",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="""Sample rate of the feature extractor. Must match the one
expected by the model. Note: The input sound files can have a
different sample rate from this argument.""",
)
parser.add_argument(
"--feature-dim",
type=int,
default=80,
help="Feature dimension. Must match the one expected by the model",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to decode. Each file must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)
parser.add_argument(
"--name",
type=str,
default="",
help="The directory containing the input sound files to decode",
)
parser.add_argument(
"--log-dir",
type=str,
default="",
help="The directory containing the input sound files to decode",
)
parser.add_argument(
"--label",
type=str,
default=None,
help="wav_base_name label",
)
# Dataset related arguments for loading labels when label file is not provided
parser.add_argument(
"--dataset-name",
type=str,
default="yuekai/seed_tts_cosy2",
help="Huggingface dataset name for loading labels",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
help="Dataset split name for loading labels",
)
return parser.parse_args()
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and can be of type
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples,
which are normalized to the range [-1, 1].
- Sample rate of the wave file.
"""
samples, sample_rate = sf.read(wave_filename, dtype="float32")
assert (
samples.ndim == 1
), f"Expected single channel, but got {samples.ndim} channels."
samples_float32 = samples.astype(np.float32)
return samples_float32, sample_rate
def normalize_text_alimeeting(text: str) -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
import re
text = text.replace('\u00A0', '') # test_hard
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = remove_punctuation(text)
return text
def main():
args = get_args()
assert_file_exists(args.tokens)
assert args.num_threads > 0, args.num_threads
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.paraformer)
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=args.paraformer,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
print("Started!")
start_time = time.time()
streams, results = [], []
total_duration = 0
for i, wave_filename in enumerate(args.sound_files):
assert_file_exists(wave_filename)
samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate
total_duration += duration
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
streams.append(s)
if i % 10 == 0:
recognizer.decode_streams(streams)
results += [s.result.text for s in streams]
streams = []
print(f"Processed {i} files")
# process the last batch
if streams:
recognizer.decode_streams(streams)
results += [s.result.text for s in streams]
end_time = time.time()
print("Done!")
results_dict = {}
for wave_filename, result in zip(args.sound_files, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
wave_basename = Path(wave_filename).stem
results_dict[wave_basename] = result
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {total_duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
# Load labels either from file or from dataset
labels_dict = {}
if args.label:
# Load labels from file (original functionality)
print(f"Loading labels from file: {args.label}")
with open(args.label, "r") as f:
for line in f:
# fields = line.strip().split(" ")
# fields = [item for item in fields if item]
# assert len(fields) == 4
# prompt_text, prompt_audio, text, audio_path = fields
fields = line.strip().split("|")
fields = [item for item in fields if item]
assert len(fields) == 4
audio_path, prompt_text, prompt_audio, text = fields
labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
else:
# Load labels from dataset (new functionality)
print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
if 'zero' in args.split_name:
dataset_name = "yuekai/CV3-Eval"
else:
dataset_name = "yuekai/seed_tts_cosy2"
dataset = load_dataset(
dataset_name,
split=args.split_name,
trust_remote_code=True,
)
for item in dataset:
audio_id = item["id"]
labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
print(f"Loaded {len(labels_dict)} labels from dataset")
# Perform evaluation if labels are available
if labels_dict:
final_results = []
for key, value in results_dict.items():
if key in labels_dict:
final_results.append((key, labels_dict[key], value))
else:
print(f"Warning: No label found for {key}, skipping...")
if final_results:
store_transcripts(
filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
)
with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
write_error_stats(f, "test-set", final_results, enable_log=True)
with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
print(f.readline()) # WER
print(f.readline()) # Detailed errors
else:
print("No matching labels found for evaluation")
else:
print("No labels available for evaluation")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,346 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pytriton server for token2wav conversion and ASR"""
from datasets import load_dataset
from cosyvoice.cli.cosyvoice import CosyVoice2
from omnisense.models import OmniSenseVoiceSmall
from pytriton.proxy.types import Request
from pytriton.triton import Triton, TritonConfig
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
from pytriton.decorators import batch
import argparse
import io
import logging
from typing import Any, List
import numpy as np
import torch
from scipy.signal import resample
import sys
import random
import re
from jiwer import wer
from pypinyin import lazy_pinyin, Style
from tn.chinese.normalizer import Normalizer as ZhNormalizer
# Chinese text normalizer (cached globally)
zh_tn_model = ZhNormalizer(
cache_dir="./cache",
remove_erhua=False,
remove_interjections=False,
remove_puncts=True,
overwrite_cache=True,
)
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
logger = logging.getLogger("token2wav_asr_server")
class _ASR_Server:
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
def __init__(self, device_id: int):
self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
@batch
def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray):
"""
WAV: np.ndarray, WAV_LENS: np.ndarray
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
"""
logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape)
wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))]
results = self._model.transcribe_single_batch(
wavs,
language="zh",
textnorm="woitn",
)
texts = [result.text for result in results]
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
return {"TRANSCRIPTS": transcripts}
def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
):
"""
Generate audio from tokens with optional tone and prompt embedding.
"""
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
"empty", prompt_text, prompt_speech_16k, 24000
)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
codec_decoder.model.device
),
prompt_token_len=torch.tensor(
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
codec_decoder.model.device
),
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
codec_decoder.model.device
),
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
finalize=True,
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def get_random_prompt_from_dataset(dataset):
"""
Get random prompt text and speech from the pre-loaded dataset.
Returns (prompt_text, prompt_speech_16k)
"""
random_idx = random.randint(0, len(dataset) - 1)
sample = dataset[random_idx]
# Extract audio data
audio_data = sample["audio"]
audio_array = audio_data["array"]
sample_rate = audio_data["sampling_rate"]
# Convert audio to 16kHz if needed
if sample_rate != 16000:
num_samples = int(len(audio_array) * (16000 / sample_rate))
audio_array = resample(audio_array, num_samples)
# Convert to torch tensor
prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0)
prompt_text = sample["text"]
# remove space in prompt_text
prompt_text = prompt_text.replace(" ", "")
return prompt_text, prompt_speech_16k
class _Token2Wav_ASR:
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
def __init__(self, device_id: int):
self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"]
# Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model
# CosyVoice2 internally uses generic "cuda" device, so we first switch the
# current CUDA context to the desired card before the object is created.
# Afterwards, all parameters loaded with the generic "cuda" device will
# reside on this GPU. We keep the selected id in `self.device_id` and
# will set the context again for every forward call to avoid race
# conditions when several instances are used in the same process.
self.device_id = device_id
# Construct the TTS codec decoder under the correct CUDA device context
with torch.cuda.device(self.device_id):
self.codec_decoder = CosyVoice2(
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
)
@batch
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
"""
WAV: np.ndarray, WAV_LENS: np.ndarray
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
"""
# Ensure the default CUDA device is set correctly for this invocation
torch.cuda.set_device(self.device_id)
if self.device_id == 0:
print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")
tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]
# Decode ground-truth text strings (BYTES → str)
if GT_TEXT.ndim == 2:
gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
else:
gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]
wavs = []
for tokens in tokens_list:
prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
prompt_text,
prompt_speech_16k,
self.codec_decoder,
)
# resample to 16000 using soundfile
audio_hat = audio_hat.squeeze(0).float().cpu()
audio_hat = audio_hat.numpy()
num_samples = int(len(audio_hat) * (16000 / 24000))
audio_hat = resample(audio_hat, num_samples)
wavs.append(audio_hat)
results = self.asr_model.transcribe_single_batch(
wavs,
language="zh",
textnorm="woitn",
)
texts = [result.text for result in results]
# ---------------- Reward computation ----------------
rewards = []
for gt_text, hyp_text in zip(gt_texts, texts):
gt_norm = zh_tn_model.normalize(gt_text).lower()
hyp_norm = zh_tn_model.normalize(hyp_text).lower()
gt_pinyin = lazy_pinyin(
gt_norm,
style=Style.TONE3,
tone_sandhi=True,
neutral_tone_with_five=True,
)
hyp_pinyin = lazy_pinyin(
hyp_norm,
style=Style.TONE3,
tone_sandhi=True,
neutral_tone_with_five=True,
)
c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin)))
reward_val = 1.0 - np.tanh(3.0 * c)
reward_val = max(0.0, min(1.0, reward_val))
rewards.append(reward_val)
print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}")
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
def _infer_function_factory(device_ids: List[int], model_name: str):
"""Creates a list of inference functions, one for each requested device ID."""
infer_funcs = []
for device_id in device_ids:
if model_name == "sensevoice":
infer_funcs.append(_ASR_Server(device_id=device_id))
else:
infer_funcs.append(_Token2Wav_ASR(device_id=device_id))
return infer_funcs
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--max-batch-size",
type=int,
default=32,
help="Batch size of request.",
required=False,
)
parser.add_argument(
"--verbose",
action="store_true",
default=False,
)
parser.add_argument(
"--number-of-instances-per-device",
type=int,
default=1,
help="Number of model instances to load.",
required=False,
)
parser.add_argument(
"--number-of-devices",
type=int,
default=8,
help="Number of devices to use.",
)
parser.add_argument(
"--model-name",
type=str,
default="token2wav_asr",
choices=["token2wav_asr", "sensevoice"],
help="Model name.",
)
args = parser.parse_args()
log_level = logging.DEBUG if args.verbose else logging.INFO
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
triton_config = TritonConfig(
http_port=8000,
grpc_port=8001,
metrics_port=8002,
)
device_ids = [i for i in range(args.number_of_devices)]
device_ids = device_ids * args.number_of_instances_per_device
with Triton(config=triton_config) as triton:
logger.info("Loading SenseVoice model on device ids: %s", device_ids)
if args.model_name == "sensevoice":
triton.bind(
model_name="sensevoice",
infer_func=_infer_function_factory(device_ids, args.model_name),
inputs=[
Tensor(name="WAV", dtype=np.float32, shape=(-1,)),
Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)),
Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)),
Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)),
],
outputs=[
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
],
config=ModelConfig(
max_batch_size=args.max_batch_size,
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
),
strict=True,
)
else:
triton.bind(
model_name="token2wav_asr",
infer_func=_infer_function_factory(device_ids, args.model_name),
inputs=[
Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)),
Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)),
Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)),
],
outputs=[
Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)),
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
],
config=ModelConfig(
max_batch_size=args.max_batch_size,
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
),
strict=True,
)
logger.info("Serving inference")
triton.serve()
if __name__ == "__main__":
main()