Files
CosyVoice/runtime/triton_trtllm/offline_inference.py
2025-09-08 17:37:33 +08:00

605 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
CUDA_VISIBLE_DEVICES=0 \
python3 offline_inference.py \
--output-dir $output_dir \
--llm-model-name-or-path $huggingface_model_local_dir \
--token2wav-path $model_scope_model_local_dir \
--backend $backend \
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
--engine-dir $trt_engines_dir \
--split-name ${dataset} || exit 1
"""
import argparse
import json
import os
import sys
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from cosyvoice.utils.file_utils import load_wav
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import soundfile as sf
import s3tokenizer
from functools import partial
import time
from token2wav import CosyVoice2_Token2Wav
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
def extract_speech_ids(speech_tokens_str):
"""Extract speech IDs from token strings like <|s_23456|>"""
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith('<|s_') and token_str.endswith('|>'):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
speech_id_str = ""
for token in cosy2_tokens:
speech_id_str += f"<|s_{token}|>"
return speech_id_str
def get_args():
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch-size",
default=1,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--token2wav-batch-size",
default=1,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num-workers", type=int, default=0, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=None, help="prefetch for dataloader"
)
parser.add_argument(
"--llm-model-name-or-path",
required=True,
type=str,
help="LLM model path (includes both model and tokenizer)",
)
parser.add_argument(
"--token2wav-path",
required=True,
type=str,
help="CosyVoice2 token2wav model path",
)
parser.add_argument(
"--prompt-text",
type=str,
default=None,
help="The prompt text for CosyVoice2",
)
parser.add_argument(
"--prompt-speech-path",
type=str,
default=None,
help="The path to the prompt speech for CosyVoice2",
)
parser.add_argument(
"--top-p",
type=float,
default=0.95,
help="top p for sampling",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="temperature for sampling",
)
parser.add_argument(
"--top-k",
type=int,
default=50,
help="top k for sampling",
)
parser.add_argument(
"--backend",
type=str,
default="hf",
choices=["hf", "trtllm", "vllm"],
help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
)
parser.add_argument(
"--engine-dir",
type=str,
default=None,
help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
)
parser.add_argument(
"--kv-cache-free-gpu-memory-fraction",
type=float,
default=0.6,
help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
)
args = parser.parse_args()
return args
def data_collator(batch, tokenizer, s3_tokenizer):
"""Simplified data collator for batch_size=1 processing"""
collator_start_time = time.time()
total_audio_processing_time = 0
total_speech_tokenization_time = 0
total_text_tokenization_time = 0
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
prompt_text_after_apply_template_list = []
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
for i, item in enumerate(batch):
audio_processing_start_time = time.time()
prompt_text, target_text = (
item["prompt_text"],
item["target_text"],
)
prompt_text_list.append(prompt_text)
full_text = prompt_text + target_text
full_text_list.append(full_text)
# remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
puncts = ['"', '(', ')', '', '', '', '', '', '\'']
for p in puncts:
if p in full_text:
full_text = full_text.replace(p, '')
print(f"removed {p} from {full_text}")
# get prompt audio for CosyVoice2 (convert to 16kHz)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
# ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
print(ref_audio_org.shape)
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
prompt_audio_list.append(ref_audio)
audio_processing_end_time = time.time()
total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
speech_tokenization_start_time = time.time()
if "prompt_audio_cosy2_tokens" in item:
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
else:
# convert to float first
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
if len(mels) > 0:
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
for i in range(len(codes)):
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
speech_tokenization_end_time = time.time()
total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
text_tokenization_start_time = time.time()
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
# Create chat template for LLM generation
chat = [
{"role": "user", "content": full_text_list[i]},
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
]
assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
)
input_ids_list.append(input_ids.squeeze(0))
prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
text_tokenization_end_time = time.time()
total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
ids = [item["id"] for item in batch]
return {
"input_ids": input_ids_list,
"ids": ids,
"prompt_text": prompt_text_list,
"prompt_audio_list": prompt_audio_list,
"prompt_text_after_apply_template": prompt_text_after_apply_template_list,
"audio_processing_time": total_audio_processing_time,
"speech_tokenization_time": total_speech_tokenization_time,
"text_tokenization_time": total_text_tokenization_time,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
# world_size, local_rank, rank = init_distributed()
local_rank, world_size, rank = 0, 1, 0
device = torch.device(f"cuda:{local_rank}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
# model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
# Initialize backend based on argument
if args.backend == "hf":
# Load HuggingFace model
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
model.eval()
model.to(device)
runner = None
elif args.backend == "trtllm":
# Validate engine_dir is provided
if args.engine_dir is None:
raise ValueError("--engine-dir is required when backend is 'trtllm'")
# import tensorrt_llm
#from tensorrt_llm.runtime import ModelRunnerCpp
# Initialize TensorRT-LLM runner
runtime_rank = tensorrt_llm.mpi_rank()
model = None
# Prepare input for runner initialization
runner_kwargs = dict(
engine_dir=args.engine_dir,
rank=runtime_rank,
max_output_len=2048,
enable_context_fmha_fp32_acc=False,
max_batch_size=args.batch_size,
max_input_len=512,
kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
cuda_graph_mode=False,
gather_generation_logits=False,
)
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
elif args.backend == "vllm":
# from vllm import LLM, SamplingParams
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
runner = None
else:
raise ValueError(f"Unsupported backend: {args.backend}")
token2wav_model = CosyVoice2_Token2Wav(
model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
)
if args.prompt_speech_path:
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
else:
prompt_speech_16k = None
s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
dataset = load_dataset(
dataset_name,
split=args.split_name,
trust_remote_code=True,
)
# sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
sampler = None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
)
for _ in range(3):
print(f"Running {_} times")
total_llm_time = 0
total_token2wav_time = 0
total_data_load_time = 0
total_llm_post_processing_time = 0
total_audio_save_time = 0
total_audio_processing_time_in_collator = 0
total_speech_tokenization_time_in_collator = 0
total_text_tokenization_time_in_collator = 0
total_audio_samples = 0
start_time = time.time()
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
last_batch_end_time = time.time()
for batch in dataloader:
data_loaded_time = time.time()
total_data_load_time += data_loaded_time - last_batch_end_time
total_audio_processing_time_in_collator += batch["audio_processing_time"]
total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
with torch.no_grad():
# Generate speech tokens using LLM
llm_start_time = time.time()
if args.backend == "hf":
input_ids_list = batch["input_ids"]
if len(input_ids_list) == 1:
input_ids = input_ids_list[0].unsqueeze(0)
attention_mask = torch.ones_like(input_ids)
else:
# Handle batch > 1 if needed
max_len = max([len(input_ids) for input_ids in input_ids_list])
# input_ids_list_new = [
# torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
# for input_ids in input_ids_list
# ]
input_ids_list_new = [
torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
for input_ids in input_ids_list
]
input_ids = torch.stack(input_ids_list_new)
# compute attention mask
attention_mask = torch.zeros_like(input_ids)
for i in range(len(input_ids_list)):
attention_mask[i, :len(input_ids_list[i])] = 1
# breakpoint()
input_ids = input_ids.to(device)
outputs = model.generate(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
max_new_tokens=2048, # Max length for generation
do_sample=True,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=1.1,
top_k=args.top_k,
)
torch.cuda.synchronize()
elif args.backend == "trtllm":
# Convert input_ids to list of tensors for TensorRT-LLM
batch_input_ids = [ids for ids in batch["input_ids"]]
input_lengths = [x.size(0) for x in batch_input_ids]
# Get end_id from tokenizer
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
# random_seed=42, repetition_penalty=1.1,
outputs = runner.generate(
batch_input_ids=batch_input_ids,
max_new_tokens=2048,
end_id=end_id,
pad_id=end_id,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=1.1,
num_return_sequences=1,
streaming=False,
output_sequence_lengths=True,
output_generation_logits=False,
return_dict=True,
return_all_generated_tokens=False
)
torch.cuda.synchronize()
# Extract output_ids from TensorRT-LLM output
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
num_output_sents, num_beams, _ = output_ids.size()
assert num_beams == 1
beam = 0
batch_size = len(batch["input_ids"])
num_return_sequences = num_output_sents // batch_size
assert num_return_sequences == 1
outputs = []
for i in range(batch_size * num_return_sequences):
batch_idx = i // num_return_sequences
seq_idx = i % num_return_sequences
# inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist()
# input_text = tokenizer.decode(inputs)
# print(f'Input [Text {batch_idx}]: \"{input_text}\"')
output_begin = input_lengths[batch_idx]
output_end = sequence_lengths[i][beam]
# outputs_i = output_ids[i][beam][output_begin:output_end].tolist()
outputs_i = output_ids[i][beam][:output_end].tolist()
outputs.append(outputs_i)
elif args.backend == "vllm":
input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
# prompts = [batch["prompt_text_after_apply_template"][i] for i in range(len(batch["prompt_text_after_apply_template"]))]
# print(prompts)
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=1.1,
max_tokens=2048,
)
outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
# outputs = model.generate(prompts, sampling_params)
print(outputs)
# breakpoint()
for j, output in enumerate(outputs):
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
llm_end_time = time.time()
total_llm_time += (llm_end_time - llm_start_time)
items_for_token2wav = []
for i in range(len(batch["ids"])):
llm_post_processing_start_time = time.time()
# Extract generated tokens (excluding input)
input_length = len(batch["input_ids"][i])
generated_ids = outputs[i][input_length:] # Remove last token if needed
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Extract speech IDs from token strings like <|s_23456|>
speech_ids = extract_speech_ids(speech_tokens_str)
print(i, speech_ids)
# breakpoint()
if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
continue
if args.prompt_text is not None:
current_prompt_text = args.prompt_text
current_prompt_audio = prompt_speech_16k
else:
current_prompt_text = batch["prompt_text"][i]
current_prompt_audio = batch["prompt_audio_list"][i]
llm_post_processing_end_time = time.time()
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
if current_prompt_audio is not None:
items_for_token2wav.append({
"speech_ids": speech_ids,
"prompt_audio": current_prompt_audio.squeeze(0),
"id": batch["ids"][i]
})
else:
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
for i in range(0, len(items_for_token2wav), args.token2wav_batch_size):
t2w_batch = items_for_token2wav[i:i + args.token2wav_batch_size]
if not t2w_batch:
continue
t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
t2w_ids = [item["id"] for item in t2w_batch]
# Generate audio using CosyVoice2
token2wav_start_time = time.time()
generated_wavs = token2wav_model(
t2w_generated_speech_tokens_list,
t2w_prompt_audios_list,
t2w_prompt_audios_sample_rate,
)
torch.cuda.synchronize()
token2wav_end_time = time.time()
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
audio_save_start_time = time.time()
# Convert to numpy and save
for j, audio_hat in enumerate(generated_wavs):
generated_wave = audio_hat.squeeze().cpu().numpy()
total_audio_samples += len(generated_wave)
target_sample_rate = 24000
utt = t2w_ids[j]
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
audio_save_end_time = time.time()
total_audio_save_time += audio_save_end_time - audio_save_start_time
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
last_batch_end_time = time.time()
if rank == 0:
progress_bar.close()
end_time = time.time()
target_sample_rate = 24000
total_audio_duration_seconds = total_audio_samples / target_sample_rate
log_file_path = os.path.join(args.output_dir, "log.txt")
with open(log_file_path, 'w') as f:
# Convert Namespace to dict for JSON serialization
args_dict = vars(args)
log_data = {
"args": args_dict,
"data_load_time_seconds": total_data_load_time,
"audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
"speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
"text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
"llm_time_seconds": total_llm_time,
"llm_post_processing_time_seconds": total_llm_post_processing_time,
"token2wav_time_seconds": total_token2wav_time,
"audio_save_time_seconds": total_audio_save_time,
"total_audio_duration_seconds": total_audio_duration_seconds,
"pipeline_time_seconds": end_time - start_time,
}
print(log_data)
f.write(json.dumps(log_data, indent=4))
print(f"Metrics logged to {log_file_path}")
if __name__ == "__main__":
args = get_args()
if args.backend == "vllm":
from vllm import LLM, SamplingParams
elif args.backend == "trtllm":
import tensorrt_llm
from tensorrt_llm.runtime import ModelRunnerCpp
elif args.backend == "hf":
from transformers import AutoModelForCausalLM
else:
raise ValueError(f"Unsupported backend: {args.backend}")
main(args)