add cosyvoice2 offline inference

This commit is contained in:
root
2025-09-08 17:37:33 +08:00
parent 8ded65e611
commit cc1991870b
5 changed files with 1011 additions and 4 deletions

View File

@@ -0,0 +1,605 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
CUDA_VISIBLE_DEVICES=0 \
python3 offline_inference.py \
--output-dir $output_dir \
--llm-model-name-or-path $huggingface_model_local_dir \
--token2wav-path $model_scope_model_local_dir \
--backend $backend \
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
--engine-dir $trt_engines_dir \
--split-name ${dataset} || exit 1
"""
import argparse
import json
import os
import sys
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from cosyvoice.utils.file_utils import load_wav
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import soundfile as sf
import s3tokenizer
from functools import partial
import time
from token2wav import CosyVoice2_Token2Wav
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
def extract_speech_ids(speech_tokens_str):
"""Extract speech IDs from token strings like <|s_23456|>"""
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith('<|s_') and token_str.endswith('|>'):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
speech_id_str = ""
for token in cosy2_tokens:
speech_id_str += f"<|s_{token}|>"
return speech_id_str
def get_args():
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch-size",
default=1,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--token2wav-batch-size",
default=1,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num-workers", type=int, default=0, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=None, help="prefetch for dataloader"
)
parser.add_argument(
"--llm-model-name-or-path",
required=True,
type=str,
help="LLM model path (includes both model and tokenizer)",
)
parser.add_argument(
"--token2wav-path",
required=True,
type=str,
help="CosyVoice2 token2wav model path",
)
parser.add_argument(
"--prompt-text",
type=str,
default=None,
help="The prompt text for CosyVoice2",
)
parser.add_argument(
"--prompt-speech-path",
type=str,
default=None,
help="The path to the prompt speech for CosyVoice2",
)
parser.add_argument(
"--top-p",
type=float,
default=0.95,
help="top p for sampling",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="temperature for sampling",
)
parser.add_argument(
"--top-k",
type=int,
default=50,
help="top k for sampling",
)
parser.add_argument(
"--backend",
type=str,
default="hf",
choices=["hf", "trtllm", "vllm"],
help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
)
parser.add_argument(
"--engine-dir",
type=str,
default=None,
help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
)
parser.add_argument(
"--kv-cache-free-gpu-memory-fraction",
type=float,
default=0.6,
help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
)
args = parser.parse_args()
return args
def data_collator(batch, tokenizer, s3_tokenizer):
"""Simplified data collator for batch_size=1 processing"""
collator_start_time = time.time()
total_audio_processing_time = 0
total_speech_tokenization_time = 0
total_text_tokenization_time = 0
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
prompt_text_after_apply_template_list = []
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
for i, item in enumerate(batch):
audio_processing_start_time = time.time()
prompt_text, target_text = (
item["prompt_text"],
item["target_text"],
)
prompt_text_list.append(prompt_text)
full_text = prompt_text + target_text
full_text_list.append(full_text)
# remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
puncts = ['"', '(', ')', '', '', '', '', '', '\'']
for p in puncts:
if p in full_text:
full_text = full_text.replace(p, '')
print(f"removed {p} from {full_text}")
# get prompt audio for CosyVoice2 (convert to 16kHz)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
# ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
print(ref_audio_org.shape)
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
prompt_audio_list.append(ref_audio)
audio_processing_end_time = time.time()
total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
speech_tokenization_start_time = time.time()
if "prompt_audio_cosy2_tokens" in item:
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
else:
# convert to float first
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
if len(mels) > 0:
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
for i in range(len(codes)):
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
speech_tokenization_end_time = time.time()
total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
text_tokenization_start_time = time.time()
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
# Create chat template for LLM generation
chat = [
{"role": "user", "content": full_text_list[i]},
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
]
assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
)
input_ids_list.append(input_ids.squeeze(0))
prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
text_tokenization_end_time = time.time()
total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
ids = [item["id"] for item in batch]
return {
"input_ids": input_ids_list,
"ids": ids,
"prompt_text": prompt_text_list,
"prompt_audio_list": prompt_audio_list,
"prompt_text_after_apply_template": prompt_text_after_apply_template_list,
"audio_processing_time": total_audio_processing_time,
"speech_tokenization_time": total_speech_tokenization_time,
"text_tokenization_time": total_text_tokenization_time,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
# world_size, local_rank, rank = init_distributed()
local_rank, world_size, rank = 0, 1, 0
device = torch.device(f"cuda:{local_rank}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
# model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
# Initialize backend based on argument
if args.backend == "hf":
# Load HuggingFace model
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
model.eval()
model.to(device)
runner = None
elif args.backend == "trtllm":
# Validate engine_dir is provided
if args.engine_dir is None:
raise ValueError("--engine-dir is required when backend is 'trtllm'")
# import tensorrt_llm
#from tensorrt_llm.runtime import ModelRunnerCpp
# Initialize TensorRT-LLM runner
runtime_rank = tensorrt_llm.mpi_rank()
model = None
# Prepare input for runner initialization
runner_kwargs = dict(
engine_dir=args.engine_dir,
rank=runtime_rank,
max_output_len=2048,
enable_context_fmha_fp32_acc=False,
max_batch_size=args.batch_size,
max_input_len=512,
kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
cuda_graph_mode=False,
gather_generation_logits=False,
)
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
elif args.backend == "vllm":
# from vllm import LLM, SamplingParams
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
runner = None
else:
raise ValueError(f"Unsupported backend: {args.backend}")
token2wav_model = CosyVoice2_Token2Wav(
model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
)
if args.prompt_speech_path:
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
else:
prompt_speech_16k = None
s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
dataset = load_dataset(
dataset_name,
split=args.split_name,
trust_remote_code=True,
)
# sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
sampler = None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
)
for _ in range(3):
print(f"Running {_} times")
total_llm_time = 0
total_token2wav_time = 0
total_data_load_time = 0
total_llm_post_processing_time = 0
total_audio_save_time = 0
total_audio_processing_time_in_collator = 0
total_speech_tokenization_time_in_collator = 0
total_text_tokenization_time_in_collator = 0
total_audio_samples = 0
start_time = time.time()
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
last_batch_end_time = time.time()
for batch in dataloader:
data_loaded_time = time.time()
total_data_load_time += data_loaded_time - last_batch_end_time
total_audio_processing_time_in_collator += batch["audio_processing_time"]
total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
with torch.no_grad():
# Generate speech tokens using LLM
llm_start_time = time.time()
if args.backend == "hf":
input_ids_list = batch["input_ids"]
if len(input_ids_list) == 1:
input_ids = input_ids_list[0].unsqueeze(0)
attention_mask = torch.ones_like(input_ids)
else:
# Handle batch > 1 if needed
max_len = max([len(input_ids) for input_ids in input_ids_list])
# input_ids_list_new = [
# torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
# for input_ids in input_ids_list
# ]
input_ids_list_new = [
torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
for input_ids in input_ids_list
]
input_ids = torch.stack(input_ids_list_new)
# compute attention mask
attention_mask = torch.zeros_like(input_ids)
for i in range(len(input_ids_list)):
attention_mask[i, :len(input_ids_list[i])] = 1
# breakpoint()
input_ids = input_ids.to(device)
outputs = model.generate(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
max_new_tokens=2048, # Max length for generation
do_sample=True,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=1.1,
top_k=args.top_k,
)
torch.cuda.synchronize()
elif args.backend == "trtllm":
# Convert input_ids to list of tensors for TensorRT-LLM
batch_input_ids = [ids for ids in batch["input_ids"]]
input_lengths = [x.size(0) for x in batch_input_ids]
# Get end_id from tokenizer
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
# random_seed=42, repetition_penalty=1.1,
outputs = runner.generate(
batch_input_ids=batch_input_ids,
max_new_tokens=2048,
end_id=end_id,
pad_id=end_id,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=1.1,
num_return_sequences=1,
streaming=False,
output_sequence_lengths=True,
output_generation_logits=False,
return_dict=True,
return_all_generated_tokens=False
)
torch.cuda.synchronize()
# Extract output_ids from TensorRT-LLM output
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
num_output_sents, num_beams, _ = output_ids.size()
assert num_beams == 1
beam = 0
batch_size = len(batch["input_ids"])
num_return_sequences = num_output_sents // batch_size
assert num_return_sequences == 1
outputs = []
for i in range(batch_size * num_return_sequences):
batch_idx = i // num_return_sequences
seq_idx = i % num_return_sequences
# inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist()
# input_text = tokenizer.decode(inputs)
# print(f'Input [Text {batch_idx}]: \"{input_text}\"')
output_begin = input_lengths[batch_idx]
output_end = sequence_lengths[i][beam]
# outputs_i = output_ids[i][beam][output_begin:output_end].tolist()
outputs_i = output_ids[i][beam][:output_end].tolist()
outputs.append(outputs_i)
elif args.backend == "vllm":
input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
# prompts = [batch["prompt_text_after_apply_template"][i] for i in range(len(batch["prompt_text_after_apply_template"]))]
# print(prompts)
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=1.1,
max_tokens=2048,
)
outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
# outputs = model.generate(prompts, sampling_params)
print(outputs)
# breakpoint()
for j, output in enumerate(outputs):
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
llm_end_time = time.time()
total_llm_time += (llm_end_time - llm_start_time)
items_for_token2wav = []
for i in range(len(batch["ids"])):
llm_post_processing_start_time = time.time()
# Extract generated tokens (excluding input)
input_length = len(batch["input_ids"][i])
generated_ids = outputs[i][input_length:] # Remove last token if needed
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Extract speech IDs from token strings like <|s_23456|>
speech_ids = extract_speech_ids(speech_tokens_str)
print(i, speech_ids)
# breakpoint()
if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
continue
if args.prompt_text is not None:
current_prompt_text = args.prompt_text
current_prompt_audio = prompt_speech_16k
else:
current_prompt_text = batch["prompt_text"][i]
current_prompt_audio = batch["prompt_audio_list"][i]
llm_post_processing_end_time = time.time()
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
if current_prompt_audio is not None:
items_for_token2wav.append({
"speech_ids": speech_ids,
"prompt_audio": current_prompt_audio.squeeze(0),
"id": batch["ids"][i]
})
else:
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
for i in range(0, len(items_for_token2wav), args.token2wav_batch_size):
t2w_batch = items_for_token2wav[i:i + args.token2wav_batch_size]
if not t2w_batch:
continue
t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
t2w_ids = [item["id"] for item in t2w_batch]
# Generate audio using CosyVoice2
token2wav_start_time = time.time()
generated_wavs = token2wav_model(
t2w_generated_speech_tokens_list,
t2w_prompt_audios_list,
t2w_prompt_audios_sample_rate,
)
torch.cuda.synchronize()
token2wav_end_time = time.time()
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
audio_save_start_time = time.time()
# Convert to numpy and save
for j, audio_hat in enumerate(generated_wavs):
generated_wave = audio_hat.squeeze().cpu().numpy()
total_audio_samples += len(generated_wave)
target_sample_rate = 24000
utt = t2w_ids[j]
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
audio_save_end_time = time.time()
total_audio_save_time += audio_save_end_time - audio_save_start_time
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
last_batch_end_time = time.time()
if rank == 0:
progress_bar.close()
end_time = time.time()
target_sample_rate = 24000
total_audio_duration_seconds = total_audio_samples / target_sample_rate
log_file_path = os.path.join(args.output_dir, "log.txt")
with open(log_file_path, 'w') as f:
# Convert Namespace to dict for JSON serialization
args_dict = vars(args)
log_data = {
"args": args_dict,
"data_load_time_seconds": total_data_load_time,
"audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
"speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
"text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
"llm_time_seconds": total_llm_time,
"llm_post_processing_time_seconds": total_llm_post_processing_time,
"token2wav_time_seconds": total_token2wav_time,
"audio_save_time_seconds": total_audio_save_time,
"total_audio_duration_seconds": total_audio_duration_seconds,
"pipeline_time_seconds": end_time - start_time,
}
print(log_data)
f.write(json.dumps(log_data, indent=4))
print(f"Metrics logged to {log_file_path}")
if __name__ == "__main__":
args = get_args()
if args.backend == "vllm":
from vllm import LLM, SamplingParams
elif args.backend == "trtllm":
import tensorrt_llm
from tensorrt_llm.runtime import ModelRunnerCpp
elif args.backend == "hf":
from transformers import AutoModelForCausalLM
else:
raise ValueError(f"Unsupported backend: {args.backend}")
main(args)