mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
fix lint
This commit is contained in:
@@ -78,7 +78,7 @@ For offline inference mode benchmark, please check the below command:
|
|||||||
# install FlashCosyVoice for token2wav batching
|
# install FlashCosyVoice for token2wav batching
|
||||||
# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt
|
# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt
|
||||||
# cd /workspace/FlashCosyVoice
|
# cd /workspace/FlashCosyVoice
|
||||||
# pip install -e .
|
# pip install -e .
|
||||||
# cd -
|
# cd -
|
||||||
# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx
|
# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ The following results were obtained by decoding on a single L20 GPU with 26 prom
|
|||||||
| HF | 1 | 39.26 | 44.31 | 0.2494 |
|
| HF | 1 | 39.26 | 44.31 | 0.2494 |
|
||||||
| HF | 2 | 30.54 | 35.62 | 0.2064 |
|
| HF | 2 | 30.54 | 35.62 | 0.2064 |
|
||||||
| HF | 4 | 18.63 | 23.90 | 0.1421 |
|
| HF | 4 | 18.63 | 23.90 | 0.1421 |
|
||||||
| HF | 8 | 11.22 | 16.45 | 0.0947 |
|
| HF | 8 | 11.22 | 16.45 | 0.0947 |
|
||||||
| HF | 16 | 8.42 | 13.78 | 0.0821 |
|
| HF | 16 | 8.42 | 13.78 | 0.0821 |
|
||||||
| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 |
|
| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 |
|
||||||
| TRTLLM | 2 | 7.64 |12.65 | 0.0739 |
|
| TRTLLM | 2 | 7.64 |12.65 | 0.0739 |
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ def extract_speech_ids(speech_tokens_str):
|
|||||||
print(f"Unexpected token: {token_str}")
|
print(f"Unexpected token: {token_str}")
|
||||||
return speech_ids
|
return speech_ids
|
||||||
|
|
||||||
|
|
||||||
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||||
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||||
speech_id_str = ""
|
speech_id_str = ""
|
||||||
@@ -167,7 +168,6 @@ def get_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def data_collator(batch, tokenizer, s3_tokenizer):
|
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||||
"""Simplified data collator for batch_size=1 processing"""
|
"""Simplified data collator for batch_size=1 processing"""
|
||||||
collator_start_time = time.time()
|
collator_start_time = time.time()
|
||||||
@@ -202,7 +202,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
|||||||
item["prompt_audio"]["sampling_rate"],
|
item["prompt_audio"]["sampling_rate"],
|
||||||
)
|
)
|
||||||
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
|
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
|
||||||
# ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
|
|
||||||
print(ref_audio_org.shape)
|
print(ref_audio_org.shape)
|
||||||
|
|
||||||
if ref_sr != target_sample_rate:
|
if ref_sr != target_sample_rate:
|
||||||
@@ -220,7 +219,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
|||||||
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
|
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
|
||||||
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
|
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
|
||||||
else:
|
else:
|
||||||
# convert to float first
|
|
||||||
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
|
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
|
||||||
|
|
||||||
if len(mels) > 0:
|
if len(mels) > 0:
|
||||||
@@ -287,33 +285,23 @@ def main(args):
|
|||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
# world_size, local_rank, rank = init_distributed()
|
|
||||||
local_rank, world_size, rank = 0, 1, 0
|
local_rank, world_size, rank = 0, 1, 0
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
|
||||||
# Load tokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
||||||
|
|
||||||
# model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
|
|
||||||
# Initialize backend based on argument
|
|
||||||
if args.backend == "hf":
|
if args.backend == "hf":
|
||||||
# Load HuggingFace model
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
runner = None
|
runner = None
|
||||||
elif args.backend == "trtllm":
|
elif args.backend == "trtllm":
|
||||||
# Validate engine_dir is provided
|
|
||||||
if args.engine_dir is None:
|
if args.engine_dir is None:
|
||||||
raise ValueError("--engine-dir is required when backend is 'trtllm'")
|
raise ValueError("--engine-dir is required when backend is 'trtllm'")
|
||||||
# import tensorrt_llm
|
|
||||||
#from tensorrt_llm.runtime import ModelRunnerCpp
|
|
||||||
|
|
||||||
# Initialize TensorRT-LLM runner
|
|
||||||
runtime_rank = tensorrt_llm.mpi_rank()
|
runtime_rank = tensorrt_llm.mpi_rank()
|
||||||
model = None
|
model = None
|
||||||
|
|
||||||
# Prepare input for runner initialization
|
|
||||||
runner_kwargs = dict(
|
runner_kwargs = dict(
|
||||||
engine_dir=args.engine_dir,
|
engine_dir=args.engine_dir,
|
||||||
rank=runtime_rank,
|
rank=runtime_rank,
|
||||||
@@ -328,7 +316,6 @@ def main(args):
|
|||||||
|
|
||||||
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
|
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
|
||||||
elif args.backend == "vllm":
|
elif args.backend == "vllm":
|
||||||
# from vllm import LLM, SamplingParams
|
|
||||||
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
|
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
|
||||||
runner = None
|
runner = None
|
||||||
else:
|
else:
|
||||||
@@ -349,7 +336,6 @@ def main(args):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
|
||||||
sampler = None
|
sampler = None
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
@@ -385,7 +371,6 @@ def main(args):
|
|||||||
total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
|
total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
|
||||||
total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
|
total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Generate speech tokens using LLM
|
|
||||||
llm_start_time = time.time()
|
llm_start_time = time.time()
|
||||||
if args.backend == "hf":
|
if args.backend == "hf":
|
||||||
input_ids_list = batch["input_ids"]
|
input_ids_list = batch["input_ids"]
|
||||||
@@ -393,31 +378,22 @@ def main(args):
|
|||||||
input_ids = input_ids_list[0].unsqueeze(0)
|
input_ids = input_ids_list[0].unsqueeze(0)
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
else:
|
else:
|
||||||
# Handle batch > 1 if needed
|
|
||||||
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||||
# input_ids_list_new = [
|
|
||||||
# torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
|
|
||||||
# for input_ids in input_ids_list
|
|
||||||
# ]
|
|
||||||
input_ids_list_new = [
|
input_ids_list_new = [
|
||||||
torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
|
torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
|
||||||
for input_ids in input_ids_list
|
for input_ids in input_ids_list
|
||||||
]
|
]
|
||||||
input_ids = torch.stack(input_ids_list_new)
|
input_ids = torch.stack(input_ids_list_new)
|
||||||
# compute attention mask
|
|
||||||
attention_mask = torch.zeros_like(input_ids)
|
attention_mask = torch.zeros_like(input_ids)
|
||||||
for i in range(len(input_ids_list)):
|
for i in range(len(input_ids_list)):
|
||||||
attention_mask[i, :len(input_ids_list[i])] = 1
|
attention_mask[i, :len(input_ids_list[i])] = 1
|
||||||
|
|
||||||
# breakpoint()
|
|
||||||
|
|
||||||
|
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
|
|
||||||
outputs = model.generate(
|
outputs = model.generate(
|
||||||
input_ids=input_ids.to(device),
|
input_ids=input_ids.to(device),
|
||||||
attention_mask=attention_mask.to(device),
|
attention_mask=attention_mask.to(device),
|
||||||
max_new_tokens=2048, # Max length for generation
|
max_new_tokens=2048,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
@@ -426,14 +402,11 @@ def main(args):
|
|||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
elif args.backend == "trtllm":
|
elif args.backend == "trtllm":
|
||||||
# Convert input_ids to list of tensors for TensorRT-LLM
|
|
||||||
batch_input_ids = [ids for ids in batch["input_ids"]]
|
batch_input_ids = [ids for ids in batch["input_ids"]]
|
||||||
input_lengths = [x.size(0) for x in batch_input_ids]
|
input_lengths = [x.size(0) for x in batch_input_ids]
|
||||||
|
|
||||||
# Get end_id from tokenizer
|
|
||||||
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
|
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
|
||||||
print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
|
print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
|
||||||
# random_seed=42, repetition_penalty=1.1,
|
|
||||||
outputs = runner.generate(
|
outputs = runner.generate(
|
||||||
batch_input_ids=batch_input_ids,
|
batch_input_ids=batch_input_ids,
|
||||||
max_new_tokens=2048,
|
max_new_tokens=2048,
|
||||||
@@ -451,7 +424,6 @@ def main(args):
|
|||||||
return_all_generated_tokens=False
|
return_all_generated_tokens=False
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# Extract output_ids from TensorRT-LLM output
|
|
||||||
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
|
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
|
||||||
num_output_sents, num_beams, _ = output_ids.size()
|
num_output_sents, num_beams, _ = output_ids.size()
|
||||||
assert num_beams == 1
|
assert num_beams == 1
|
||||||
@@ -463,18 +435,12 @@ def main(args):
|
|||||||
for i in range(batch_size * num_return_sequences):
|
for i in range(batch_size * num_return_sequences):
|
||||||
batch_idx = i // num_return_sequences
|
batch_idx = i // num_return_sequences
|
||||||
seq_idx = i % num_return_sequences
|
seq_idx = i % num_return_sequences
|
||||||
# inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist()
|
|
||||||
# input_text = tokenizer.decode(inputs)
|
|
||||||
# print(f'Input [Text {batch_idx}]: \"{input_text}\"')
|
|
||||||
output_begin = input_lengths[batch_idx]
|
output_begin = input_lengths[batch_idx]
|
||||||
output_end = sequence_lengths[i][beam]
|
output_end = sequence_lengths[i][beam]
|
||||||
# outputs_i = output_ids[i][beam][output_begin:output_end].tolist()
|
|
||||||
outputs_i = output_ids[i][beam][:output_end].tolist()
|
outputs_i = output_ids[i][beam][:output_end].tolist()
|
||||||
outputs.append(outputs_i)
|
outputs.append(outputs_i)
|
||||||
elif args.backend == "vllm":
|
elif args.backend == "vllm":
|
||||||
input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
|
input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
|
||||||
# prompts = [batch["prompt_text_after_apply_template"][i] for i in range(len(batch["prompt_text_after_apply_template"]))]
|
|
||||||
# print(prompts)
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
@@ -483,26 +449,21 @@ def main(args):
|
|||||||
max_tokens=2048,
|
max_tokens=2048,
|
||||||
)
|
)
|
||||||
outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
|
outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
|
||||||
# outputs = model.generate(prompts, sampling_params)
|
|
||||||
print(outputs)
|
print(outputs)
|
||||||
# breakpoint()
|
|
||||||
for j, output in enumerate(outputs):
|
for j, output in enumerate(outputs):
|
||||||
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
|
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
|
||||||
|
|
||||||
llm_end_time = time.time()
|
llm_end_time = time.time()
|
||||||
total_llm_time += (llm_end_time - llm_start_time)
|
total_llm_time += (llm_end_time - llm_start_time)
|
||||||
|
|
||||||
items_for_token2wav = []
|
items_for_token_2wav = []
|
||||||
for i in range(len(batch["ids"])):
|
for i in range(len(batch["ids"])):
|
||||||
llm_post_processing_start_time = time.time()
|
llm_post_processing_start_time = time.time()
|
||||||
# Extract generated tokens (excluding input)
|
|
||||||
input_length = len(batch["input_ids"][i])
|
input_length = len(batch["input_ids"][i])
|
||||||
generated_ids = outputs[i][input_length:] # Remove last token if needed
|
generated_ids = outputs[i][input_length:]
|
||||||
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
# Extract speech IDs from token strings like <|s_23456|>
|
|
||||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||||
print(i, speech_ids)
|
print(i, speech_ids)
|
||||||
# breakpoint()
|
|
||||||
if len(speech_ids) == 0:
|
if len(speech_ids) == 0:
|
||||||
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||||
continue
|
continue
|
||||||
@@ -517,7 +478,7 @@ def main(args):
|
|||||||
llm_post_processing_end_time = time.time()
|
llm_post_processing_end_time = time.time()
|
||||||
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
|
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
|
||||||
if current_prompt_audio is not None:
|
if current_prompt_audio is not None:
|
||||||
items_for_token2wav.append({
|
items_for_token_2wav.append({
|
||||||
"speech_ids": speech_ids,
|
"speech_ids": speech_ids,
|
||||||
"prompt_audio": current_prompt_audio.squeeze(0),
|
"prompt_audio": current_prompt_audio.squeeze(0),
|
||||||
"id": batch["ids"][i]
|
"id": batch["ids"][i]
|
||||||
@@ -525,8 +486,8 @@ def main(args):
|
|||||||
else:
|
else:
|
||||||
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||||
|
|
||||||
for i in range(0, len(items_for_token2wav), args.token2wav_batch_size):
|
for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
|
||||||
t2w_batch = items_for_token2wav[i:i + args.token2wav_batch_size]
|
t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
|
||||||
if not t2w_batch:
|
if not t2w_batch:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -535,7 +496,6 @@ def main(args):
|
|||||||
t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
|
t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
|
||||||
t2w_ids = [item["id"] for item in t2w_batch]
|
t2w_ids = [item["id"] for item in t2w_batch]
|
||||||
|
|
||||||
# Generate audio using CosyVoice2
|
|
||||||
token2wav_start_time = time.time()
|
token2wav_start_time = time.time()
|
||||||
generated_wavs = token2wav_model(
|
generated_wavs = token2wav_model(
|
||||||
t2w_generated_speech_tokens_list,
|
t2w_generated_speech_tokens_list,
|
||||||
@@ -547,7 +507,6 @@ def main(args):
|
|||||||
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
|
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
|
||||||
|
|
||||||
audio_save_start_time = time.time()
|
audio_save_start_time = time.time()
|
||||||
# Convert to numpy and save
|
|
||||||
for j, audio_hat in enumerate(generated_wavs):
|
for j, audio_hat in enumerate(generated_wavs):
|
||||||
generated_wave = audio_hat.squeeze().cpu().numpy()
|
generated_wave = audio_hat.squeeze().cpu().numpy()
|
||||||
total_audio_samples += len(generated_wave)
|
total_audio_samples += len(generated_wave)
|
||||||
@@ -571,7 +530,6 @@ def main(args):
|
|||||||
|
|
||||||
log_file_path = os.path.join(args.output_dir, "log.txt")
|
log_file_path = os.path.join(args.output_dir, "log.txt")
|
||||||
with open(log_file_path, 'w') as f:
|
with open(log_file_path, 'w') as f:
|
||||||
# Convert Namespace to dict for JSON serialization
|
|
||||||
args_dict = vars(args)
|
args_dict = vars(args)
|
||||||
log_data = {
|
log_data = {
|
||||||
"args": args_dict,
|
"args": args_dict,
|
||||||
@@ -602,4 +560,4 @@ if __name__ == "__main__":
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
|||||||
f.write(engine_bytes)
|
f.write(engine_bytes)
|
||||||
logging.info("Succesfully convert onnx to trt...")
|
logging.info("Succesfully convert onnx to trt...")
|
||||||
|
|
||||||
|
|
||||||
class TrtContextWrapper:
|
class TrtContextWrapper:
|
||||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||||
@@ -88,12 +89,13 @@ class TrtContextWrapper:
|
|||||||
def release_estimator(self, context, stream):
|
def release_estimator(self, context, stream):
|
||||||
self.trt_context_pool.put([context, stream])
|
self.trt_context_pool.put([context, stream])
|
||||||
|
|
||||||
|
|
||||||
class CosyVoice2_Token2Wav(torch.nn.Module):
|
class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||||
def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0):
|
def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device_id = device_id
|
self.device_id = device_id
|
||||||
self.device = f"cuda:{device_id}"
|
self.device = f"cuda:{device_id}"
|
||||||
|
|
||||||
self.flow = CausalMaskedDiffWithXvec()
|
self.flow = CausalMaskedDiffWithXvec()
|
||||||
self.flow.half()
|
self.flow.half()
|
||||||
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
||||||
@@ -107,22 +109,20 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
option = onnxruntime.SessionOptions()
|
option = onnxruntime.SessionOptions()
|
||||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
option.intra_op_num_threads = 1
|
option.intra_op_num_threads = 1
|
||||||
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
|
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
|
||||||
providers=["CPUExecutionProvider"])
|
|
||||||
|
|
||||||
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval()
|
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval()
|
||||||
|
|
||||||
gpu="l20"
|
gpu = "l20"
|
||||||
if enable_trt:
|
if enable_trt:
|
||||||
self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan',
|
self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan',
|
||||||
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
||||||
1,
|
1,
|
||||||
True)
|
True)
|
||||||
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||||
f'{model_dir}/campplus.onnx',
|
f'{model_dir}/campplus.onnx',
|
||||||
1,
|
1,
|
||||||
False)
|
False)
|
||||||
|
|
||||||
|
|
||||||
def forward_spk_embedding(self, spk_feat):
|
def forward_spk_embedding(self, spk_feat):
|
||||||
if isinstance(self.spk_model, onnxruntime.InferenceSession):
|
if isinstance(self.spk_model, onnxruntime.InferenceSession):
|
||||||
@@ -173,7 +173,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True):
|
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True):
|
||||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||||
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=2, max_batch_size=16)
|
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_bs=2, max_batch_size=16)
|
||||||
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16)
|
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16)
|
||||||
del self.flow.decoder.estimator
|
del self.flow.decoder.estimator
|
||||||
import tensorrt as trt
|
import tensorrt as trt
|
||||||
@@ -182,10 +182,11 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||||
|
|
||||||
def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64):
|
def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64):
|
||||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
|
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
|
||||||
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)]
|
opt_shape = [(opt_bs * 2, 80, 500), (opt_bs * 2, 1, 500), (opt_bs * 2, 80, 500), (opt_bs * 2, 80, 500), (opt_bs * 2,), (opt_bs * 2, 80)]
|
||||||
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)]
|
max_shape = [(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2,),
|
||||||
|
(max_batch_size * 2, 80)]
|
||||||
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
|
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
|
||||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
|
|
||||||
@@ -203,7 +204,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
||||||
prompt_speech_tokens_list.append(speech_tokens_i)
|
prompt_speech_tokens_list.append(speech_tokens_i)
|
||||||
return prompt_speech_tokens_list
|
return prompt_speech_tokens_list
|
||||||
|
|
||||||
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
||||||
spk_emb_for_flow = []
|
spk_emb_for_flow = []
|
||||||
for audio in prompt_audios_list:
|
for audio in prompt_audios_list:
|
||||||
@@ -213,9 +214,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
spk_emb = self.forward_spk_embedding(spk_feat)
|
spk_emb = self.forward_spk_embedding(spk_feat)
|
||||||
|
|
||||||
spk_emb_for_flow.append(spk_emb)
|
spk_emb_for_flow.append(spk_emb)
|
||||||
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
||||||
return spk_emb_for_flow
|
return spk_emb_for_flow
|
||||||
|
|
||||||
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
||||||
prompt_mels_for_flow = []
|
prompt_mels_for_flow = []
|
||||||
prompt_mels_lens_for_flow = []
|
prompt_mels_lens_for_flow = []
|
||||||
@@ -231,9 +232,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
||||||
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
||||||
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
||||||
|
|
||||||
|
|
||||||
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor,
|
||||||
|
prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
||||||
batch_size = prompt_mels_for_flow.shape[0]
|
batch_size = prompt_mels_for_flow.shape[0]
|
||||||
flow_inputs = []
|
flow_inputs = []
|
||||||
flow_inputs_lens = []
|
flow_inputs_lens = []
|
||||||
@@ -262,14 +263,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
generated_wavs.append(wav)
|
generated_wavs.append(wav)
|
||||||
return generated_wavs
|
return generated_wavs
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward(
|
def forward(
|
||||||
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||||
):
|
):
|
||||||
# assert all item in prompt_audios_sample_rate is 16000
|
# assert all item in prompt_audios_sample_rate is 16000
|
||||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
|
||||||
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
||||||
|
|
||||||
@@ -277,10 +276,11 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
|
|
||||||
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
||||||
|
|
||||||
generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
generated_mels, generated_mels_lens = self.forward_flow(
|
||||||
|
prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||||
|
|
||||||
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
||||||
|
|
||||||
return generated_wavs
|
return generated_wavs
|
||||||
|
|
||||||
|
|
||||||
@@ -288,13 +288,14 @@ def collate_fn(batch):
|
|||||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||||
for i, item in enumerate(batch):
|
for i, item in enumerate(batch):
|
||||||
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||||
prompt_audios_list.append(audio)
|
prompt_audios_list.append(audio)
|
||||||
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||||
ids.append(item['id'])
|
ids.append(item['id'])
|
||||||
|
|
||||||
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
|
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--enable-trt", action="store_true")
|
parser.add_argument("--enable-trt", action="store_true")
|
||||||
@@ -305,6 +306,7 @@ def get_args():
|
|||||||
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
|
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args()
|
args = get_args()
|
||||||
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
||||||
@@ -315,22 +317,19 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||||
|
|
||||||
|
|
||||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||||
|
|
||||||
|
|
||||||
for epoch in range(args.warmup):
|
for epoch in range(args.warmup):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
for batch in data_loader:
|
for batch in data_loader:
|
||||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
||||||
|
|
||||||
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
|
||||||
for id, wav in zip(ids, generated_wavs):
|
for id, wav in zip(ids, generated_wavs):
|
||||||
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
epoch_time = end_time - start_time
|
epoch_time = end_time - start_time
|
||||||
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||||
|
|||||||
Reference in New Issue
Block a user