diff --git a/README.md b/README.md index ef7f6ab..1d32e44 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ - [x] 2025/08 - - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support + - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support - [x] 2025/07 @@ -246,6 +246,17 @@ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /o cd fastapi && python3 client.py --port 50000 --mode ``` +#### Using Nvidia TensorRT-LLM for deployment + +Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation. +To quick start: + +``` sh +cd runtime/triton_trtllm +docker compose up -d +``` +For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm) + ## Discussion & Communication You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues). diff --git a/docker/Dockerfile b/docker/Dockerfile index d7faf03..8cefdba 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -46,6 +46,6 @@ RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5 RUN conda activate ${VENV} && cd CosyVoice && \ - pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com + pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir WORKDIR /workspace/CosyVoice diff --git a/examples/grpo/cosyvoice2/Dockerfile b/examples/grpo/cosyvoice2/Dockerfile new file mode 100644 index 0000000..17d80ed --- /dev/null +++ b/examples/grpo/cosyvoice2/Dockerfile @@ -0,0 +1,6 @@ +FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2 +COPY requirements.txt /myworkspace/requirements.txt +RUN pip install -r /myworkspace/requirements.txt +RUN pip install -U nvidia-pytriton +RUN git clone https://github.com/yuekaizhang/verl.git /myworkspace/verl -b thread && cd /myworkspace/verl && pip install --no-deps -e . +RUN git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /myworkspace/PytritonSenseVoice && cd /myworkspace/PytritonSenseVoice && pip install -e . \ No newline at end of file diff --git a/examples/grpo/cosyvoice2/README.md b/examples/grpo/cosyvoice2/README.md new file mode 100644 index 0000000..8783aa1 --- /dev/null +++ b/examples/grpo/cosyvoice2/README.md @@ -0,0 +1,125 @@ +# CosyVoice2 LLM Reinforcement Learning Recipe + +This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%. + +## Table of Contents + +- [Environment Setup](#environment-setup) +- [Data Preparation](#data-preparation) +- [Reward Function & ASR Server](#reward-function--asr-server) +- [Training](#training) +- [Evaluation](#evaluation) +- [Export Model](#export-model) +- [Results](#results) +- [Acknowledgement](#acknowledgement) + +## Environment Setup +We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile. +```bash +docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2 +``` +If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally. + +## Data Preparation + +`prepare_data.py` expects a JSON/JSONL file with at least the following schema: + +```jsonc +{ + "text": "An example sentence to be synthesized." +} +``` +You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face. + +Stage `0` converts raw JSONL files into the parquet format expected by veRL: + +```bash +bash run.sh 0 0 +``` +Create two JSONL files—`train.jsonl` and `test.jsonl`. +The script will then generate two Parquet files: + +``` +data/parquet_tiny/train.parquet +data/parquet_tiny/test.parquet +``` + +Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens. + + +## Reward Function & ASR Server + +To compute rewards, we run a lightweight server that: + +1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model. +2. Transcribes the waveform with **SenseVoice** ASR. +3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1. + +Start the server (stage `1`) in a dedicated terminal or on a separate GPU: + +```bash +bash run.sh 1 1 +# Triton server listens on ports 8000/8001/8002 +``` + +The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score. + +## Training + +Run stage `2` to start GRPO training: + +```bash +bash run.sh 2 2 +``` + +Key CLI arguments passed to `verl.trainer.main_ppo`: + +* `algorithm.adv_estimator=grpo` – use GRPO instead of PPO. +* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet` +* `custom_reward_function.path=reward_tts.py` – custom reward function described above. + +Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware. +> [!TIP] +> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model. + +## Evaluation + +After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`): + +```bash +bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model +``` + +You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`): + +```bash +bash run.sh 4 4 +``` + +This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`. + +> [!TIP] +> The script also supports the Seed-TTS test set by setting `dataset=test_zh`. + +## Export Model + +To use the RL-trained model with the official CosyVoice repository: + +```bash +bash run.sh 5 5 +``` + +The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository. +> [!TIP] +> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format. + +## Results + +| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment | +|-------|------------------------|------------------------------|---------| +| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) | +| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model | + +## Acknowledgement + +This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo). diff --git a/examples/grpo/cosyvoice2/huggingface_to_pretrained.py b/examples/grpo/cosyvoice2/huggingface_to_pretrained.py new file mode 100644 index 0000000..ca49fc3 --- /dev/null +++ b/examples/grpo/cosyvoice2/huggingface_to_pretrained.py @@ -0,0 +1,71 @@ + +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +python3 hf2pretrained.py --hf-cosyvoice2-llm-path /workspace/rl-exp/checkpoint-400 --output-path /workspace/CosyVoice2-0.5B/llm-new.pt +""" +from argparse import ArgumentParser +import torch +from safetensors import safe_open +from transformers import AutoTokenizer + + +def get_args(): + parser = ArgumentParser() + + parser.add_argument( + "--hf-cosyvoice2-llm-path", + type=str, + default=None, + help="The RL trained CosyVoice2 model path in HuggingFace format", + ) + parser.add_argument( + "--output-path", + type=str, + default="./llm.pt", + help="The path to save the llm.pt", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + + tokenizer = AutoTokenizer.from_pretrained(args.hf_cosyvoice2_llm_path) + speech_start_idx = tokenizer.convert_tokens_to_ids("<|s_0|>") + cosyvoice2_token_size = 6561 + 3 + llm_embedding_vocab_size = 2 + + hf_tensors = {} + with safe_open(f"{args.hf_cosyvoice2_llm_path}/model.safetensors", framework="pt", device="cpu") as f: + for k in f.keys(): + if k.startswith("lm_head.bias"): + # RL trained model disable bias for lm_head + continue + new_k = "llm.model." + k + hf_tensors[new_k] = f.get_tensor(k) + if k.startswith("lm_head"): + hf_tensors["llm_decoder.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size] + hf_tensors["llm_decoder.bias"] = torch.zeros_like(hf_tensors["llm_decoder.weight"][:, 0]) + if k.startswith("model.embed_tokens"): + hf_tensors["speech_embedding.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size] + hf_tensors["llm_embedding.weight"] = f.get_tensor(k)[speech_start_idx + cosyvoice2_token_size:speech_start_idx + cosyvoice2_token_size + llm_embedding_vocab_size] + + # use tie_word_embeddings=True + hf_tensors["llm.model.model.embed_tokens.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"][:151936] + hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"] + + torch.save(hf_tensors, args.output_path) diff --git a/examples/grpo/cosyvoice2/infer_dataset.py b/examples/grpo/cosyvoice2/infer_dataset.py new file mode 100644 index 0000000..4dcbc96 --- /dev/null +++ b/examples/grpo/cosyvoice2/infer_dataset.py @@ -0,0 +1,397 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage +dataset=zero_shot_zh +output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts + +token2wav_path=/workspace/CosyVoice2-0.5B +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +torchrun --nproc_per_node=8 \ + infer_dataset.py \ + --output-dir $output_dir \ + --llm-model-name-or-path $llm_path/merged_hf_model \ + --token2wav-path $token2wav_path \ + --split-name ${dataset} || exit 1 +""" + +import argparse +import json +import os +import sys +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torchaudio +from cosyvoice.cli.cosyvoice import CosyVoice2 +from cosyvoice.utils.file_utils import load_wav +from datasets import load_dataset +from transformers import AutoTokenizer, AutoModelForCausalLM +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm import tqdm +import soundfile as sf +import s3tokenizer +from functools import partial + +sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") +try: + torch.multiprocessing.set_start_method("spawn") +except RuntimeError: + pass + + +TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" + + +def audio_decode_cosyvoice2( + audio_tokens, prompt_text, prompt_speech_16k, codec_decoder +): + """ + Generate audio from tokens with optional tone and prompt embedding. + """ + model_inputs_dict = codec_decoder.frontend.frontend_zero_shot( + "empty", prompt_text, prompt_speech_16k, 24000 + ) + tts_mel, _ = codec_decoder.model.flow.inference( + token=audio_tokens.to(codec_decoder.model.device), + token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to( + codec_decoder.model.device + ), + prompt_token=model_inputs_dict["flow_prompt_speech_token"].to( + codec_decoder.model.device + ), + prompt_token_len=torch.tensor( + [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32 + ).to(codec_decoder.model.device), + prompt_feat=model_inputs_dict["prompt_speech_feat"].to( + codec_decoder.model.device + ), + prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to( + codec_decoder.model.device + ), + embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device), + finalize=True, + ) + + audio_hat, _ = codec_decoder.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) + + return audio_hat + + +def extract_speech_ids(speech_tokens_str): + """Extract speech IDs from token strings like <|s_23456|>""" + speech_ids = [] + for token_str in speech_tokens_str: + if token_str.startswith('<|s_') and token_str.endswith('|>'): + num_str = token_str[4:-2] + num = int(num_str) + speech_ids.append(num) + else: + print(f"Unexpected token: {token_str}") + return speech_ids + + +def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens): + """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>""" + speech_id_str = "" + for token in cosy2_tokens: + speech_id_str += f"<|s_{token}|>" + return speech_id_str + + +def get_args(): + parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2") + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2", + ) + parser.add_argument( + "--output-dir", required=True, type=str, help="dir to save result" + ) + parser.add_argument( + "--batch-size", + default=1, + type=int, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--num-workers", type=int, default=1, help="workers for dataloader" + ) + parser.add_argument( + "--prefetch", type=int, default=5, help="prefetch for dataloader" + ) + parser.add_argument( + "--llm-model-name-or-path", + required=True, + type=str, + help="LLM model path (includes both model and tokenizer)", + ) + parser.add_argument( + "--token2wav-path", + required=True, + type=str, + help="CosyVoice2 token2wav model path", + ) + parser.add_argument( + "--prompt-text", + type=str, + default=None, + help="The prompt text for CosyVoice2", + ) + parser.add_argument( + "--prompt-speech-path", + type=str, + default=None, + help="The path to the prompt speech for CosyVoice2", + ) + parser.add_argument( + "--top-p", + type=float, + default=0.95, + help="top p for sampling", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="temperature for sampling", + ) + parser.add_argument( + "--top-k", + type=int, + default=50, + help="top k for sampling", + ) + args = parser.parse_args() + return args + + +def data_collator(batch, tokenizer, s3_tokenizer): + """Simplified data collator for batch_size=1 processing""" + target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio + device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu") + input_ids_list, prompt_audio_list, prompt_text_list = [], [], [] + mels, prompt_audio_cosy2tokens_list = [], [] + for item in batch: + prompt_text, target_text = ( + item["prompt_text"], + item["target_text"], + ) + prompt_text_list.append(prompt_text) + # Combine prompt and target text + full_text = prompt_text + target_text + + # get prompt audio for CosyVoice2 (convert to 16kHz) + ref_audio_org, ref_sr = ( + item["prompt_audio"]["array"], + item["prompt_audio"]["sampling_rate"], + ) + ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0) + # ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True) + print(ref_audio_org.shape) + + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + + prompt_audio_list.append(ref_audio) + + if "prompt_audio_cosy2_tokens" in item: + prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"] + prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens) + else: + # convert to float first + mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0))) + + if len(mels) > 0: + mels, mels_lens = s3tokenizer.padding(mels) + codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device)) + for i in range(len(codes)): + prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()]) + for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list: + prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens) + # Create chat template for LLM generation + chat = [ + {"role": "user", "content": full_text}, + {"role": "assistant", "content": prompt_audio_cosy2_id_str} + ] + if 'system' in tokenizer.chat_template: + tokenizer.chat_template = TEMPLATE + input_ids = tokenizer.apply_chat_template( + chat, + tokenize=True, + return_tensors='pt', + continue_final_message=True + ) + input_ids_list.append(input_ids.squeeze(0)) + + # For batch_size=1, no need to pad + if len(input_ids_list) == 1: + input_ids = input_ids_list[0].unsqueeze(0) + else: + # Handle batch > 1 if needed + max_len = max([len(input_ids) for input_ids in input_ids_list]) + input_ids_list = [ + torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids]) + for input_ids in input_ids_list + ] + input_ids = torch.stack(input_ids_list) + + ids = [item["id"] for item in batch] + + return { + "input_ids": input_ids, + "ids": ids, + "prompt_text": prompt_text_list, + "prompt_audio_list": prompt_audio_list, + } + + +def init_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + print( + "Inference on multiple gpus, this gpu {}".format(local_rank) + + ", rank {}, world_size {}".format(rank, world_size) + ) + torch.cuda.set_device(local_rank) + dist.init_process_group("nccl") + return world_size, local_rank, rank + + +def main(): + args = get_args() + os.makedirs(args.output_dir, exist_ok=True) + + assert torch.cuda.is_available() + world_size, local_rank, rank = init_distributed() + device = torch.device(f"cuda:{local_rank}") + + # Load LLM model and tokenizer directly + tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path) + model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path) + model.eval() + model.to(device) + + cosyvoice_codec = CosyVoice2( + args.token2wav_path, load_jit=True, load_trt=True, fp16=True + ) + if args.prompt_speech_path: + prompt_speech_16k = load_wav(args.prompt_speech_path, 16000) + else: + prompt_speech_16k = None + s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None + dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2" + dataset = load_dataset( + dataset_name, + split=args.split_name, + trust_remote_code=True, + ) + + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + shuffle=False, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer), + ) + + total_steps = len(dataset) + + if rank == 0: + progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") + + for batch in dataloader: + with torch.no_grad(): + input_ids = batch["input_ids"].to(device) + + # Generate speech tokens using LLM + outputs = model.generate( + input_ids, + max_new_tokens=2048, # Max length for generation + do_sample=True, + top_p=args.top_p, + temperature=args.temperature, + top_k=args.top_k, + ) + + # Process each sample in the batch + for i in range(len(batch["ids"])): + # Extract generated tokens (excluding input) + input_length = input_ids[i].shape[0] + generated_ids = outputs[i][input_length:-1] # Remove last token if needed + speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + # Extract speech IDs from token strings like <|s_23456|> + speech_ids = extract_speech_ids(speech_tokens_str) + + if len(speech_ids) == 0: + print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping") + continue + + # Convert to tensor for CosyVoice2 + audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0) + + if args.prompt_text is not None: + current_prompt_text = args.prompt_text + current_prompt_audio = prompt_speech_16k + else: + current_prompt_text = batch["prompt_text"][i] + current_prompt_audio = batch["prompt_audio_list"][i] + + if current_prompt_audio is not None: + # Generate audio using CosyVoice2 + audio_hat = audio_decode_cosyvoice2( + audio_tokens, + current_prompt_text, + current_prompt_audio, + cosyvoice_codec, + ) + + # Convert to numpy and save + generated_wave = audio_hat.squeeze(0).cpu().numpy() + target_sample_rate = 24000 + + utt = batch["ids"][i] + sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate) + + print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens") + else: + print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping") + + if rank == 0: + progress_bar.update(world_size * len(batch["ids"])) + + if rank == 0: + progress_bar.close() + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/grpo/cosyvoice2/prepare_data.py b/examples/grpo/cosyvoice2/prepare_data.py new file mode 100644 index 0000000..46c3c09 --- /dev/null +++ b/examples/grpo/cosyvoice2/prepare_data.py @@ -0,0 +1,86 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the Text to Speech dataset to parquet format +""" + +import argparse +import os +import re + +import datasets + +from verl.utils.hdfs_io import copy, makedirs + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file") + parser.add_argument("--test_file", required=True, help="Path to test JSON/JSONL file") + parser.add_argument("--local_dir", default=None, required=True) + parser.add_argument("--hdfs_dir", default=None) + + args = parser.parse_args() + + # Load datasets from local JSON files + train_dataset = datasets.load_dataset("json", data_files=args.train_file)['train'] + test_dataset = datasets.load_dataset("json", data_files=args.test_file)['train'] + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + text = example.pop("text") + + # use cosyvoice2 official huggingface compatible checkpoint template + question = text + answer = "" + + data = { + "data_source": f"{args.train_file}_{args.test_file}", # Use file names as data source + "prompt": [ + { + "role": "user", + "content": question, + }, + { + "role": "assistant", + "content": answer, + }, + ], + "ability": "text-to-speech", + "reward_model": {"style": "rule", "ground_truth": text}, + "extra_info": { + "split": split, + "index": idx, + "text": text, + }, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + print(train_dataset) + print(test_dataset) + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + + copy(src=local_dir, dst=hdfs_dir) diff --git a/examples/grpo/cosyvoice2/pretrained_to_huggingface.py b/examples/grpo/cosyvoice2/pretrained_to_huggingface.py new file mode 100644 index 0000000..161a11f --- /dev/null +++ b/examples/grpo/cosyvoice2/pretrained_to_huggingface.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: Instruct TTS + python3 infer.py \ + --token2wav-path /workspace/CosyVoice2-0.5B \ + --prompt-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ + --prompt-speech-path ./assets/prompt_audio.wav \ + --model-path ./transformers_cosyvoice2_llm \ + --input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。" +""" +from cosyvoice.cli.cosyvoice import CosyVoice2 +import sys +from argparse import ArgumentParser +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + +sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") + + +def get_args(): + parser = ArgumentParser() + + parser.add_argument( + "--pretrained-cosyvoice2-path", + type=str, + default="/workspace/CosyVoice2-0.5B", + help="Token2Wav path, default to %(default)r", + ) + parser.add_argument( + "--save-path", + type=str, + default='./transformers_cosyvoice2_llm', + help="The path to save the model", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + cosy2_model = CosyVoice2( + args.pretrained_cosyvoice2_path, load_jit=False, load_trt=False, fp16=False + ) + + llm = cosy2_model.model.llm.llm.model + + speech_embedding = cosy2_model.model.llm.speech_embedding + llm_decoder = cosy2_model.model.llm.llm_decoder + llm_embedding = cosy2_model.model.llm.llm_embedding + + tokenizer = AutoTokenizer.from_pretrained(f"{args.pretrained_cosyvoice2_path}/CosyVoice-BlankEN") + special_tokens = { + 'eos_token': '<|endoftext|>', + 'pad_token': '<|endoftext|>', + 'additional_special_tokens': [ + '<|im_start|>', '<|im_end|>', '<|endofprompt|>', + '[breath]', '', '', '[noise]', + '[laughter]', '[cough]', '[clucking]', '[accent]', + '[quick_breath]', + "", "", + "[hissing]", "[sigh]", "[vocalized-noise]", + "[lipsmack]", "[mn]" + ] + } + tokenizer.add_special_tokens(special_tokens) + + original_tokenizer_vocab_size = len(tokenizer) + cosyvoice2_token_size = 6561 + new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [ + "<|eos1|>", "<|eos2|>", "<|eos3|>", "<|sos|>", "<|task_id|>" + ] + num_added_tokens = tokenizer.add_tokens(new_tokens) + + llm.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128) + vocab_size = llm.get_input_embeddings().weight.shape[0] + + feature_size = speech_embedding.embedding_dim + new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=vocab_size, bias=True) + + with torch.no_grad(): + # set the weight and bias of the new lm_head to 0 + new_lm_head.weight.data.zero_() + # make bias value -inf + new_lm_head.bias.data.fill_(-float('inf')) + new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight + new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias + + llm.lm_head = new_lm_head + input_embeddings = llm.get_input_embeddings() + + with torch.no_grad(): + input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight + input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight + + eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size, + original_tokenizer_vocab_size + cosyvoice2_token_size + 1, + original_tokenizer_vocab_size + cosyvoice2_token_size + 2] + llm.generation_config.eos_token_id = eos_token_ids + llm.generation_config.temperature = 1.0 + llm.generation_config.top_p = 0.8 + llm.generation_config.top_k = 25 + + llm.config.eos_token_id = original_tokenizer_vocab_size + cosyvoice2_token_size + llm.config.vocab_size = vocab_size + llm.config.tie_word_embeddings = False + llm.config.use_bias = True + llm.to(torch.bfloat16) + llm.save_pretrained(args.save_path) + + TEMPLATE = ( + "{%- for message in messages %}" + "{%- if message['role'] == 'user' %}" + "{{- '<|sos|>' + message['content'] + '<|task_id|>' }}" + "{%- elif message['role'] == 'assistant' %}" + "{{- message['content']}}" + "{%- endif %}" + "{%- endfor %}" + ) + tokenizer.chat_template = TEMPLATE + tokenizer.save_pretrained(args.save_path) diff --git a/examples/grpo/cosyvoice2/requirements.txt b/examples/grpo/cosyvoice2/requirements.txt new file mode 100644 index 0000000..50f4edd --- /dev/null +++ b/examples/grpo/cosyvoice2/requirements.txt @@ -0,0 +1,31 @@ +conformer==0.3.2 +diffusers==0.29.0 +gdown==5.1.0 +gradio +hydra-core==1.3.2 +HyperPyYAML==1.2.2 +inflect==7.3.1 +librosa==0.10.2 +lightning==2.2.4 +matplotlib==3.7.5 +modelscope==1.15.0 +networkx==3.1 +omegaconf==2.3.0 +onnx==1.16.0 +onnxruntime-gpu==1.18.0 +protobuf==4.25 +pydantic==2.7.0 +pyworld==0.3.4 +rich==13.7.1 +soundfile==0.12.1 +tensorboard==2.14.0 +wget==3.2 +WeTextProcessing==1.0.3 +s3tokenizer +tensorrt +sherpa_onnx +jiwer +zhon +numpy==1.25.2 +pypinyin +openai-whisper \ No newline at end of file diff --git a/examples/grpo/cosyvoice2/reward_tts.py b/examples/grpo/cosyvoice2/reward_tts.py new file mode 100644 index 0000000..4c40761 --- /dev/null +++ b/examples/grpo/cosyvoice2/reward_tts.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Reward calculation for CosyVoice2-0.5B. +""" + +from __future__ import annotations + +import re +import json +import time +import argparse +from typing import List + +import numpy as np +import requests + + +REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer" + + +def _parse_ids(token_str: str) -> List[int]: + return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)] + + +def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float: + """Send token IDs and ground-truth text to the Triton server and get reward.""" + + tokens_arr = np.array(tokens, dtype=np.int32).reshape(1, -1) + lens_arr = np.array([[tokens_arr.shape[1]]], dtype=np.int32) + + gt_arr = np.array([ground_truth.encode("utf-8")], dtype=object) + + payload = { + "inputs": [ + { + "name": "TOKENS", + "shape": list(tokens_arr.shape), + "datatype": "INT32", + "data": tokens_arr.tolist(), + }, + { + "name": "TOKEN_LENS", + "shape": list(lens_arr.shape), + "datatype": "INT32", + "data": lens_arr.tolist(), + }, + { + "name": "GT_TEXT", + "shape": [1, 1], + "datatype": "BYTES", + "data": [ground_truth], + }, + ] + } + rsp = requests.post( + REWARD_SERVER_URL, + headers={"Content-Type": "application/json"}, + json=payload, + timeout=timeout, + verify=False, + params={"request_id": "0"}, + ) + rsp.raise_for_status() + result = rsp.json() + + try: + # Reward is returned as the first output + return float(result["outputs"][0]["data"][0]) + except (KeyError, IndexError, TypeError): + return 0.0 + + +def compute_score( + data_source: str, + solution_str: str, + ground_truth: str, + extra_info: dict | None = None, + *, + debug_dump: bool = False, +) -> float: + """Return reward in [0, 1] using the Triton ASR service. + + The reward is based on the pinyin-level WER between the ASR transcript + produced from *solution_str* and the provided *ground_truth* text. + """ + + # Decode token IDs + ids = _parse_ids(solution_str) + + # Query remote server for reward + try: + reward = _remote_reward(ids, ground_truth) + except Exception as e: + reward = 0.0 + + if debug_dump: + print( + f"\033[92m[{data_source}] Remote reward: {reward:.4f}\033[0m" + ) + + return reward + + +# CLI quick test +if __name__ == "__main__": + import sys + + def get_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Test TTS CER scoring with data from JSONL file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--input", "-i", + type=str, + default="data/emilia_zh-cosy-tiny-test.jsonl", + help="Path to input JSONL file" + ) + + parser.add_argument( + "--max-samples", "-n", + type=int, + default=None, + help="Maximum number of samples to process (default: all)" + ) + + parser.add_argument( + "--no-interactive", + action="store_true", + help="Run in non-interactive mode (process all samples without prompts)" + ) + + parser.add_argument( + "--debug", + action="store_true", + help="Enable debug mode" + ) + + return parser.parse_args() + + def load_jsonl(file_path: str): + """Load data from jsonl file.""" + data = [] + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + data.append(json.loads(line.strip())) + return data + + def code_to_solution_str(code_list: List[int]) -> str: + """Convert code list to solution string format.""" + return ''.join([f"<|s_{code}|>" for code in code_list]) + + # Parse command line arguments + args = get_args() + + try: + # Load data from jsonl file + print(f"Loading data from: {args.input}") + data_list = load_jsonl(args.input) + print(f"Loaded {len(data_list)} samples") + + # Limit samples if specified + if args.max_samples is not None: + data_list = data_list[:args.max_samples] + print(f"Processing first {len(data_list)} samples (limited by --max-samples)") + + # Process each sample + begin_time = time.time() + for i, sample in enumerate(data_list): + print(f"\n--- Sample {i+1}/{len(data_list)} ---") + print(f"Index: {sample.get('index', 'unknown')}") + print(f"Text: {sample['text']}") + + # Extract required fields + code_list = sample['code'] + ground_truth = sample['text'] + data_source = sample.get('index', f'sample_{i}') # Use index as data_source + + # Convert code list to solution string + solution_str = code_to_solution_str(code_list) + print(f"Solution tokens: {len(code_list)} tokens") + if args.debug: + print(f"Solution string: {solution_str}") + else: + print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}") + + # Call compute_score function + try: + score = compute_score( + data_source=data_source, + solution_str=solution_str, + ground_truth=ground_truth, + extra_info=None, + debug_dump=args.debug + ) + print(f"Final Score: {score:.4f}") + except Exception as e: + print(f"Error computing score: {e}") + + # Ask user if they want to continue (for interactive mode) + if not args.no_interactive and i < len(data_list) - 1: + try: + response = input("\nPress Enter to continue or 'q' to quit: ").strip().lower() + if response == 'q': + break + except KeyboardInterrupt: + print("\nStopped by user") + break + + print(f"\nProcessed {min(i+1, len(data_list))} samples") + end_time = time.time() + print(f"Time taken: {end_time - begin_time} seconds") + except FileNotFoundError: + print(f"Error: File not found - {args.input}") + print("Please check the file path or use --input to specify correct path") + print("Run with --help for usage information") + except Exception as e: + print(f"Error: {e}") diff --git a/examples/grpo/cosyvoice2/run.sh b/examples/grpo/cosyvoice2/run.sh new file mode 100644 index 0000000..ce97ab3 --- /dev/null +++ b/examples/grpo/cosyvoice2/run.sh @@ -0,0 +1,159 @@ +#!/usr/bin/env bash + +set -eou pipefail + +stage=-1 +stop_stage=4 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +export PYTHONPATH=/workspace/CosyVoice +model_scope_model_path=./CosyVoice2-0.5B +sft_model_path=./transformers_cosyvoice2_llm + +if [ $stage -le -2 ] && [ $stop_stage -ge -2 ]; then + log "stage -2: install dependencies locally if pre-built docker image is not available" + conda create -n cosyvoice2 python=3.10 -y + conda activate cosyvoice2 + # install verl + git clone https://github.com/yuekaizhang/verl.git -b thread + cd verl + USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh + pip install --no-deps -e . + cd - + # install requirements + pip install -r requirements.txt + pip install -U nvidia-pytriton + git clone https://github.com/yuekaizhang/PytritonSenseVoice.git && cd PytritonSenseVoice && pip install -e . +fi + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint" + modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path + python3 pretrained_to_huggingface.py \ + --pretrained-cosyvoice2-path $model_scope_model_path \ + --save-path $sft_model_path + + # Or, you could use the following command to download the huggingface compatible checkpoint + # huggingface-cli download --local-dir $sft_model_path yuekai/cosyvoice2_llm + + # Note: we remove the lm_head's bias to make it compatible with the Qwen2.5-0.5B model in Transformers. +fi + +data_dir=data/parquet_aishell3 +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "stage 0: prepare data into verl format" + mkdir -p $data_dir + wget -O data/aishell-3.jsonl https://huggingface.co/datasets/SparkAudio/voxbox/resolve/main/metadata/aishell-3.jsonl + # total 88035 samples + head -n 80000 data/aishell-3.jsonl > data/train.jsonl + tail -n 100 data/aishell-3.jsonl > data/test.jsonl + python prepare_data.py \ + --train_file data/train.jsonl \ + --test_file data/test.jsonl \ + --local_dir $data_dir +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "stage 1: start token2wav asr server for reward function" + python3 token2wav_asr_server.py --number-of-devices 8 +fi + +exp_name=official_llm_aishell3_grpo +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "stage 2: grpo train" + export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + export MKL_SERVICE_FORCE_INTEL=TRUE + n_gpus_per_node=8 + micro_batch_size=4 + train_batch_size=32 + python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$data_dir/train.parquet \ + data.val_files=$data_dir/test.parquet \ + data.train_batch_size=$train_batch_size \ + data.max_prompt_length=1024 \ + data.max_response_length=512 \ + data.truncation='error' \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.model.path=$sft_model_path \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.do_sample=true \ + actor_rollout_ref.rollout.temperature=0.8 \ + actor_rollout_ref.rollout.top_p=0.95 \ + actor_rollout_ref.rollout.top_k=25 \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=true \ + actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \ + actor_rollout_ref.rollout.val_kwargs.top_k=25 \ + reward_model.reward_manager=prime \ + custom_reward_function.path=reward_tts.py \ + custom_reward_function.name=compute_score \ + trainer.project_name='cosyvoice2_grpo' \ + trainer.experiment_name=$exp_name \ + trainer.logger=['console','wandb'] \ + trainer.n_gpus_per_node=$n_gpus_per_node \ + trainer.nnodes=1 \ + trainer.save_freq=100 \ + trainer.test_freq=100 \ + trainer.resume_mode='auto' \ + trainer.total_epochs=1 \ + trainer.val_before_train=False +fi + +steps=(100 200 300 400 500) +for step in ${steps[@]}; do +llm_path=./checkpoints/cosyvoice2_grpo/$exp_name/global_step_${step} +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "stage 3: merge the model" + python -m verl.model_merger merge \ + --backend fsdp \ + --local_dir $llm_path/actor \ + --target_dir $llm_path/merged_hf_model || exit 1 +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "stage 4: Test the model" + dataset=zero_shot_zh # from CosyVoice3 test set + # dataset=test_zh # from seed_tts test set + output_dir=./outputs_${exp_name}_${step}_${dataset} + + token2wav_path=/workspace/CosyVoice2-0.5B + model_path=$llm_path/merged_hf_model + + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + torchrun --nproc_per_node=8 \ + infer_dataset.py \ + --output-dir $output_dir \ + --llm-model-name-or-path $model_path \ + --token2wav-path $token2wav_path \ + --split-name ${dataset} || exit 1 + + bash scripts/compute_wer.sh $output_dir ${dataset} +fi +done + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "stage 5: Convert the RL trained model to CosyVoice repo format" + python3 huggingface_to_pretrained.py \ + --hf-cosyvoice2-llm-path $llm_path/merged_hf_model \ + --output-path /workspace/CosyVoice2-0.5B/llm-new.pt + # You need to manually move the llm-new.pt to overwrite /workspace/CosyVoice2-0.5B/llm.pt + # However, we found that the RL trained model accuracy would slightly drop after this conversion. + # Please be careful or use the huggingface format inference code. +fi \ No newline at end of file diff --git a/examples/grpo/cosyvoice2/scripts/compute_wer.sh b/examples/grpo/cosyvoice2/scripts/compute_wer.sh new file mode 100644 index 0000000..43a6872 --- /dev/null +++ b/examples/grpo/cosyvoice2/scripts/compute_wer.sh @@ -0,0 +1,33 @@ +wav_dir=$1 +wav_files=$(ls $wav_dir/*.wav) +# if wav_files is empty, then exit +if [ -z "$wav_files" ]; then + exit 1 +fi +split_name=$2 +model_path=models/sherpa-onnx-paraformer-zh-2023-09-14 + +if [ ! -d $model_path ]; then + pip install sherpa-onnx + wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + mkdir models + tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C models +fi + +python3 scripts/offline-decode-files.py \ + --tokens=$model_path/tokens.txt \ + --paraformer=$model_path/model.int8.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=24000 \ + --log-dir $wav_dir \ + --feature-dim=80 \ + --split-name $split_name \ + --name sherpa_onnx \ + $wav_files + +# python3 scripts/paraformer-pytriton-client.py \ +# --log-dir $wav_dir \ +# --split-name $split_name \ +# $wav_files \ No newline at end of file diff --git a/examples/grpo/cosyvoice2/scripts/offline-decode-files.py b/examples/grpo/cosyvoice2/scripts/offline-decode-files.py new file mode 100644 index 0000000..847d434 --- /dev/null +++ b/examples/grpo/cosyvoice2/scripts/offline-decode-files.py @@ -0,0 +1,756 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 by manyeyes +# Copyright (c) 2023 Xiaomi Corporation + +""" +This file demonstrates how to use sherpa-onnx Python API to transcribe +file(s) with a non-streaming model. + +(1) For paraformer + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/paraformer.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(2) For transducer models from icefall + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(3) For CTC models from NeMo + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \ + --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav + +(4) For Whisper models + +python3 ./python-api-examples/offline-decode-files.py \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --whisper-task=transcribe \ + --num-threads=1 \ + ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav + +(5) For CTC models from WeNet + +python3 ./python-api-examples/offline-decode-files.py \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav + +(6) For tdnn models of the yesno recipe from icefall + +python3 ./python-api-examples/offline-decode-files.py \ + --sample-rate=8000 \ + --feature-dim=23 \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/index.html +to install sherpa-onnx and to download non-streaming pre-trained models +used in this file. +""" +import argparse +import time +import wave +from pathlib import Path +from typing import List, Tuple, Dict, Iterable, TextIO, Union + +import numpy as np +import sherpa_onnx +import soundfile as sf +from datasets import load_dataset +import logging +from collections import defaultdict +import kaldialign +from zhon.hanzi import punctuation +import string +punctuation_all = punctuation + string.punctuation +Pathlike = Union[str, Path] + + +def remove_punctuation(text: str) -> str: + for x in punctuation_all: + if x == '\'': + continue + text = text.replace(x, '') + return text + + +def store_transcripts( + filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False +) -> None: + """Save predicted results and reference transcripts to a file. + + Args: + filename: + File to save the results to. + texts: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + If it is a multi-talker ASR system, the ref and hyp may also be lists of + strings. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf8") as f: + for cut_id, ref, hyp in texts: + if char_level: + ref = list("".join(ref)) + hyp = list("".join(hyp)) + print(f"{cut_id}:\tref={ref}", file=f) + print(f"{cut_id}:\thyp={hyp}", file=f) + + +def write_error_stats( + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, + compute_CER: bool = False, + sclite_mode: bool = False, +) -> float: + """Write statistics based on predicted results and reference transcripts. + + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the cut_id, the second is + the reference transcript and the third element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + Returns: + Return None. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + + if compute_CER: + for i, res in enumerate(results): + cut_id, ref, hyp = res + ref = list("".join(ref)) + hyp = list("".join(hyp)) + results[i] = (cut_id, ref, hyp) + + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) + for ref_word, hyp_word in ali: + if ref_word == ERR: + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + ref_len = sum([len(r) for _, r, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [ + [ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] + for x, y in ali + ] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [ + [ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] + for x, y in ali + ] + + print( + f"{cut_id}:\t" + + " ".join( + ( + ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali + ) + ), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + print("", file=f) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) + for _, word, counts in sorted( + [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True + ): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + ref_count = corr + ref_sub + dels + hyp_count = corr + hyp_sub + ins + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + return float(tot_err_rate) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, like + HELLO WORLD + 你好世界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + parser.add_argument( + "--modeling-unit", + type=str, + default="", + help=""" + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. + Used only when hotwords-file is given. + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + default="", + help=""" + The path to the bpe vocabulary, the bpe vocabulary is generated by + sentencepiece, you can also export the bpe vocabulary through a bpe model + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given + and modeling-unit is bpe or cjkchar+bpe. + """, + ) + + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + parser.add_argument( + "--nemo-ctc", + default="", + type=str, + help="Path to the model.onnx from NeMo CTC", + ) + + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the model.onnx from WeNet CTC", + ) + + parser.add_argument( + "--tdnn-model", + default="", + type=str, + help="Path to the model.onnx for the tdnn model of the yesno recipe", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input audio file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="""Sample rate of the feature extractor. Must match the one + expected by the model. Note: The input sound files can have a + different sample rate from this argument.""", + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="Feature dimension. Must match the one expected by the model", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + parser.add_argument( + "--name", + type=str, + default="", + help="The directory containing the input sound files to decode", + ) + + parser.add_argument( + "--log-dir", + type=str, + default="", + help="The directory containing the input sound files to decode", + ) + + parser.add_argument( + "--label", + type=str, + default=None, + help="wav_base_name label", + ) + + # Dataset related arguments for loading labels when label file is not provided + parser.add_argument( + "--dataset-name", + type=str, + default="yuekai/seed_tts_cosy2", + help="Huggingface dataset name for loading labels", + ) + + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + help="Dataset split name for loading labels", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and can be of type + 32-bit floating point PCM. Its sample rate does not need to be 24kHz. + + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, + which are normalized to the range [-1, 1]. + - Sample rate of the wave file. + """ + + samples, sample_rate = sf.read(wave_filename, dtype="float32") + assert ( + samples.ndim == 1 + ), f"Expected single channel, but got {samples.ndim} channels." + + samples_float32 = samples.astype(np.float32) + + return samples_float32, sample_rate + + +def normalize_text_alimeeting(text: str) -> str: + """ + Text normalization similar to M2MeT challenge baseline. + See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl + """ + import re + text = text.replace('\u00A0', '') # test_hard + text = text.replace(" ", "") + text = text.replace("", "") + text = text.replace("<%>", "") + text = text.replace("<->", "") + text = text.replace("<$>", "") + text = text.replace("<#>", "") + text = text.replace("<_>", "") + text = text.replace("", "") + text = text.replace("`", "") + text = text.replace("&", "") + text = text.replace(",", "") + if re.search("[a-zA-Z]", text): + text = text.upper() + text = text.replace("A", "A") + text = text.replace("a", "A") + text = text.replace("b", "B") + text = text.replace("c", "C") + text = text.replace("k", "K") + text = text.replace("t", "T") + text = text.replace(",", "") + text = text.replace("丶", "") + text = text.replace("。", "") + text = text.replace("、", "") + text = text.replace("?", "") + text = remove_punctuation(text) + return text + + +def main(): + args = get_args() + assert_file_exists(args.tokens) + assert args.num_threads > 0, args.num_threads + + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + + assert_file_exists(args.paraformer) + + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + + print("Started!") + start_time = time.time() + + streams, results = [], [] + total_duration = 0 + + for i, wave_filename in enumerate(args.sound_files): + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + total_duration += duration + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) + + streams.append(s) + if i % 10 == 0: + recognizer.decode_streams(streams) + results += [s.result.text for s in streams] + streams = [] + print(f"Processed {i} files") + # process the last batch + if streams: + recognizer.decode_streams(streams) + results += [s.result.text for s in streams] + end_time = time.time() + print("Done!") + + results_dict = {} + for wave_filename, result in zip(args.sound_files, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + wave_basename = Path(wave_filename).stem + results_dict[wave_basename] = result + + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + print(f"num_threads: {args.num_threads}") + print(f"decoding_method: {args.decoding_method}") + print(f"Wave duration: {total_duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + # Load labels either from file or from dataset + labels_dict = {} + + if args.label: + # Load labels from file (original functionality) + print(f"Loading labels from file: {args.label}") + with open(args.label, "r") as f: + for line in f: + # fields = line.strip().split(" ") + # fields = [item for item in fields if item] + # assert len(fields) == 4 + # prompt_text, prompt_audio, text, audio_path = fields + + fields = line.strip().split("|") + fields = [item for item in fields if item] + assert len(fields) == 4 + audio_path, prompt_text, prompt_audio, text = fields + labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text) + else: + # Load labels from dataset (new functionality) + print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}") + if 'zero' in args.split_name: + dataset_name = "yuekai/CV3-Eval" + else: + dataset_name = "yuekai/seed_tts_cosy2" + dataset = load_dataset( + dataset_name, + split=args.split_name, + trust_remote_code=True, + ) + + for item in dataset: + audio_id = item["id"] + labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"]) + + print(f"Loaded {len(labels_dict)} labels from dataset") + + # Perform evaluation if labels are available + if labels_dict: + + final_results = [] + for key, value in results_dict.items(): + if key in labels_dict: + final_results.append((key, labels_dict[key], value)) + else: + print(f"Warning: No label found for {key}, skipping...") + + if final_results: + store_transcripts( + filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results + ) + with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f: + write_error_stats(f, "test-set", final_results, enable_log=True) + + with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f: + print(f.readline()) # WER + print(f.readline()) # Detailed errors + else: + print("No matching labels found for evaluation") + else: + print("No labels available for evaluation") + + +if __name__ == "__main__": + main() diff --git a/examples/grpo/cosyvoice2/token2wav_asr_server.py b/examples/grpo/cosyvoice2/token2wav_asr_server.py new file mode 100644 index 0000000..8a6cb6e --- /dev/null +++ b/examples/grpo/cosyvoice2/token2wav_asr_server.py @@ -0,0 +1,346 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pytriton server for token2wav conversion and ASR""" + +from datasets import load_dataset +from cosyvoice.cli.cosyvoice import CosyVoice2 +from omnisense.models import OmniSenseVoiceSmall +from pytriton.proxy.types import Request +from pytriton.triton import Triton, TritonConfig +from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor +from pytriton.decorators import batch +import argparse +import io +import logging +from typing import Any, List +import numpy as np +import torch +from scipy.signal import resample +import sys +import random +import re +from jiwer import wer +from pypinyin import lazy_pinyin, Style +from tn.chinese.normalizer import Normalizer as ZhNormalizer + +# Chinese text normalizer (cached globally) +zh_tn_model = ZhNormalizer( + cache_dir="./cache", + remove_erhua=False, + remove_interjections=False, + remove_puncts=True, + overwrite_cache=True, +) + + +sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") + +logger = logging.getLogger("token2wav_asr_server") + + +class _ASR_Server: + """Wraps a single OmniSenseVoiceSmall model instance for Triton.""" + + def __init__(self, device_id: int): + self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id) + + @batch + def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray): + """ + WAV: np.ndarray, WAV_LENS: np.ndarray + LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used + See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu + """ + logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape) + wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))] + + results = self._model.transcribe_single_batch( + wavs, + language="zh", + textnorm="woitn", + ) + texts = [result.text for result in results] + transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8") + return {"TRANSCRIPTS": transcripts} + + +def audio_decode_cosyvoice2( + audio_tokens, prompt_text, prompt_speech_16k, codec_decoder +): + """ + Generate audio from tokens with optional tone and prompt embedding. + """ + model_inputs_dict = codec_decoder.frontend.frontend_zero_shot( + "empty", prompt_text, prompt_speech_16k, 24000 + ) + tts_mel, _ = codec_decoder.model.flow.inference( + token=audio_tokens.to(codec_decoder.model.device), + token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to( + codec_decoder.model.device + ), + prompt_token=model_inputs_dict["flow_prompt_speech_token"].to( + codec_decoder.model.device + ), + prompt_token_len=torch.tensor( + [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32 + ).to(codec_decoder.model.device), + prompt_feat=model_inputs_dict["prompt_speech_feat"].to( + codec_decoder.model.device + ), + prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to( + codec_decoder.model.device + ), + embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device), + finalize=True, + ) + + audio_hat, _ = codec_decoder.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) + + return audio_hat + + +def get_random_prompt_from_dataset(dataset): + """ + Get random prompt text and speech from the pre-loaded dataset. + Returns (prompt_text, prompt_speech_16k) + """ + random_idx = random.randint(0, len(dataset) - 1) + sample = dataset[random_idx] + + # Extract audio data + audio_data = sample["audio"] + audio_array = audio_data["array"] + sample_rate = audio_data["sampling_rate"] + + # Convert audio to 16kHz if needed + if sample_rate != 16000: + num_samples = int(len(audio_array) * (16000 / sample_rate)) + audio_array = resample(audio_array, num_samples) + + # Convert to torch tensor + prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0) + prompt_text = sample["text"] + # remove space in prompt_text + prompt_text = prompt_text.replace(" ", "") + return prompt_text, prompt_speech_16k + + +class _Token2Wav_ASR: + """Wraps a single OmniSenseVoiceSmall model instance for Triton.""" + + def __init__(self, device_id: int): + self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id) + self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"] + + # Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model + # CosyVoice2 internally uses generic "cuda" device, so we first switch the + # current CUDA context to the desired card before the object is created. + # Afterwards, all parameters loaded with the generic "cuda" device will + # reside on this GPU. We keep the selected id in `self.device_id` and + # will set the context again for every forward call to avoid race + # conditions when several instances are used in the same process. + + self.device_id = device_id + + # Construct the TTS codec decoder under the correct CUDA device context + with torch.cuda.device(self.device_id): + self.codec_decoder = CosyVoice2( + "/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True + ) + + @batch + def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray): + """ + WAV: np.ndarray, WAV_LENS: np.ndarray + LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used + See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu + """ + # Ensure the default CUDA device is set correctly for this invocation + torch.cuda.set_device(self.device_id) + + if self.device_id == 0: + print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}") + + tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))] + + # Decode ground-truth text strings (BYTES → str) + if GT_TEXT.ndim == 2: + gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))] + else: + gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))] + + wavs = [] + for tokens in tokens_list: + prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset) + audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0) + audio_hat = audio_decode_cosyvoice2( + audio_tokens, + prompt_text, + prompt_speech_16k, + self.codec_decoder, + ) + # resample to 16000 using soundfile + audio_hat = audio_hat.squeeze(0).float().cpu() + audio_hat = audio_hat.numpy() + num_samples = int(len(audio_hat) * (16000 / 24000)) + audio_hat = resample(audio_hat, num_samples) + wavs.append(audio_hat) + + results = self.asr_model.transcribe_single_batch( + wavs, + language="zh", + textnorm="woitn", + ) + texts = [result.text for result in results] + + # ---------------- Reward computation ---------------- + rewards = [] + for gt_text, hyp_text in zip(gt_texts, texts): + gt_norm = zh_tn_model.normalize(gt_text).lower() + hyp_norm = zh_tn_model.normalize(hyp_text).lower() + + gt_pinyin = lazy_pinyin( + gt_norm, + style=Style.TONE3, + tone_sandhi=True, + neutral_tone_with_five=True, + ) + hyp_pinyin = lazy_pinyin( + hyp_norm, + style=Style.TONE3, + tone_sandhi=True, + neutral_tone_with_five=True, + ) + + c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin))) + reward_val = 1.0 - np.tanh(3.0 * c) + reward_val = max(0.0, min(1.0, reward_val)) + rewards.append(reward_val) + print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}") + + transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8") + rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1) + + return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts} + + +def _infer_function_factory(device_ids: List[int], model_name: str): + """Creates a list of inference functions, one for each requested device ID.""" + infer_funcs = [] + for device_id in device_ids: + if model_name == "sensevoice": + infer_funcs.append(_ASR_Server(device_id=device_id)) + else: + infer_funcs.append(_Token2Wav_ASR(device_id=device_id)) + return infer_funcs + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--max-batch-size", + type=int, + default=32, + help="Batch size of request.", + required=False, + ) + parser.add_argument( + "--verbose", + action="store_true", + default=False, + ) + parser.add_argument( + "--number-of-instances-per-device", + type=int, + default=1, + help="Number of model instances to load.", + required=False, + ) + parser.add_argument( + "--number-of-devices", + type=int, + default=8, + help="Number of devices to use.", + ) + parser.add_argument( + "--model-name", + type=str, + default="token2wav_asr", + choices=["token2wav_asr", "sensevoice"], + help="Model name.", + ) + + args = parser.parse_args() + + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s") + + triton_config = TritonConfig( + http_port=8000, + grpc_port=8001, + metrics_port=8002, + ) + + device_ids = [i for i in range(args.number_of_devices)] + device_ids = device_ids * args.number_of_instances_per_device + + with Triton(config=triton_config) as triton: + logger.info("Loading SenseVoice model on device ids: %s", device_ids) + if args.model_name == "sensevoice": + triton.bind( + model_name="sensevoice", + infer_func=_infer_function_factory(device_ids, args.model_name), + inputs=[ + Tensor(name="WAV", dtype=np.float32, shape=(-1,)), + Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)), + Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)), + Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)), + ], + outputs=[ + Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)), + ], + config=ModelConfig( + max_batch_size=args.max_batch_size, + batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms + ), + strict=True, + ) + else: + triton.bind( + model_name="token2wav_asr", + infer_func=_infer_function_factory(device_ids, args.model_name), + inputs=[ + Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)), + Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)), + Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)), + ], + outputs=[ + Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)), + Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)), + ], + config=ModelConfig( + max_batch_size=args.max_batch_size, + batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms + ), + strict=True, + ) + logger.info("Serving inference") + triton.serve() + + +if __name__ == "__main__": + main() diff --git a/runtime/python/Dockerfile b/runtime/python/Dockerfile index ae7e01f..e2d2012 100644 --- a/runtime/python/Dockerfile +++ b/runtime/python/Dockerfile @@ -9,5 +9,5 @@ RUN apt-get -y install git unzip git-lfs g++ RUN git lfs install RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git # here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed -RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com +RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto \ No newline at end of file diff --git a/runtime/triton_trtllm/README.md b/runtime/triton_trtllm/README.md index b1e091c..3765038 100644 --- a/runtime/triton_trtllm/README.md +++ b/runtime/triton_trtllm/README.md @@ -1,15 +1,17 @@ -## Best Practices for Serving CosyVoice with NVIDIA Triton Inference Server +## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM -Thanks to the contribution from NVIDIA Yuekai Zhang. +Contributed by Yuekai Zhang (NVIDIA). ### Quick Start + Launch the service directly with Docker Compose: ```sh docker compose up ``` ### Build the Docker Image -Build the image from scratch: + +To build the image from scratch: ```sh docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06 ``` @@ -21,71 +23,124 @@ docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_di ``` ### Understanding `run.sh` + The `run.sh` script orchestrates the entire workflow through numbered stages. -Run a subset of stages with: +You can run a subset of stages with: ```sh bash run.sh [service_type] ``` -- `` – stage to start from (0-5). -- `` – stage to stop after (0-5). +- ``: The stage to start from (0-5). +- ``: The stage to stop after (0-5). -Stages: -- **Stage 0** – Download the cosyvoice-2 0.5B model from HuggingFace. -- **Stage 1** – Convert the HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines. -- **Stage 2** – Create the Triton model repository and configure the model files (adjusts depending on whether `Decoupled=True/False` will be used later). -- **Stage 3** – Launch the Triton Inference Server. -- **Stage 4** – Run the single-utterance HTTP client. -- **Stage 5** – Run the gRPC benchmark client. +**Stages:** + +- **Stage 0**: Downloads the `cosyvoice-2 0.5B` model from HuggingFace. +- **Stage 1**: Converts the HuggingFace checkpoint to the TensorRT-LLM format and builds the TensorRT engines. +- **Stage 2**: Creates the Triton model repository and configures the model files. The configuration is adjusted based on whether `Decoupled=True` (streaming) or `Decoupled=False` (offline) will be used. +- **Stage 3**: Launches the Triton Inference Server. +- **Stage 4**: Runs the single-utterance HTTP client for testing. +- **Stage 5**: Runs the gRPC benchmark client. +- **Stage 6**: Runs the offline inference benchmark test. + +### Export Models and Launch Server -### Export Models to TensorRT-LLM and Launch the Server Inside the Docker container, prepare the models and start the Triton server by running stages 0-3: ```sh -# Runs stages 0, 1, 2, and 3 +# This command runs stages 0, 1, 2, and 3 bash run.sh 0 3 ``` -*Note: Stage 2 prepares the model repository differently depending on whether you intend to run with `Decoupled=False` or `Decoupled=True`. Rerun stage 2 if you switch the service type.* +> [!TIP] +> Both streaming and offline (non-streaming) TTS modes are supported. For streaming TTS, set `Decoupled=True`. For offline TTS, set `Decoupled=False`. You need to rerun stage 2 if you switch between modes. ### Single-Utterance HTTP Client -Send a single HTTP inference request: + +Sends a single HTTP inference request. This is intended for testing the offline TTS mode (`Decoupled=False`): ```sh bash run.sh 4 4 ``` -### Benchmark with a Dataset -Benchmark the running Triton server. Pass either `streaming` or `offline` as the third argument. -```sh -bash run.sh 5 5 +### Benchmark with client-server mode -# You can also customise parameters such as num_task and dataset split directly: +To benchmark the running Triton server, pass `streaming` or `offline` as the third argument: +```sh +bash run.sh 5 5 # [streaming|offline] + +# You can also customize parameters such as the number of tasks and the dataset split: # python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline] ``` > [!TIP] -> Only offline CosyVoice TTS is currently supported. Setting the client to `streaming` simply enables NVIDIA Triton’s decoupled mode so that responses are returned as soon as they are ready. +> It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up. -### Benchmark Results -Decoding on a single L20 GPU with 26 prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts) (≈221 s of audio): - -| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF | -|------|------|-------------|------------------|------------------|-----| -| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 | -| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 | -| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 | -| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 659.87 | 655.63 | 0.0891 | -| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1103.16 | 992.96 | 0.0693 | -| Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1790.91 | 1668.63 | 0.0604 | - -### OpenAI-Compatible Server -To launch an OpenAI-compatible service, run: +### Benchmark with offline inference mode +For offline inference mode benchmark, please check the below command: ```sh -git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git -pip install -r requirements.txt -# After the Triton service is up, start the FastAPI bridge: -python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000 -# Test with curl -bash test/test_cosyvoice.sh +# install FlashCosyVoice for token2wav batching +# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt +# cd /workspace/FlashCosyVoice +# pip install -e . +# cd - +# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx + +bash run.sh 6 6 + +# You can also switch to huggingface backend by setting backend=hf ``` -### Acknowledgements -This section originates from the NVIDIA CISI project. We also provide other multimodal resources—see [mair-hub](https://github.com/nvidia-china-sae/mair-hub) for details. + +### Benchmark Results +The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio): + +**Client-Server Mode: Streaming TTS (First Chunk Latency)** +| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF | +|---|---|---|---|---| +| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 | +| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 | +| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 | +| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 | +| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 | +| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 | + +> If your service only needs a fixed speaker, you can set `use_spk2info_cache=True` in `run.sh`. To add more speakers, refer to the instructions [here](https://github.com/qi-hua/async_cosyvoice?tab=readme-ov-file#9-spk2info-%E8%AF%B4%E6%98%8E). + +**Client-Server Mode: Offline TTS (Full Sentence Latency)** +| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF | +|---|---|---|---|---|---| +| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 | +| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 | +| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 | + +**Offline Inference Mode: Hugginface LLM V.S. TensorRT-LLM** +| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF | +|---------|------------|------------------|-----------------------|--| +| HF | 1 | 39.26 | 44.31 | 0.2494 | +| HF | 2 | 30.54 | 35.62 | 0.2064 | +| HF | 4 | 18.63 | 23.90 | 0.1421 | +| HF | 8 | 11.22 | 16.45 | 0.0947 | +| HF | 16 | 8.42 | 13.78 | 0.0821 | +| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 | +| TRTLLM | 2 | 7.64 |12.65 | 0.0739 | +| TRTLLM | 4 | 4.89 | 9.38 | 0.0539 | +| TRTLLM | 8 | 2.92 | 7.23 | 0.0418 | +| TRTLLM | 16 | 2.01 | 6.63 | 0.0386 | +### OpenAI-Compatible Server + +To launch an OpenAI-compatible API service, run the following commands: +```sh +git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git +cd Triton-OpenAI-Speech +pip install -r requirements.txt + +# After the Triton service is running, start the FastAPI bridge: +python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000 + +# Test the service with curl: +bash test/test_cosyvoice.sh +``` +> [!NOTE] +> Currently, only the offline TTS mode is compatible with the OpenAI-compatible server. + +### Acknowledgements + +This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub). diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 881b519..994a401 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -257,7 +257,13 @@ def get_args(): default=0.1, help="Chunk overlap duration for streaming reconstruction (in seconds)." ) - # --- End Added arguments --- + + parser.add_argument( + "--use-spk2info-cache", + type=bool, + default=False, + help="Use spk2info cache for reference audio.", + ) return parser.parse_args() @@ -283,7 +289,8 @@ def prepare_request_input_output( reference_text, target_text, sample_rate=16000, - padding_duration: int = None # Optional padding for offline mode + padding_duration: int = None, # Optional padding for offline mode + use_spk2info_cache: bool = False ): """Prepares inputs for Triton inference (offline or streaming).""" assert len(waveform.shape) == 1, "waveform should be 1D" @@ -330,7 +337,8 @@ def prepare_request_input_output( inputs[3].set_data_from_numpy(input_data_numpy) outputs = [protocol_client.InferRequestedOutput("waveform")] - + if use_spk2info_cache: + inputs = inputs[-1:] return inputs, outputs @@ -395,38 +403,45 @@ def run_sync_streaming_inference( # Reconstruct audio using cross-fade (from client_grpc_streaming.py) actual_duration = 0 if audios: - cross_fade_samples = int(chunk_overlap_duration * save_sample_rate) - fade_out = np.linspace(1, 0, cross_fade_samples) - fade_in = np.linspace(0, 1, cross_fade_samples) - reconstructed_audio = None + # Only spark_tts model uses cross-fade + if model_name == "spark_tts": + cross_fade_samples = int(chunk_overlap_duration * save_sample_rate) + fade_out = np.linspace(1, 0, cross_fade_samples) + fade_in = np.linspace(0, 1, cross_fade_samples) + reconstructed_audio = None - # Simplified reconstruction based on client_grpc_streaming.py - if not audios: - print("Warning: No audio chunks received.") - reconstructed_audio = np.array([], dtype=np.float32) # Empty array - elif len(audios) == 1: - reconstructed_audio = audios[0] + # Simplified reconstruction based on client_grpc_streaming.py + if not audios: + print("Warning: No audio chunks received.") + reconstructed_audio = np.array([], dtype=np.float32) # Empty array + elif len(audios) == 1: + reconstructed_audio = audios[0] + else: + reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap + for i in range(1, len(audios)): + # Cross-fade section + cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + + audios[i - 1][-cross_fade_samples:] * fade_out) + # Middle section of the current chunk + middle_part = audios[i][cross_fade_samples:-cross_fade_samples] + # Concatenate + reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) + # Add the last part of the final chunk + reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) + + if reconstructed_audio is not None and reconstructed_audio.size > 0: + actual_duration = len(reconstructed_audio) / save_sample_rate + # Save reconstructed audio + sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") + else: + print("Warning: No audio chunks received or reconstructed.") + actual_duration = 0 # Set duration to 0 if no audio else: - reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap - for i in range(1, len(audios)): - # Cross-fade section - cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + - audios[i - 1][-cross_fade_samples:] * fade_out) - # Middle section of the current chunk - middle_part = audios[i][cross_fade_samples:-cross_fade_samples] - # Concatenate - reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) - # Add the last part of the final chunk - reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) - - if reconstructed_audio is not None and reconstructed_audio.size > 0: + reconstructed_audio = np.concatenate(audios) + print(f"reconstructed_audio: {reconstructed_audio.shape}") actual_duration = len(reconstructed_audio) / save_sample_rate # Save reconstructed audio - os.makedirs(os.path.dirname(audio_save_path), exist_ok=True) sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") - else: - print("Warning: No audio chunks received or reconstructed.") - actual_duration = 0 # Set duration to 0 if no audio else: print("Warning: No audio chunks received.") @@ -446,6 +461,7 @@ async def send_streaming( save_sample_rate: int = 16000, chunk_overlap_duration: float = 0.1, padding_duration: int = None, + use_spk2info_cache: bool = False, ): total_duration = 0.0 latency_data = [] @@ -471,7 +487,8 @@ async def send_streaming( reference_text, target_text, sample_rate, - padding_duration=padding_duration + padding_duration=padding_duration, + use_spk2info_cache=use_spk2info_cache ) request_id = str(uuid.uuid4()) user_data = UserData() @@ -527,6 +544,7 @@ async def send( padding_duration: int = None, audio_save_dir: str = "./", save_sample_rate: int = 16000, + use_spk2info_cache: bool = False, ): total_duration = 0.0 latency_data = [] @@ -545,7 +563,8 @@ async def send( reference_text, target_text, sample_rate, - padding_duration=padding_duration + padding_duration=padding_duration, + use_spk2info_cache=use_spk2info_cache ) sequence_id = 100000000 + i + task_id * 10 start = time.time() @@ -667,6 +686,7 @@ async def main(): manifest_item_list = split_data(manifest_item_list, num_tasks) os.makedirs(args.log_dir, exist_ok=True) + tasks = [] start_time = time.time() for i in range(num_tasks): @@ -683,6 +703,7 @@ async def main(): audio_save_dir=args.log_dir, padding_duration=1, save_sample_rate=16000 if args.model_name == "spark_tts" else 24000, + use_spk2info_cache=args.use_spk2info_cache, ) ) elif args.mode == "streaming": @@ -698,6 +719,7 @@ async def main(): padding_duration=10, save_sample_rate=16000 if args.model_name == "spark_tts" else 24000, chunk_overlap_duration=args.chunk_overlap_duration, + use_spk2info_cache=args.use_spk2info_cache, ) ) # --- End Task Creation --- diff --git a/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py index 47383e2..339cb0e 100644 --- a/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py +++ b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py @@ -32,7 +32,7 @@ import triton_python_backend_utils as pb_utils import os import numpy as np import s3tokenizer - +torch.set_num_threads(1) ORIGINAL_VOCAB_SIZE = 151663 diff --git a/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt b/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt index 6d8bd0c..4bb972c 100644 --- a/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt @@ -20,7 +20,7 @@ dynamic_batching { } parameters [ { - key: "model_dir", + key: "model_dir", value: {string_value:"${model_dir}"} } ] diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py index 77a440b..97659ad 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py @@ -28,6 +28,8 @@ import json import math import os import re +import threading +import time from typing import Dict, List, Tuple, Optional, Union import numpy as np @@ -35,13 +37,15 @@ import torch from torch.utils.dlpack import from_dlpack, to_dlpack import triton_python_backend_utils as pb_utils from transformers import AutoTokenizer -import torchaudio.compliance.kaldi as kaldi + import torchaudio -import onnxruntime from matcha.utils.audio import mel_spectrogram +ORIGINAL_VOCAB_SIZE = 151663 +torch.set_num_threads(1) + class TritonPythonModel: """Triton Python model for Spark TTS. @@ -62,6 +66,8 @@ class TritonPythonModel: parameters = self.model_config['parameters'] model_params = {k: v["string_value"] for k, v in parameters.items()} self.logger.log_info(f"model_params:{model_params}") + self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" + self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") # Initialize tokenizer llm_tokenizer_dir = model_params["llm_tokenizer_dir"] @@ -72,11 +78,15 @@ class TritonPythonModel: self.device = torch.device("cuda") self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) - campplus_model = f'{model_params["model_dir"]}/campplus.onnx' - option = onnxruntime.SessionOptions() - option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - option.intra_op_num_threads = 1 - self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"]) + self.token_frame_rate = 25 + self.flow_pre_lookahead_len = 3 + self.token_hop_len = 15 + + spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt") + if not os.path.exists(spk_info_path): + raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") + spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) + self.default_spk_info = spk_info["001"] def forward_llm(self, input_ids): """ @@ -105,7 +115,7 @@ class TritonPythonModel: """ # convert input_ids to numpy, with shape [1, sequence_length] input_ids = input_ids.cpu().numpy() - max_tokens = 1024 + max_tokens = 750 input_dict = { "request_output_len": np.array([[max_tokens]], dtype=np.int32), "end_id": np.array([[self.eos_token_id]], dtype=np.int32), @@ -114,6 +124,8 @@ class TritonPythonModel: "runtime_top_p": np.array([[0.95]], dtype=np.float32), "runtime_top_k": np.array([[50]], dtype=np.int32), "temperature": np.array([[0.8]], dtype=np.float32), + "repetition_penalty": np.array([[1.1]], dtype=np.float32), + "random_seed": np.array([[42]], dtype=np.uint64), "input_ids": input_ids, "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), } @@ -188,12 +200,40 @@ class TritonPythonModel: return prompt_speech_tokens + def forward_speaker_embedding(self, wav): + """Forward pass through the speaker embedding component. + + Args: + wav: Input waveform tensor + + Returns: + Prompt speaker embedding tensor + """ + inference_request = pb_utils.InferenceRequest( + model_name='speaker_embedding', + requested_output_names=['prompt_spk_embedding'], + inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding') + prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack()) + + return prompt_spk_embedding + def forward_token2wav( self, - prompt_speech_tokens: torch.Tensor, - prompt_speech_feat: torch.Tensor, - prompt_spk_embedding: torch.Tensor, - target_speech_tokens: torch.Tensor) -> torch.Tensor: + target_speech_tokens: torch.Tensor, + request_id: str, + prompt_speech_tokens: torch.Tensor = None, + prompt_speech_feat: torch.Tensor = None, + prompt_spk_embedding: torch.Tensor = None, + token_offset: int = None, + finalize: bool = None) -> torch.Tensor: """Forward pass through the vocoder component. Args: @@ -205,16 +245,30 @@ class TritonPythonModel: Returns: Generated waveform tensor """ - prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) - prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) - prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) + inputs_tensor = [target_speech_tokens_tensor] + + if token_offset is not None: + assert finalize is not None + token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32)) + finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) + inputs_tensor.append(token_offset_tensor) + inputs_tensor.append(finalize_tensor) + + if prompt_spk_embedding is not None: + assert prompt_speech_feat is not None + prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) + prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) + prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) + inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor]) + # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav', requested_output_names=['waveform'], - inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor] + inputs=inputs_tensor, + request_id=request_id, ) inference_response = inference_request.exec() @@ -235,17 +289,6 @@ class TritonPythonModel: input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1) return input_ids - def _extract_spk_embedding(self, speech): - feat = kaldi.fbank(speech, - num_mel_bins=80, - dither=0, - sample_frequency=16000) - feat = feat - feat.mean(dim=0, keepdim=True) - embedding = self.campplus_session.run(None, - {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() - embedding = torch.tensor([embedding]).to(self.device).half() - return embedding - def _extract_speech_feat(self, speech): speech_feat = mel_spectrogram( speech, @@ -263,6 +306,14 @@ class TritonPythonModel: speech_feat = speech_feat.unsqueeze(dim=0) return speech_feat + def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag): + for generated_ids in generated_ids_iter: + generated_ids = generated_ids.tolist() + if len(generated_ids) == 0: + break + semantic_token_ids_arr.extend(generated_ids) + llm_is_done_flag[0] = True + def execute(self, requests): """Execute inference on the batched requests. @@ -275,25 +326,33 @@ class TritonPythonModel: responses = [] for request in requests: + request_id = request.request_id() # Extract input tensors wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") - wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") # Process reference audio through audio tokenizer + if wav is not None: + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) + prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) - prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) + wav_tensor = wav.as_numpy() + wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) + speech_feat = self._extract_speech_feat(prompt_speech_resample) + token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) + prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() + prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() - wav_tensor = wav.as_numpy() - wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) - speech_feat = self._extract_speech_feat(prompt_speech_resample) - token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) - prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() - prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() - - reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() - reference_text = reference_text[0][0].decode('utf-8') + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') + prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + else: + # using pre-cached reference text + reference_text = self.default_spk_info["prompt_text"] + prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + prompt_speech_feat = None + prompt_spk_embedding = None target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode('utf-8') @@ -310,22 +369,73 @@ class TritonPythonModel: if self.decoupled: response_sender = request.get_response_sender() - request_id = request.request_id() - generated_ids = [] - for generated_id in generated_ids_iter: - # convert the numpy array into a int32 tensor - generated_id = generated_id.tolist() - if len(generated_id) > 0: - assert len(generated_id) == 1, "Generated ID is not a single integer" - generated_ids.append(generated_id[0]) - generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device) - prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) - audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) - # Prepare response - audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) + semantic_token_ids_arr = [] + llm_is_done_flag = [False] + + llm_thread = threading.Thread( + target=self._llm_gen_thread, + args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag) + ) + + llm_thread.start() + + token_offset, chunk_index = 0, 0 + start_time = time.time() + this_token_hop_len = self.token_hop_len + + while True: + pending_num = len(semantic_token_ids_arr) - token_offset + + if llm_is_done_flag[0]: + break + + if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: + this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] + this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) + + sub_tts_speech = self.forward_token2wav( + this_tts_speech_token, request_id, prompt_speech_tokens, + prompt_speech_feat, prompt_spk_embedding, token_offset, False + ) + + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + + token_offset += this_token_hop_len + self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}") + + if self.dynamic_chunk_strategy == "exponential": + this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) + elif self.dynamic_chunk_strategy == "time_based": + # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 + cost_time = time.time() - start_time + duration = token_offset / self.token_frame_rate + if chunk_index > 0 and cost_time > 0: + avg_chunk_processing_time = cost_time / (chunk_index + 1) + if avg_chunk_processing_time > 0: + multiples = (duration - cost_time) / avg_chunk_processing_time + self.logger.log_info(f"multiples: {multiples}") + next_pending_num = len(semantic_token_ids_arr) - token_offset + if multiples > 4: + this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len + elif multiples > 2: + this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len + else: + this_token_hop_len = self.token_hop_len + this_token_hop_len = max(self.token_hop_len, this_token_hop_len) + chunk_index += 1 + else: + time.sleep(0.02) + + this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device) + sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True) + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) + + llm_thread.join() response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) self.logger.log_info("send tritonserver_response_complete_final to end") else: @@ -334,8 +444,7 @@ class TritonPythonModel: if generated_ids is None or len(generated_ids) == 0: raise pb_utils.TritonModelException("Generated IDs is None or empty") - prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) - audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) + audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding) # Prepare response audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt b/runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt index c370336..73a9a05 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt @@ -23,11 +23,11 @@ model_transaction_policy { } parameters [ { - key: "llm_tokenizer_dir", + key: "llm_tokenizer_dir", value: {string_value:"${llm_tokenizer_dir}"} }, { - key: "model_dir", + key: "model_dir", value: {string_value:"${model_dir}"} } ] @@ -37,16 +37,19 @@ input [ name: "reference_wav" data_type: TYPE_FP32 dims: [-1] + optional: true }, { name: "reference_wav_len" data_type: TYPE_INT32 dims: [1] + optional: true }, { name: "reference_text" data_type: TYPE_STRING dims: [1] + optional: true }, { name: "target_text" diff --git a/runtime/triton_trtllm/model_repo/speaker_embedding/1/model.py b/runtime/triton_trtllm/model_repo/speaker_embedding/1/model.py new file mode 100644 index 0000000..1a7293a --- /dev/null +++ b/runtime/triton_trtllm/model_repo/speaker_embedding/1/model.py @@ -0,0 +1,153 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import json +import torch +from torch.utils.dlpack import to_dlpack + +import triton_python_backend_utils as pb_utils + +import os +import numpy as np +import torchaudio.compliance.kaldi as kaldi +from cosyvoice.utils.file_utils import convert_onnx_to_trt +from cosyvoice.utils.common import TrtContextWrapper +import onnxruntime + + +class TritonPythonModel: + """Triton Python model for audio tokenization. + + This model takes reference audio input and extracts semantic tokens + using s3tokenizer. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + # Parse model parameters + parameters = json.loads(args['model_config'])['parameters'] + model_params = {k: v["string_value"] for k, v in parameters.items()} + + self.device = torch.device("cuda") + + model_dir = model_params["model_dir"] + gpu = "l20" + enable_trt = True + if enable_trt: + self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', + f'{model_dir}/campplus.onnx', + 1, + False) + else: + campplus_model = f'{model_dir}/campplus.onnx' + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.spk_model = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"]) + + def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): + if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: + trt_kwargs = self.get_spk_trt_kwargs() + convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16) + import tensorrt as trt + with open(spk_model, 'rb') as f: + spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert spk_engine is not None, 'failed to load trt {}'.format(spk_model) + self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_spk_trt_kwargs(self): + min_shape = [(1, 4, 80)] + opt_shape = [(1, 500, 80)] + max_shape = [(1, 3000, 80)] + input_names = ["input"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def _extract_spk_embedding(self, speech): + feat = kaldi.fbank(speech, + num_mel_bins=80, + dither=0, + sample_frequency=16000) + spk_feat = feat - feat.mean(dim=0, keepdim=True) + + if isinstance(self.spk_model, onnxruntime.InferenceSession): + embedding = self.spk_model.run( + None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} + )[0].flatten().tolist() + embedding = torch.tensor([embedding]).to(self.device) + else: + [spk_model, stream], trt_engine = self.spk_model.acquire_estimator() + # NOTE need to synchronize when switching stream + with torch.cuda.device(self.device): + torch.cuda.current_stream().synchronize() + spk_feat = spk_feat.unsqueeze(dim=0).to(self.device) + batch_size = spk_feat.size(0) + + with stream: + spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80)) + embedding = torch.empty((batch_size, 192), device=spk_feat.device) + + data_ptrs = [spk_feat.contiguous().data_ptr(), + embedding.contiguous().data_ptr()] + for i, j in enumerate(data_ptrs): + + spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.spk_model.release_estimator(spk_model, stream) + + return embedding.half() + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing tokenized outputs + """ + responses = [] + # Process each request in batch + for request in requests: + # Extract input tensors + wav_array = pb_utils.get_input_tensor_by_name( + request, "reference_wav").as_numpy() + wav_array = torch.from_numpy(wav_array).to(self.device) + + embedding = self._extract_spk_embedding(wav_array) + + prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack( + "prompt_spk_embedding", to_dlpack(embedding)) + inference_response = pb_utils.InferenceResponse( + output_tensors=[prompt_spk_embedding_tensor]) + + responses.append(inference_response) + + return responses diff --git a/runtime/triton_trtllm/model_repo/speaker_embedding/config.pbtxt b/runtime/triton_trtllm/model_repo/speaker_embedding/config.pbtxt new file mode 100644 index 0000000..fd91bd6 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/speaker_embedding/config.pbtxt @@ -0,0 +1,48 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "speaker_embedding" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +parameters [ + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "reference_wav" + data_type: TYPE_FP32 + dims: [-1] + } +] +output [ + { + name: "prompt_spk_embedding" + data_type: TYPE_FP16 + dims: [-1] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/token2wav/1/model.py b/runtime/triton_trtllm/model_repo/token2wav/1/model.py index d38f8a4..10bc272 100644 --- a/runtime/triton_trtllm/model_repo/token2wav/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav/1/model.py @@ -32,22 +32,27 @@ from typing import List, Dict import torch from torch.utils.dlpack import to_dlpack +from torch.nn import functional as F import triton_python_backend_utils as pb_utils from hyperpyyaml import load_hyperpyyaml +from cosyvoice.utils.common import fade_in_out from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm from cosyvoice.utils.common import TrtContextWrapper +from collections import defaultdict +import numpy as np logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) ORIGINAL_VOCAB_SIZE = 151663 +torch.set_num_threads(1) class CosyVoice2: - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'): self.model_dir = model_dir self.fp16 = fp16 @@ -57,7 +62,7 @@ class CosyVoice2: raise ValueError('{} not found!'.format(hyper_yaml_path)) with open(hyper_yaml_path, 'r') as f: configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) - self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16) + self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device) self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) if load_jit: self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) @@ -73,14 +78,22 @@ class CosyVoice2Model: def __init__(self, flow: torch.nn.Module, hift: torch.nn.Module, - fp16: bool = False): - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + fp16: bool = False, + device: str = 'cuda'): + self.device = device self.flow = flow self.hift = hift self.fp16 = fp16 if self.fp16 is True: self.flow.half() + # streaming tts config + self.token_hop_len = 25 + self.mel_cache_len = 8 + self.source_cache_len = int(self.mel_cache_len * 480) + self.speech_window = np.hamming(2 * self.source_cache_len) + self.hift_cache_dict = defaultdict(lambda: None) + def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder @@ -111,6 +124,42 @@ class CosyVoice2Model: input_names = ["x", "mask", "mu", "cond"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): + with torch.cuda.amp.autocast(self.fp16): + tts_mel, _ = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + streaming=stream, + finalize=finalize) + tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] + # append hift cache + if self.hift_cache_dict[uuid] is not None: + hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] + tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = torch.zeros(1, 1, 0) + # keep overlap mel and hift cache + if finalize is False: + tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) + self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:], + 'source': tts_source[:, :, -self.source_cache_len:], + 'speech': tts_speech[:, -self.source_cache_len:]} + tts_speech = tts_speech[:, :-self.source_cache_len] + else: + if speed != 1.0: + assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode' + tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear') + tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) + return tts_speech + class TritonPythonModel: """Triton Python model for vocoder. @@ -131,13 +180,19 @@ class TritonPythonModel: model_dir = model_params["model_dir"] # Initialize device and vocoder - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger.info(f"Initializing vocoder from {model_dir} on {self.device}") self.token2wav_model = CosyVoice2( - model_dir, load_jit=True, load_trt=True, fp16=True + model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device ) + spk_info_path = os.path.join(model_dir, "spk2info.pt") + if not os.path.exists(spk_info_path): + raise ValueError(f"spk2info.pt not found in {model_dir}") + spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) + self.default_spk_info = spk_info["001"] + logger.info("Token2Wav initialized successfully") def execute(self, requests): @@ -153,38 +208,66 @@ class TritonPythonModel: # Process each request in batch for request in requests: target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() - prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy() - prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy() - prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy() - target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device) - prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device) - prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device) - prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device) + + prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens") + if prompt_speech_tokens_tensor is not None: + prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy() + prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy() + prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy() + prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device) + prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device) + prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device) + prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE + else: + prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device) + prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device) + prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device) # shift the speech tokens according to the original vocab size - prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE - tts_mel, _ = self.token2wav_model.model.flow.inference( - token=target_speech_tokens, - token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( - self.device - ), - prompt_token=prompt_speech_tokens, - prompt_token_len=torch.tensor( - [prompt_speech_tokens.shape[1]], dtype=torch.int32 - ).to(self.device), - prompt_feat=prompt_speech_feat, - prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=prompt_spk_embedding, - streaming=False, - finalize=True, - ) + # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. + token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset") + if token_offset is not None: + token_offset = token_offset.as_numpy().item() + finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() + if not finalize: + stream = True + else: + stream = False + request_id = request.request_id() + audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens, + prompt_token=prompt_speech_tokens, + prompt_feat=prompt_speech_feat, + embedding=prompt_spk_embedding, + token_offset=token_offset, + uuid=request_id, + stream=stream, + finalize=finalize) + if finalize: + self.token2wav_model.model.hift_cache_dict.pop(request_id) - audio_hat, _ = self.token2wav_model.model.hift.inference( - speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) - ) + else: + tts_mel, _ = self.token2wav_model.model.flow.inference( + token=target_speech_tokens, + token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( + self.device + ), + prompt_token=prompt_speech_tokens, + prompt_token_len=torch.tensor( + [prompt_speech_tokens.shape[1]], dtype=torch.int32 + ).to(self.device), + prompt_feat=prompt_speech_feat, + prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=prompt_spk_embedding, + streaming=False, + finalize=True, + ) + + audio_hat, _ = self.token2wav_model.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) generated_wave = audio_hat.squeeze(0).cpu().numpy() diff --git a/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt index 36489ff..c33a85f 100644 --- a/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt @@ -20,7 +20,7 @@ dynamic_batching { } parameters [ { - key: "model_dir", + key: "model_dir", value: {string_value:"${model_dir}"} } ] @@ -35,16 +35,33 @@ input [ name: "prompt_speech_tokens" data_type: TYPE_INT32 dims: [-1] + optional: true }, { name: "prompt_speech_feat" data_type: TYPE_FP16 dims: [-1, 80] + optional: true }, { name: "prompt_spk_embedding" data_type: TYPE_FP16 dims: [-1] + optional: true + }, + { + name: "token_offset" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "finalize" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true } ] output [ diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py new file mode 100644 index 0000000..6f1a836 --- /dev/null +++ b/runtime/triton_trtllm/offline_inference.py @@ -0,0 +1,563 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage + CUDA_VISIBLE_DEVICES=0 \ + python3 offline_inference.py \ + --output-dir $output_dir \ + --llm-model-name-or-path $huggingface_model_local_dir \ + --token2wav-path $model_scope_model_local_dir \ + --backend $backend \ + --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \ + --engine-dir $trt_engines_dir \ + --split-name ${dataset} || exit 1 +""" + +import argparse +import json +import os +import sys +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torchaudio +from cosyvoice.utils.file_utils import load_wav +from datasets import load_dataset +from transformers import AutoTokenizer +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +import soundfile as sf +import s3tokenizer +from functools import partial +import time + +from token2wav import CosyVoice2_Token2Wav + +sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") +try: + torch.multiprocessing.set_start_method("spawn") +except RuntimeError: + pass + + +def extract_speech_ids(speech_tokens_str): + """Extract speech IDs from token strings like <|s_23456|>""" + speech_ids = [] + for token_str in speech_tokens_str: + if token_str.startswith('<|s_') and token_str.endswith('|>'): + num_str = token_str[4:-2] + num = int(num_str) + speech_ids.append(num) + else: + print(f"Unexpected token: {token_str}") + return speech_ids + + +def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens): + """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>""" + speech_id_str = "" + for token in cosy2_tokens: + speech_id_str += f"<|s_{token}|>" + return speech_id_str + + +def get_args(): + parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2") + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2", + ) + parser.add_argument( + "--output-dir", required=True, type=str, help="dir to save result" + ) + parser.add_argument( + "--batch-size", + default=1, + type=int, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--token2wav-batch-size", + default=1, + type=int, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--num-workers", type=int, default=0, help="workers for dataloader" + ) + parser.add_argument( + "--prefetch", type=int, default=None, help="prefetch for dataloader" + ) + parser.add_argument( + "--llm-model-name-or-path", + required=True, + type=str, + help="LLM model path (includes both model and tokenizer)", + ) + parser.add_argument( + "--token2wav-path", + required=True, + type=str, + help="CosyVoice2 token2wav model path", + ) + parser.add_argument( + "--prompt-text", + type=str, + default=None, + help="The prompt text for CosyVoice2", + ) + parser.add_argument( + "--prompt-speech-path", + type=str, + default=None, + help="The path to the prompt speech for CosyVoice2", + ) + parser.add_argument( + "--top-p", + type=float, + default=0.95, + help="top p for sampling", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="temperature for sampling", + ) + parser.add_argument( + "--top-k", + type=int, + default=50, + help="top k for sampling", + ) + parser.add_argument( + "--backend", + type=str, + default="hf", + choices=["hf", "trtllm", "vllm"], + help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM", + ) + parser.add_argument( + "--engine-dir", + type=str, + default=None, + help="TensorRT-LLM engine directory (required when backend is 'trtllm')", + ) + parser.add_argument( + "--kv-cache-free-gpu-memory-fraction", + type=float, + default=0.6, + help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)", + ) + args = parser.parse_args() + return args + + +def data_collator(batch, tokenizer, s3_tokenizer): + """Simplified data collator for batch_size=1 processing""" + collator_start_time = time.time() + total_audio_processing_time = 0 + total_speech_tokenization_time = 0 + total_text_tokenization_time = 0 + + target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio + device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu") + input_ids_list, prompt_audio_list, prompt_text_list = [], [], [] + prompt_text_after_apply_template_list = [] + mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], [] + for _, item in enumerate(batch): + audio_processing_start_time = time.time() + prompt_text, target_text = ( + item["prompt_text"], + item["target_text"], + ) + prompt_text_list.append(prompt_text) + full_text = prompt_text + target_text + full_text_list.append(full_text) + # remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset + puncts = ['"', '(', ')', '“', '”', '‘', '(', ')', '\''] + for p in puncts: + if p in full_text: + full_text = full_text.replace(p, '') + print(f"removed {p} from {full_text}") + + # get prompt audio for CosyVoice2 (convert to 16kHz) + ref_audio_org, ref_sr = ( + item["prompt_audio"]["array"], + item["prompt_audio"]["sampling_rate"], + ) + ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0) + print(ref_audio_org.shape) + + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + + prompt_audio_list.append(ref_audio) + audio_processing_end_time = time.time() + total_audio_processing_time += audio_processing_end_time - audio_processing_start_time + + speech_tokenization_start_time = time.time() + if "prompt_audio_cosy2_tokens" in item: + prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"] + prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens) + else: + mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0))) + + if len(mels) > 0: + mels, mels_lens = s3tokenizer.padding(mels) + codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device)) + for i in range(len(codes)): + prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()]) + speech_tokenization_end_time = time.time() + total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time + + for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list): + text_tokenization_start_time = time.time() + prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens) + # Create chat template for LLM generation + chat = [ + {"role": "user", "content": full_text_list[i]}, + {"role": "assistant", "content": prompt_audio_cosy2_id_str} + ] + + assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template" + + input_ids = tokenizer.apply_chat_template( + chat, + tokenize=True, + return_tensors='pt', + continue_final_message=True + ) + input_ids_list.append(input_ids.squeeze(0)) + + prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}" + + prompt_text_after_apply_template_list.append(prompt_text_after_apply_template) + text_tokenization_end_time = time.time() + total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time + + ids = [item["id"] for item in batch] + + return { + "input_ids": input_ids_list, + "ids": ids, + "prompt_text": prompt_text_list, + "prompt_audio_list": prompt_audio_list, + "prompt_text_after_apply_template": prompt_text_after_apply_template_list, + "audio_processing_time": total_audio_processing_time, + "speech_tokenization_time": total_speech_tokenization_time, + "text_tokenization_time": total_text_tokenization_time, + } + + +def init_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + print( + "Inference on multiple gpus, this gpu {}".format(local_rank) + + ", rank {}, world_size {}".format(rank, world_size) + ) + torch.cuda.set_device(local_rank) + dist.init_process_group("nccl") + return world_size, local_rank, rank + + +def main(args): + os.makedirs(args.output_dir, exist_ok=True) + + assert torch.cuda.is_available() + local_rank, world_size, rank = 0, 1, 0 + device = torch.device(f"cuda:{local_rank}") + + tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path) + + if args.backend == "hf": + model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path) + model.eval() + model.to(device) + runner = None + elif args.backend == "trtllm": + if args.engine_dir is None: + raise ValueError("--engine-dir is required when backend is 'trtllm'") + + runtime_rank = tensorrt_llm.mpi_rank() + model = None + + runner_kwargs = dict( + engine_dir=args.engine_dir, + rank=runtime_rank, + max_output_len=2048, + enable_context_fmha_fp32_acc=False, + max_batch_size=args.batch_size, + max_input_len=512, + kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction, + cuda_graph_mode=False, + gather_generation_logits=False, + ) + + runner = ModelRunnerCpp.from_dir(**runner_kwargs) + elif args.backend == "vllm": + model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4) + runner = None + else: + raise ValueError(f"Unsupported backend: {args.backend}") + + token2wav_model = CosyVoice2_Token2Wav( + model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank + ) + if args.prompt_speech_path: + prompt_speech_16k = load_wav(args.prompt_speech_path, 16000) + else: + prompt_speech_16k = None + s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None + dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2" + dataset = load_dataset( + dataset_name, + split=args.split_name, + trust_remote_code=True, + ) + + sampler = None + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + shuffle=False, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer), + ) + for _ in range(3): + print(f"Running {_} times") + total_llm_time = 0 + total_token2wav_time = 0 + total_data_load_time = 0 + total_llm_post_processing_time = 0 + total_audio_save_time = 0 + total_audio_processing_time_in_collator = 0 + total_speech_tokenization_time_in_collator = 0 + total_text_tokenization_time_in_collator = 0 + total_audio_samples = 0 + start_time = time.time() + total_steps = len(dataset) + + if rank == 0: + progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") + + last_batch_end_time = time.time() + for batch in dataloader: + data_loaded_time = time.time() + total_data_load_time += data_loaded_time - last_batch_end_time + total_audio_processing_time_in_collator += batch["audio_processing_time"] + total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"] + total_text_tokenization_time_in_collator += batch["text_tokenization_time"] + with torch.no_grad(): + llm_start_time = time.time() + if args.backend == "hf": + input_ids_list = batch["input_ids"] + if len(input_ids_list) == 1: + input_ids = input_ids_list[0].unsqueeze(0) + attention_mask = torch.ones_like(input_ids) + else: + max_len = max([len(input_ids) for input_ids in input_ids_list]) + input_ids_list_new = [ + torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)]) + for input_ids in input_ids_list + ] + input_ids = torch.stack(input_ids_list_new) + attention_mask = torch.zeros_like(input_ids) + for i in range(len(input_ids_list)): + attention_mask[i, :len(input_ids_list[i])] = 1 + + input_ids = input_ids.to(device) + + outputs = model.generate( + input_ids=input_ids.to(device), + attention_mask=attention_mask.to(device), + max_new_tokens=2048, + do_sample=True, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=1.1, + top_k=args.top_k, + ) + torch.cuda.synchronize() + elif args.backend == "trtllm": + batch_input_ids = list(batch["input_ids"]) + input_lengths = [x.size(0) for x in batch_input_ids] + + end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id + print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================") + outputs = runner.generate( + batch_input_ids=batch_input_ids, + max_new_tokens=2048, + end_id=end_id, + pad_id=end_id, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + repetition_penalty=1.1, + num_return_sequences=1, + streaming=False, + output_sequence_lengths=True, + output_generation_logits=False, + return_dict=True, + return_all_generated_tokens=False + ) + torch.cuda.synchronize() + output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"] + num_output_sents, num_beams, _ = output_ids.size() + assert num_beams == 1 + beam = 0 + batch_size = len(batch["input_ids"]) + num_return_sequences = num_output_sents // batch_size + assert num_return_sequences == 1 + outputs = [] + for i in range(batch_size * num_return_sequences): + batch_idx = i // num_return_sequences + seq_idx = i % num_return_sequences + output_begin = input_lengths[batch_idx] + output_end = sequence_lengths[i][beam] + outputs_i = output_ids[i][beam][:output_end].tolist() + outputs.append(outputs_i) + elif args.backend == "vllm": + input_ids_list = [ids.tolist() for ids in batch["input_ids"]] + sampling_params = SamplingParams( + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=1.1, + max_tokens=2048, + ) + outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params) + print(outputs) + for j, output in enumerate(outputs): + outputs[j] = input_ids_list[j] + output.outputs[0].token_ids + + llm_end_time = time.time() + total_llm_time += (llm_end_time - llm_start_time) + + items_for_token_2wav = [] + for i in range(len(batch["ids"])): + llm_post_processing_start_time = time.time() + input_length = len(batch["input_ids"][i]) + generated_ids = outputs[i][input_length:] + speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + speech_ids = extract_speech_ids(speech_tokens_str) + print(i, speech_ids) + if len(speech_ids) == 0: + print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping") + continue + + if args.prompt_text is not None: + current_prompt_text = args.prompt_text + current_prompt_audio = prompt_speech_16k + else: + current_prompt_text = batch["prompt_text"][i] + current_prompt_audio = batch["prompt_audio_list"][i] + + llm_post_processing_end_time = time.time() + total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time + if current_prompt_audio is not None: + items_for_token_2wav.append({ + "speech_ids": speech_ids, + "prompt_audio": current_prompt_audio.squeeze(0), + "id": batch["ids"][i] + }) + else: + print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping") + + for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size): + t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size] + if not t2w_batch: + continue + + t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch] + t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch] + t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch) + t2w_ids = [item["id"] for item in t2w_batch] + + token2wav_start_time = time.time() + generated_wavs = token2wav_model( + t2w_generated_speech_tokens_list, + t2w_prompt_audios_list, + t2w_prompt_audios_sample_rate, + ) + torch.cuda.synchronize() + token2wav_end_time = time.time() + total_token2wav_time += (token2wav_end_time - token2wav_start_time) + + audio_save_start_time = time.time() + for j, audio_hat in enumerate(generated_wavs): + generated_wave = audio_hat.squeeze().cpu().numpy() + total_audio_samples += len(generated_wave) + target_sample_rate = 24000 + + utt = t2w_ids[j] + sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate) + print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens") + audio_save_end_time = time.time() + total_audio_save_time += audio_save_end_time - audio_save_start_time + + if rank == 0: + progress_bar.update(world_size * len(batch["ids"])) + + last_batch_end_time = time.time() + if rank == 0: + progress_bar.close() + end_time = time.time() + target_sample_rate = 24000 + total_audio_duration_seconds = total_audio_samples / target_sample_rate + + log_file_path = os.path.join(args.output_dir, "log.txt") + with open(log_file_path, 'w') as f: + args_dict = vars(args) + log_data = { + "args": args_dict, + "data_load_time_seconds": total_data_load_time, + "audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator, + "speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator, + "text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator, + "llm_time_seconds": total_llm_time, + "llm_post_processing_time_seconds": total_llm_post_processing_time, + "token2wav_time_seconds": total_token2wav_time, + "audio_save_time_seconds": total_audio_save_time, + "total_audio_duration_seconds": total_audio_duration_seconds, + "pipeline_time_seconds": end_time - start_time, + } + print(log_data) + f.write(json.dumps(log_data, indent=4)) + print(f"Metrics logged to {log_file_path}") + + +if __name__ == "__main__": + args = get_args() + if args.backend == "vllm": + from vllm import LLM, SamplingParams + elif args.backend == "trtllm": + import tensorrt_llm + from tensorrt_llm.runtime import ModelRunnerCpp + elif args.backend == "hf": + from transformers import AutoModelForCausalLM + else: + raise ValueError(f"Unsupported backend: {args.backend}") + main(args) diff --git a/runtime/triton_trtllm/run.sh b/runtime/triton_trtllm/run.sh index 2e81896..7c7f3cd 100644 --- a/runtime/triton_trtllm/run.sh +++ b/runtime/triton_trtllm/run.sh @@ -15,6 +15,8 @@ trt_engines_dir=./trt_engines_${trt_dtype} model_repo=./model_repo_cosyvoice2 +use_spk2info_cache=False + if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then echo "Cloning CosyVoice" git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path @@ -25,8 +27,11 @@ fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then echo "Downloading CosyVoice2-0.5B" + # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir + # download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings + wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt fi @@ -57,9 +62,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then cosyvoice2_dir="cosyvoice2" cp -r ./model_repo/${cosyvoice2_dir} $model_repo - cp -r ./model_repo/audio_tokenizer $model_repo cp -r ./model_repo/tensorrt_llm $model_repo cp -r ./model_repo/token2wav $model_repo + if [ $use_spk2info_cache == "False" ]; then + cp -r ./model_repo/audio_tokenizer $model_repo + cp -r ./model_repo/speaker_embedding $model_repo + fi ENGINE_PATH=$trt_engines_dir MAX_QUEUE_DELAY_MICROSECONDS=0 @@ -67,13 +75,15 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then LLM_TOKENIZER_DIR=$huggingface_model_local_dir BLS_INSTANCE_NUM=4 TRITON_MAX_BATCH_SIZE=16 - DECOUPLED_MODE=False + DECOUPLED_MODE=True # True for streaming, False for offline python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} - python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 - + if [ $use_spk2info_cache == "False" ]; then + python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + fi fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then @@ -82,7 +92,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - echo "Single request test http" + echo "Single request test http, only work for offline TTS mode" python3 client_http.py \ --reference-audio ./assets/prompt_audio.wav \ --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ @@ -93,14 +103,40 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then echo "Running benchmark client grpc" num_task=4 - # set mode=streaming, when decoupled=True - # set mode=offline, when decoupled=False - mode=offline + + mode=streaming + BLS_INSTANCE_NUM=4 + python3 client_grpc.py \ --server-addr localhost \ --model-name cosyvoice2 \ --num-tasks $num_task \ --mode $mode \ + --use-spk2info-cache $use_spk2info_cache \ --huggingface-dataset yuekai/seed_tts_cosy2 \ - --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_4_${trt_dtype} -fi \ No newline at end of file + --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache} +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + echo "stage 6: Offline inference benchmark" + n_gpus=1 + datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh + backend=trtllm # hf, trtllm, vllm + + batch_sizes=(16 8 4 2 1) + token2wav_batch_size=1 + for batch_size in ${batch_sizes[@]}; do + for dataset in ${datasets[@]}; do + output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size} + CUDA_VISIBLE_DEVICES=0 \ + python3 offline_inference.py \ + --output-dir $output_dir \ + --llm-model-name-or-path $huggingface_model_local_dir \ + --token2wav-path $model_scope_model_local_dir \ + --backend $backend \ + --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \ + --engine-dir $trt_engines_dir \ + --split-name ${dataset} || exit 1 + done + done +fi diff --git a/runtime/triton_trtllm/token2wav.py b/runtime/triton_trtllm/token2wav.py new file mode 100644 index 0000000..09b6db6 --- /dev/null +++ b/runtime/triton_trtllm/token2wav.py @@ -0,0 +1,335 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage + CUDA_VISIBLE_DEVICES=0 \ + python3 token2wav.py --enable-trt || exit 1 +""" +import torch +from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec +from flashcosyvoice.modules.hifigan import HiFTGenerator +from flashcosyvoice.utils.audio import mel_spectrogram +import torchaudio.compliance.kaldi as kaldi +import onnxruntime +import s3tokenizer +from torch.utils.data import DataLoader +from datasets import load_dataset +import torchaudio +import os +import logging +import argparse +import queue +import time + + +def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): + import tensorrt as trt + logging.info("Converting onnx to trt...") + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + network = builder.create_network(network_flags) + parser = trt.OnnxParser(network, logger) + config = builder.create_builder_config() + # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB + if fp16: + config.set_flag(trt.BuilderFlag.FP16) + profile = builder.create_optimization_profile() + # load onnx model + with open(onnx_model, "rb") as f: + if not parser.parse(f.read()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + raise ValueError('failed to parse {}'.format(onnx_model)) + # set input shapes + for i in range(len(trt_kwargs['input_names'])): + profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) + tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT + # set input and output data type + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + input_tensor.dtype = tensor_dtype + for i in range(network.num_outputs): + output_tensor = network.get_output(i) + output_tensor.dtype = tensor_dtype + config.add_optimization_profile(profile) + engine_bytes = builder.build_serialized_network(network, config) + # save trt engine + with open(trt_model, "wb") as f: + f.write(engine_bytes) + logging.info("Succesfully convert onnx to trt...") + + +class TrtContextWrapper: + def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) + self.trt_engine = trt_engine + self.device = device + for _ in range(trt_concurrent): + trt_context = trt_engine.create_execution_context() + trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device))) + assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) + self.trt_context_pool.put([trt_context, trt_stream]) + assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' + + def acquire_estimator(self): + return self.trt_context_pool.get(), self.trt_engine + + def release_estimator(self, context, stream): + self.trt_context_pool.put([context, stream]) + + +class CosyVoice2_Token2Wav(torch.nn.Module): + def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0): + super().__init__() + self.device_id = device_id + self.device = f"cuda:{device_id}" + + self.flow = CausalMaskedDiffWithXvec() + self.flow.half() + self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) + self.flow.to(self.device).eval() + + self.hift = HiFTGenerator() + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() + + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"]) + + self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval() + + gpu = "l20" + if enable_trt: + self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + True) + self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', + f'{model_dir}/campplus.onnx', + 1, + False) + + def forward_spk_embedding(self, spk_feat): + if isinstance(self.spk_model, onnxruntime.InferenceSession): + return self.spk_model.run( + None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} + )[0].flatten().tolist() + else: + [spk_model, stream], trt_engine = self.spk_model.acquire_estimator() + # NOTE need to synchronize when switching stream + with torch.cuda.device(self.device_id): + torch.cuda.current_stream().synchronize() + spk_feat = spk_feat.unsqueeze(dim=0).to(self.device) + batch_size = spk_feat.size(0) + + with stream: + spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80)) + output_tensor = torch.empty((batch_size, 192), device=spk_feat.device) + + data_ptrs = [spk_feat.contiguous().data_ptr(), + output_tensor.contiguous().data_ptr()] + for i, j in enumerate(data_ptrs): + + spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.spk_model.release_estimator(spk_model, stream) + + return output_tensor.cpu().numpy().flatten().tolist() + + def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): + if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: + trt_kwargs = self.get_spk_trt_kwargs() + convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16) + import tensorrt as trt + with open(spk_model, 'rb') as f: + spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert spk_engine is not None, 'failed to load trt {}'.format(spk_model) + self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_spk_trt_kwargs(self): + min_shape = [(1, 4, 80)] + opt_shape = [(1, 500, 80)] + max_shape = [(1, 3000, 80)] + input_names = ["input"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_bs=2, max_batch_size=16) + convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16) + del self.flow.decoder.estimator + import tensorrt as trt + with open(flow_decoder_estimator_model, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64): + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] + opt_shape = [(opt_bs * 2, 80, 500), (opt_bs * 2, 1, 500), (opt_bs * 2, 80, 500), (opt_bs * 2, 80, 500), (opt_bs * 2,), (opt_bs * 2, 80)] + max_shape = [(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2,), + (max_batch_size * 2, 80)] + input_names = ["x", "mask", "mu", "cond", "t", "spks"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]: + prompt_speech_tokens_list, prompt_speech_mels_list = [], [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T] + prompt_speech_mels_list.append(log_mel) + prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list) + prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize( + prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device) + ) + for i in range(len(prompt_speech_tokens)): + speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() + prompt_speech_tokens_list.append(speech_tokens_i) + return prompt_speech_tokens_list + + def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: + spk_emb_for_flow = [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) + spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) + spk_emb = self.forward_spk_embedding(spk_feat) + + spk_emb_for_flow.append(spk_emb) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + return spk_emb_for_flow + + def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): + prompt_mels_for_flow = [] + prompt_mels_lens_for_flow = [] + for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate): + assert len(audio.shape) == 1 + audio = audio.unsqueeze(0) + if sample_rate != 24000: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) + mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] + mel_len = mel.shape[0] + prompt_mels_for_flow.append(mel) + prompt_mels_lens_for_flow.append(mel_len) + prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] + prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) + return prompt_mels_for_flow, prompt_mels_lens_for_flow + + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, + prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): + batch_size = prompt_mels_for_flow.shape[0] + flow_inputs = [] + flow_inputs_lens = [] + for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list): + flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens)) + flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens)) + + flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0) + flow_inputs_lens = torch.tensor(flow_inputs_lens) + + with torch.amp.autocast(self.device, dtype=torch.float16): + generated_mels, generated_mels_lens = self.flow( + flow_inputs.to(self.device), flow_inputs_lens.to(self.device), + prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), + streaming=False, finalize=True + ) + + return generated_mels, generated_mels_lens + + def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor): + batch_size = generated_mels.shape[0] + generated_wavs = [] + for i in range(batch_size): + mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0) + wav, _ = self.hift(speech_feat=mel) + generated_wavs.append(wav) + return generated_wavs + + @torch.inference_mode() + def forward( + self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] + ): + # assert all item in prompt_audios_sample_rate is 16000 + assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) + + prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) + + prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) + + spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) + + generated_mels, generated_mels_lens = self.forward_flow( + prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + + generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) + + return generated_wavs + + +def collate_fn(batch): + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] + for _, item in enumerate(batch): + generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) + audio = torch.from_numpy(item['prompt_audio']['array']).float() + prompt_audios_list.append(audio) + prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) + ids.append(item['id']) + + return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--enable-trt", action="store_true") + parser.add_argument("--model-dir", type=str, default="./CosyVoice2-0.5B") + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--output-dir", type=str, default="generated_wavs") + parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") + parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) + # mkdir output_dir if not exists + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + dataset_name = "yuekai/seed_tts_cosy2" + + dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) + + data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) + + for _ in range(args.warmup): + start_time = time.time() + + for batch in data_loader: + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch + + generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) + + for id, wav in zip(ids, generated_wavs): + torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) + + end_time = time.time() + epoch_time = end_time - start_time + print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")