From cc1991870bdf3b57f80e64507cc071c8d14350da Mon Sep 17 00:00:00 2001 From: root Date: Mon, 8 Sep 2025 17:37:33 +0800 Subject: [PATCH] add cosyvoice2 offline inference --- README.md | 11 + runtime/triton_trtllm/README.md | 38 +- runtime/triton_trtllm/offline_inference.py | 605 +++++++++++++++++++++ runtime/triton_trtllm/run.sh | 25 + runtime/triton_trtllm/token2wav.py | 336 ++++++++++++ 5 files changed, 1011 insertions(+), 4 deletions(-) create mode 100644 runtime/triton_trtllm/offline_inference.py create mode 100644 runtime/triton_trtllm/token2wav.py diff --git a/README.md b/README.md index 5e3cfd5..214042f 100644 --- a/README.md +++ b/README.md @@ -246,6 +246,17 @@ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /o cd fastapi && python3 client.py --port 50000 --mode ``` +#### Using Nvidia TensorRT-LLM for deployment + +Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation. +To quick start: + +``` sh +cd runtime/triton_trtllm +docker compose up -d +``` +For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm) + ## Discussion & Communication You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues). diff --git a/runtime/triton_trtllm/README.md b/runtime/triton_trtllm/README.md index 8d4892c..420990e 100644 --- a/runtime/triton_trtllm/README.md +++ b/runtime/triton_trtllm/README.md @@ -1,4 +1,4 @@ -## Serving CosyVoice with NVIDIA Triton Inference Server +## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM Contributed by Yuekai Zhang (NVIDIA). @@ -41,6 +41,7 @@ bash run.sh [service_type] - **Stage 3**: Launches the Triton Inference Server. - **Stage 4**: Runs the single-utterance HTTP client for testing. - **Stage 5**: Runs the gRPC benchmark client. +- **Stage 6**: Runs the offline inference benchmark test. ### Export Models and Launch Server @@ -59,7 +60,7 @@ Sends a single HTTP inference request. This is intended for testing the offline bash run.sh 4 4 ``` -### Benchmark with a Dataset +### Benchmark with client-server mode To benchmark the running Triton server, pass `streaming` or `offline` as the third argument: ```sh @@ -71,10 +72,26 @@ bash run.sh 5 5 # [streaming|offline] > [!TIP] > It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up. +### Benchmark with offline inference mode +For offline inference mode benchmark, please check the below command: +```sh +# install FlashCosyVoice for token2wav batching +# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt +# cd /workspace/FlashCosyVoice +# pip install -e . +# cd - +# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx + +bash run.sh 6 6 + +# You can also switch to huggingface backend by setting backend=hf +``` + + ### Benchmark Results The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio): -**Streaming TTS (First Chunk Latency)** +**Client-Server Mode: Streaming TTS (First Chunk Latency)** | Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF | |---|---|---|---|---| | Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 | @@ -86,13 +103,26 @@ The following results were obtained by decoding on a single L20 GPU with 26 prom > If your service only needs a fixed speaker, you can set `use_spk2info_cache=True` in `run.sh`. To add more speakers, refer to the instructions [here](https://github.com/qi-hua/async_cosyvoice?tab=readme-ov-file#9-spk2info-%E8%AF%B4%E6%98%8E). -**Offline TTS (Full Sentence Latency)** +**Client-Server Mode: Offline TTS (Full Sentence Latency)** | Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF | |---|---|---|---|---|---| | Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 | | Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 | | Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 | +**Offline Inference Mode: Hugginface LLM V.S. TensorRT-LLM** +| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF | +|---------|------------|------------------|-----------------------|--| +| HF | 1 | 39.26 | 44.31 | 0.2494 | +| HF | 2 | 30.54 | 35.62 | 0.2064 | +| HF | 4 | 18.63 | 23.90 | 0.1421 | +| HF | 8 | 11.22 | 16.45 | 0.0947 | +| HF | 16 | 8.42 | 13.78 | 0.0821 | +| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 | +| TRTLLM | 2 | 7.64 |12.65 | 0.0739 | +| TRTLLM | 4 | 4.89 | 9.38 | 0.0539 | +| TRTLLM | 8 | 2.92 | 7.23 | 0.0418 | +| TRTLLM | 16 | 2.01 | 6.63 | 0.0386 | ### OpenAI-Compatible Server To launch an OpenAI-compatible API service, run the following commands: diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py new file mode 100644 index 0000000..523cd56 --- /dev/null +++ b/runtime/triton_trtllm/offline_inference.py @@ -0,0 +1,605 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage + CUDA_VISIBLE_DEVICES=0 \ + python3 offline_inference.py \ + --output-dir $output_dir \ + --llm-model-name-or-path $huggingface_model_local_dir \ + --token2wav-path $model_scope_model_local_dir \ + --backend $backend \ + --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \ + --engine-dir $trt_engines_dir \ + --split-name ${dataset} || exit 1 +""" + +import argparse +import json +import os +import sys +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torchaudio +from cosyvoice.utils.file_utils import load_wav +from datasets import load_dataset +from transformers import AutoTokenizer +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +import soundfile as sf +import s3tokenizer +from functools import partial +import time + +from token2wav import CosyVoice2_Token2Wav + +sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") +try: + torch.multiprocessing.set_start_method("spawn") +except RuntimeError: + pass + + +def extract_speech_ids(speech_tokens_str): + """Extract speech IDs from token strings like <|s_23456|>""" + speech_ids = [] + for token_str in speech_tokens_str: + if token_str.startswith('<|s_') and token_str.endswith('|>'): + num_str = token_str[4:-2] + num = int(num_str) + speech_ids.append(num) + else: + print(f"Unexpected token: {token_str}") + return speech_ids + +def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens): + """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>""" + speech_id_str = "" + for token in cosy2_tokens: + speech_id_str += f"<|s_{token}|>" + return speech_id_str + + +def get_args(): + parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2") + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2", + ) + parser.add_argument( + "--output-dir", required=True, type=str, help="dir to save result" + ) + parser.add_argument( + "--batch-size", + default=1, + type=int, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--token2wav-batch-size", + default=1, + type=int, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--num-workers", type=int, default=0, help="workers for dataloader" + ) + parser.add_argument( + "--prefetch", type=int, default=None, help="prefetch for dataloader" + ) + parser.add_argument( + "--llm-model-name-or-path", + required=True, + type=str, + help="LLM model path (includes both model and tokenizer)", + ) + parser.add_argument( + "--token2wav-path", + required=True, + type=str, + help="CosyVoice2 token2wav model path", + ) + parser.add_argument( + "--prompt-text", + type=str, + default=None, + help="The prompt text for CosyVoice2", + ) + parser.add_argument( + "--prompt-speech-path", + type=str, + default=None, + help="The path to the prompt speech for CosyVoice2", + ) + parser.add_argument( + "--top-p", + type=float, + default=0.95, + help="top p for sampling", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="temperature for sampling", + ) + parser.add_argument( + "--top-k", + type=int, + default=50, + help="top k for sampling", + ) + parser.add_argument( + "--backend", + type=str, + default="hf", + choices=["hf", "trtllm", "vllm"], + help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM", + ) + parser.add_argument( + "--engine-dir", + type=str, + default=None, + help="TensorRT-LLM engine directory (required when backend is 'trtllm')", + ) + parser.add_argument( + "--kv-cache-free-gpu-memory-fraction", + type=float, + default=0.6, + help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)", + ) + args = parser.parse_args() + return args + + + +def data_collator(batch, tokenizer, s3_tokenizer): + """Simplified data collator for batch_size=1 processing""" + collator_start_time = time.time() + total_audio_processing_time = 0 + total_speech_tokenization_time = 0 + total_text_tokenization_time = 0 + + target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio + device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu") + input_ids_list, prompt_audio_list, prompt_text_list = [], [], [] + prompt_text_after_apply_template_list = [] + mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], [] + for i, item in enumerate(batch): + audio_processing_start_time = time.time() + prompt_text, target_text = ( + item["prompt_text"], + item["target_text"], + ) + prompt_text_list.append(prompt_text) + full_text = prompt_text + target_text + full_text_list.append(full_text) + # remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset + puncts = ['"', '(', ')', '“', '”', '‘', '(', ')', '\''] + for p in puncts: + if p in full_text: + full_text = full_text.replace(p, '') + print(f"removed {p} from {full_text}") + + # get prompt audio for CosyVoice2 (convert to 16kHz) + ref_audio_org, ref_sr = ( + item["prompt_audio"]["array"], + item["prompt_audio"]["sampling_rate"], + ) + ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0) + # ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True) + print(ref_audio_org.shape) + + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + + prompt_audio_list.append(ref_audio) + audio_processing_end_time = time.time() + total_audio_processing_time += audio_processing_end_time - audio_processing_start_time + + speech_tokenization_start_time = time.time() + if "prompt_audio_cosy2_tokens" in item: + prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"] + prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens) + else: + # convert to float first + mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0))) + + if len(mels) > 0: + mels, mels_lens = s3tokenizer.padding(mels) + codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device)) + for i in range(len(codes)): + prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()]) + speech_tokenization_end_time = time.time() + total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time + + for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list): + text_tokenization_start_time = time.time() + prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens) + # Create chat template for LLM generation + chat = [ + {"role": "user", "content": full_text_list[i]}, + {"role": "assistant", "content": prompt_audio_cosy2_id_str} + ] + + assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template" + + input_ids = tokenizer.apply_chat_template( + chat, + tokenize=True, + return_tensors='pt', + continue_final_message=True + ) + input_ids_list.append(input_ids.squeeze(0)) + + prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}" + + prompt_text_after_apply_template_list.append(prompt_text_after_apply_template) + text_tokenization_end_time = time.time() + total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time + + ids = [item["id"] for item in batch] + + return { + "input_ids": input_ids_list, + "ids": ids, + "prompt_text": prompt_text_list, + "prompt_audio_list": prompt_audio_list, + "prompt_text_after_apply_template": prompt_text_after_apply_template_list, + "audio_processing_time": total_audio_processing_time, + "speech_tokenization_time": total_speech_tokenization_time, + "text_tokenization_time": total_text_tokenization_time, + } + + +def init_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + print( + "Inference on multiple gpus, this gpu {}".format(local_rank) + + ", rank {}, world_size {}".format(rank, world_size) + ) + torch.cuda.set_device(local_rank) + dist.init_process_group("nccl") + return world_size, local_rank, rank + + +def main(args): + os.makedirs(args.output_dir, exist_ok=True) + + assert torch.cuda.is_available() + # world_size, local_rank, rank = init_distributed() + local_rank, world_size, rank = 0, 1, 0 + device = torch.device(f"cuda:{local_rank}") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path) + + # model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4) + # Initialize backend based on argument + if args.backend == "hf": + # Load HuggingFace model + model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path) + model.eval() + model.to(device) + runner = None + elif args.backend == "trtllm": + # Validate engine_dir is provided + if args.engine_dir is None: + raise ValueError("--engine-dir is required when backend is 'trtllm'") + # import tensorrt_llm + #from tensorrt_llm.runtime import ModelRunnerCpp + + # Initialize TensorRT-LLM runner + runtime_rank = tensorrt_llm.mpi_rank() + model = None + + # Prepare input for runner initialization + runner_kwargs = dict( + engine_dir=args.engine_dir, + rank=runtime_rank, + max_output_len=2048, + enable_context_fmha_fp32_acc=False, + max_batch_size=args.batch_size, + max_input_len=512, + kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction, + cuda_graph_mode=False, + gather_generation_logits=False, + ) + + runner = ModelRunnerCpp.from_dir(**runner_kwargs) + elif args.backend == "vllm": + # from vllm import LLM, SamplingParams + model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4) + runner = None + else: + raise ValueError(f"Unsupported backend: {args.backend}") + + token2wav_model = CosyVoice2_Token2Wav( + model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank + ) + if args.prompt_speech_path: + prompt_speech_16k = load_wav(args.prompt_speech_path, 16000) + else: + prompt_speech_16k = None + s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None + dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2" + dataset = load_dataset( + dataset_name, + split=args.split_name, + trust_remote_code=True, + ) + + # sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) + sampler = None + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + shuffle=False, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer), + ) + for _ in range(3): + print(f"Running {_} times") + total_llm_time = 0 + total_token2wav_time = 0 + total_data_load_time = 0 + total_llm_post_processing_time = 0 + total_audio_save_time = 0 + total_audio_processing_time_in_collator = 0 + total_speech_tokenization_time_in_collator = 0 + total_text_tokenization_time_in_collator = 0 + total_audio_samples = 0 + start_time = time.time() + total_steps = len(dataset) + + if rank == 0: + progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") + + last_batch_end_time = time.time() + for batch in dataloader: + data_loaded_time = time.time() + total_data_load_time += data_loaded_time - last_batch_end_time + total_audio_processing_time_in_collator += batch["audio_processing_time"] + total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"] + total_text_tokenization_time_in_collator += batch["text_tokenization_time"] + with torch.no_grad(): + # Generate speech tokens using LLM + llm_start_time = time.time() + if args.backend == "hf": + input_ids_list = batch["input_ids"] + if len(input_ids_list) == 1: + input_ids = input_ids_list[0].unsqueeze(0) + attention_mask = torch.ones_like(input_ids) + else: + # Handle batch > 1 if needed + max_len = max([len(input_ids) for input_ids in input_ids_list]) + # input_ids_list_new = [ + # torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids]) + # for input_ids in input_ids_list + # ] + input_ids_list_new = [ + torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)]) + for input_ids in input_ids_list + ] + input_ids = torch.stack(input_ids_list_new) + # compute attention mask + attention_mask = torch.zeros_like(input_ids) + for i in range(len(input_ids_list)): + attention_mask[i, :len(input_ids_list[i])] = 1 + + # breakpoint() + + + input_ids = input_ids.to(device) + + outputs = model.generate( + input_ids=input_ids.to(device), + attention_mask=attention_mask.to(device), + max_new_tokens=2048, # Max length for generation + do_sample=True, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=1.1, + top_k=args.top_k, + ) + torch.cuda.synchronize() + elif args.backend == "trtllm": + # Convert input_ids to list of tensors for TensorRT-LLM + batch_input_ids = [ids for ids in batch["input_ids"]] + input_lengths = [x.size(0) for x in batch_input_ids] + + # Get end_id from tokenizer + end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id + print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================") + # random_seed=42, repetition_penalty=1.1, + outputs = runner.generate( + batch_input_ids=batch_input_ids, + max_new_tokens=2048, + end_id=end_id, + pad_id=end_id, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + repetition_penalty=1.1, + num_return_sequences=1, + streaming=False, + output_sequence_lengths=True, + output_generation_logits=False, + return_dict=True, + return_all_generated_tokens=False + ) + torch.cuda.synchronize() + # Extract output_ids from TensorRT-LLM output + output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"] + num_output_sents, num_beams, _ = output_ids.size() + assert num_beams == 1 + beam = 0 + batch_size = len(batch["input_ids"]) + num_return_sequences = num_output_sents // batch_size + assert num_return_sequences == 1 + outputs = [] + for i in range(batch_size * num_return_sequences): + batch_idx = i // num_return_sequences + seq_idx = i % num_return_sequences + # inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist() + # input_text = tokenizer.decode(inputs) + # print(f'Input [Text {batch_idx}]: \"{input_text}\"') + output_begin = input_lengths[batch_idx] + output_end = sequence_lengths[i][beam] + # outputs_i = output_ids[i][beam][output_begin:output_end].tolist() + outputs_i = output_ids[i][beam][:output_end].tolist() + outputs.append(outputs_i) + elif args.backend == "vllm": + input_ids_list = [ids.tolist() for ids in batch["input_ids"]] + # prompts = [batch["prompt_text_after_apply_template"][i] for i in range(len(batch["prompt_text_after_apply_template"]))] + # print(prompts) + sampling_params = SamplingParams( + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=1.1, + max_tokens=2048, + ) + outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params) + # outputs = model.generate(prompts, sampling_params) + print(outputs) + # breakpoint() + for j, output in enumerate(outputs): + outputs[j] = input_ids_list[j] + output.outputs[0].token_ids + + llm_end_time = time.time() + total_llm_time += (llm_end_time - llm_start_time) + + items_for_token2wav = [] + for i in range(len(batch["ids"])): + llm_post_processing_start_time = time.time() + # Extract generated tokens (excluding input) + input_length = len(batch["input_ids"][i]) + generated_ids = outputs[i][input_length:] # Remove last token if needed + speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + # Extract speech IDs from token strings like <|s_23456|> + speech_ids = extract_speech_ids(speech_tokens_str) + print(i, speech_ids) + # breakpoint() + if len(speech_ids) == 0: + print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping") + continue + + if args.prompt_text is not None: + current_prompt_text = args.prompt_text + current_prompt_audio = prompt_speech_16k + else: + current_prompt_text = batch["prompt_text"][i] + current_prompt_audio = batch["prompt_audio_list"][i] + + llm_post_processing_end_time = time.time() + total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time + if current_prompt_audio is not None: + items_for_token2wav.append({ + "speech_ids": speech_ids, + "prompt_audio": current_prompt_audio.squeeze(0), + "id": batch["ids"][i] + }) + else: + print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping") + + for i in range(0, len(items_for_token2wav), args.token2wav_batch_size): + t2w_batch = items_for_token2wav[i:i + args.token2wav_batch_size] + if not t2w_batch: + continue + + t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch] + t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch] + t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch) + t2w_ids = [item["id"] for item in t2w_batch] + + # Generate audio using CosyVoice2 + token2wav_start_time = time.time() + generated_wavs = token2wav_model( + t2w_generated_speech_tokens_list, + t2w_prompt_audios_list, + t2w_prompt_audios_sample_rate, + ) + torch.cuda.synchronize() + token2wav_end_time = time.time() + total_token2wav_time += (token2wav_end_time - token2wav_start_time) + + audio_save_start_time = time.time() + # Convert to numpy and save + for j, audio_hat in enumerate(generated_wavs): + generated_wave = audio_hat.squeeze().cpu().numpy() + total_audio_samples += len(generated_wave) + target_sample_rate = 24000 + + utt = t2w_ids[j] + sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate) + print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens") + audio_save_end_time = time.time() + total_audio_save_time += audio_save_end_time - audio_save_start_time + + if rank == 0: + progress_bar.update(world_size * len(batch["ids"])) + + last_batch_end_time = time.time() + if rank == 0: + progress_bar.close() + end_time = time.time() + target_sample_rate = 24000 + total_audio_duration_seconds = total_audio_samples / target_sample_rate + + log_file_path = os.path.join(args.output_dir, "log.txt") + with open(log_file_path, 'w') as f: + # Convert Namespace to dict for JSON serialization + args_dict = vars(args) + log_data = { + "args": args_dict, + "data_load_time_seconds": total_data_load_time, + "audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator, + "speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator, + "text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator, + "llm_time_seconds": total_llm_time, + "llm_post_processing_time_seconds": total_llm_post_processing_time, + "token2wav_time_seconds": total_token2wav_time, + "audio_save_time_seconds": total_audio_save_time, + "total_audio_duration_seconds": total_audio_duration_seconds, + "pipeline_time_seconds": end_time - start_time, + } + print(log_data) + f.write(json.dumps(log_data, indent=4)) + print(f"Metrics logged to {log_file_path}") + + +if __name__ == "__main__": + args = get_args() + if args.backend == "vllm": + from vllm import LLM, SamplingParams + elif args.backend == "trtllm": + import tensorrt_llm + from tensorrt_llm.runtime import ModelRunnerCpp + elif args.backend == "hf": + from transformers import AutoModelForCausalLM + else: + raise ValueError(f"Unsupported backend: {args.backend}") + main(args) \ No newline at end of file diff --git a/runtime/triton_trtllm/run.sh b/runtime/triton_trtllm/run.sh index a60f4a3..7c7f3cd 100644 --- a/runtime/triton_trtllm/run.sh +++ b/runtime/triton_trtllm/run.sh @@ -27,6 +27,7 @@ fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then echo "Downloading CosyVoice2-0.5B" + # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir # download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings @@ -115,3 +116,27 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --huggingface-dataset yuekai/seed_tts_cosy2 \ --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache} fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + echo "stage 6: Offline inference benchmark" + n_gpus=1 + datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh + backend=trtllm # hf, trtllm, vllm + + batch_sizes=(16 8 4 2 1) + token2wav_batch_size=1 + for batch_size in ${batch_sizes[@]}; do + for dataset in ${datasets[@]}; do + output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size} + CUDA_VISIBLE_DEVICES=0 \ + python3 offline_inference.py \ + --output-dir $output_dir \ + --llm-model-name-or-path $huggingface_model_local_dir \ + --token2wav-path $model_scope_model_local_dir \ + --backend $backend \ + --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \ + --engine-dir $trt_engines_dir \ + --split-name ${dataset} || exit 1 + done + done +fi diff --git a/runtime/triton_trtllm/token2wav.py b/runtime/triton_trtllm/token2wav.py new file mode 100644 index 0000000..786c582 --- /dev/null +++ b/runtime/triton_trtllm/token2wav.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage + CUDA_VISIBLE_DEVICES=0 \ + python3 token2wav.py --enable-trt || exit 1 +""" +import torch +from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec +from flashcosyvoice.modules.hifigan import HiFTGenerator +from flashcosyvoice.utils.audio import mel_spectrogram +import torchaudio.compliance.kaldi as kaldi +import onnxruntime +import s3tokenizer +from torch.utils.data import DataLoader +from datasets import load_dataset +import torchaudio +import os +import logging +import argparse +import queue +import time + + +def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): + import tensorrt as trt + logging.info("Converting onnx to trt...") + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + network = builder.create_network(network_flags) + parser = trt.OnnxParser(network, logger) + config = builder.create_builder_config() + # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB + if fp16: + config.set_flag(trt.BuilderFlag.FP16) + profile = builder.create_optimization_profile() + # load onnx model + with open(onnx_model, "rb") as f: + if not parser.parse(f.read()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + raise ValueError('failed to parse {}'.format(onnx_model)) + # set input shapes + for i in range(len(trt_kwargs['input_names'])): + profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) + tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT + # set input and output data type + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + input_tensor.dtype = tensor_dtype + for i in range(network.num_outputs): + output_tensor = network.get_output(i) + output_tensor.dtype = tensor_dtype + config.add_optimization_profile(profile) + engine_bytes = builder.build_serialized_network(network, config) + # save trt engine + with open(trt_model, "wb") as f: + f.write(engine_bytes) + logging.info("Succesfully convert onnx to trt...") + +class TrtContextWrapper: + def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) + self.trt_engine = trt_engine + self.device = device + for _ in range(trt_concurrent): + trt_context = trt_engine.create_execution_context() + trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device))) + assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) + self.trt_context_pool.put([trt_context, trt_stream]) + assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' + + def acquire_estimator(self): + return self.trt_context_pool.get(), self.trt_engine + + def release_estimator(self, context, stream): + self.trt_context_pool.put([context, stream]) + +class CosyVoice2_Token2Wav(torch.nn.Module): + def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0): + super().__init__() + self.device_id = device_id + self.device = f"cuda:{device_id}" + + self.flow = CausalMaskedDiffWithXvec() + self.flow.half() + self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) + self.flow.to(self.device).eval() + + self.hift = HiFTGenerator() + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() + + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, + providers=["CPUExecutionProvider"]) + + self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval() + + gpu="l20" + if enable_trt: + self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + True) + self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', + f'{model_dir}/campplus.onnx', + 1, + False) + + + def forward_spk_embedding(self, spk_feat): + if isinstance(self.spk_model, onnxruntime.InferenceSession): + return self.spk_model.run( + None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} + )[0].flatten().tolist() + else: + [spk_model, stream], trt_engine = self.spk_model.acquire_estimator() + # NOTE need to synchronize when switching stream + with torch.cuda.device(self.device_id): + torch.cuda.current_stream().synchronize() + spk_feat = spk_feat.unsqueeze(dim=0).to(self.device) + batch_size = spk_feat.size(0) + + with stream: + spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80)) + output_tensor = torch.empty((batch_size, 192), device=spk_feat.device) + + data_ptrs = [spk_feat.contiguous().data_ptr(), + output_tensor.contiguous().data_ptr()] + for i, j in enumerate(data_ptrs): + + spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.spk_model.release_estimator(spk_model, stream) + + return output_tensor.cpu().numpy().flatten().tolist() + + def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): + if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: + trt_kwargs = self.get_spk_trt_kwargs() + convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16) + import tensorrt as trt + with open(spk_model, 'rb') as f: + spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert spk_engine is not None, 'failed to load trt {}'.format(spk_model) + self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_spk_trt_kwargs(self): + min_shape = [(1, 4, 80)] + opt_shape = [(1, 500, 80)] + max_shape = [(1, 3000, 80)] + input_names = ["input"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=2, max_batch_size=16) + convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16) + del self.flow.decoder.estimator + import tensorrt as trt + with open(flow_decoder_estimator_model, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64): + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] + opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)] + max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)] + input_names = ["x", "mask", "mu", "cond", "t", "spks"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]: + prompt_speech_tokens_list, prompt_speech_mels_list = [], [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T] + prompt_speech_mels_list.append(log_mel) + prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list) + prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize( + prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device) + ) + for i in range(len(prompt_speech_tokens)): + speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() + prompt_speech_tokens_list.append(speech_tokens_i) + return prompt_speech_tokens_list + + def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: + spk_emb_for_flow = [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) + spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) + spk_emb = self.forward_spk_embedding(spk_feat) + + spk_emb_for_flow.append(spk_emb) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + return spk_emb_for_flow + + def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): + prompt_mels_for_flow = [] + prompt_mels_lens_for_flow = [] + for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate): + assert len(audio.shape) == 1 + audio = audio.unsqueeze(0) + if sample_rate != 24000: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) + mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] + mel_len = mel.shape[0] + prompt_mels_for_flow.append(mel) + prompt_mels_lens_for_flow.append(mel_len) + prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] + prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) + return prompt_mels_for_flow, prompt_mels_lens_for_flow + + + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): + batch_size = prompt_mels_for_flow.shape[0] + flow_inputs = [] + flow_inputs_lens = [] + for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list): + flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens)) + flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens)) + + flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0) + flow_inputs_lens = torch.tensor(flow_inputs_lens) + + with torch.amp.autocast(self.device, dtype=torch.float16): + generated_mels, generated_mels_lens = self.flow( + flow_inputs.to(self.device), flow_inputs_lens.to(self.device), + prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), + streaming=False, finalize=True + ) + + return generated_mels, generated_mels_lens + + def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor): + batch_size = generated_mels.shape[0] + generated_wavs = [] + for i in range(batch_size): + mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0) + wav, _ = self.hift(speech_feat=mel) + generated_wavs.append(wav) + return generated_wavs + + + @torch.inference_mode() + def forward( + self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] + ): + # assert all item in prompt_audios_sample_rate is 16000 + assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) + + + prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) + + prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) + + spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) + + generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + + generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) + + return generated_wavs + + +def collate_fn(batch): + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] + for i, item in enumerate(batch): + generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) + audio = torch.from_numpy(item['prompt_audio']['array']).float() + prompt_audios_list.append(audio) + prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) + ids.append(item['id']) + + return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--enable-trt", action="store_true") + parser.add_argument("--model-dir", type=str, default="./CosyVoice2-0.5B") + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--output-dir", type=str, default="generated_wavs") + parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") + parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") + return parser.parse_args() + +if __name__ == "__main__": + args = get_args() + model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) + # mkdir output_dir if not exists + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + dataset_name = "yuekai/seed_tts_cosy2" + + dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) + + + data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) + + + for epoch in range(args.warmup): + start_time = time.time() + + for batch in data_loader: + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch + + generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) + + + for id, wav in zip(ids, generated_wavs): + torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) + + end_time = time.time() + epoch_time = end_time - start_time + print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") \ No newline at end of file