mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -65,6 +65,7 @@ def extract_speech_ids(speech_tokens_str):
|
||||
print(f"Unexpected token: {token_str}")
|
||||
return speech_ids
|
||||
|
||||
|
||||
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||
speech_id_str = ""
|
||||
@@ -167,7 +168,6 @@ def get_args():
|
||||
return args
|
||||
|
||||
|
||||
|
||||
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
"""Simplified data collator for batch_size=1 processing"""
|
||||
collator_start_time = time.time()
|
||||
@@ -202,7 +202,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
item["prompt_audio"]["sampling_rate"],
|
||||
)
|
||||
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
|
||||
# ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
|
||||
print(ref_audio_org.shape)
|
||||
|
||||
if ref_sr != target_sample_rate:
|
||||
@@ -220,7 +219,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
|
||||
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
|
||||
else:
|
||||
# convert to float first
|
||||
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
|
||||
|
||||
if len(mels) > 0:
|
||||
@@ -287,33 +285,23 @@ def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
# world_size, local_rank, rank = init_distributed()
|
||||
local_rank, world_size, rank = 0, 1, 0
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
||||
|
||||
# model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
|
||||
# Initialize backend based on argument
|
||||
if args.backend == "hf":
|
||||
# Load HuggingFace model
|
||||
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
runner = None
|
||||
elif args.backend == "trtllm":
|
||||
# Validate engine_dir is provided
|
||||
if args.engine_dir is None:
|
||||
raise ValueError("--engine-dir is required when backend is 'trtllm'")
|
||||
# import tensorrt_llm
|
||||
#from tensorrt_llm.runtime import ModelRunnerCpp
|
||||
|
||||
# Initialize TensorRT-LLM runner
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
model = None
|
||||
|
||||
# Prepare input for runner initialization
|
||||
runner_kwargs = dict(
|
||||
engine_dir=args.engine_dir,
|
||||
rank=runtime_rank,
|
||||
@@ -328,7 +316,6 @@ def main(args):
|
||||
|
||||
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
|
||||
elif args.backend == "vllm":
|
||||
# from vllm import LLM, SamplingParams
|
||||
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
|
||||
runner = None
|
||||
else:
|
||||
@@ -349,7 +336,6 @@ def main(args):
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||
sampler = None
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
@@ -385,7 +371,6 @@ def main(args):
|
||||
total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
|
||||
total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
|
||||
with torch.no_grad():
|
||||
# Generate speech tokens using LLM
|
||||
llm_start_time = time.time()
|
||||
if args.backend == "hf":
|
||||
input_ids_list = batch["input_ids"]
|
||||
@@ -393,31 +378,22 @@ def main(args):
|
||||
input_ids = input_ids_list[0].unsqueeze(0)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
# Handle batch > 1 if needed
|
||||
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||
# input_ids_list_new = [
|
||||
# torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
|
||||
# for input_ids in input_ids_list
|
||||
# ]
|
||||
input_ids_list_new = [
|
||||
torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
|
||||
for input_ids in input_ids_list
|
||||
]
|
||||
input_ids = torch.stack(input_ids_list_new)
|
||||
# compute attention mask
|
||||
attention_mask = torch.zeros_like(input_ids)
|
||||
for i in range(len(input_ids_list)):
|
||||
attention_mask[i, :len(input_ids_list[i])] = 1
|
||||
|
||||
# breakpoint()
|
||||
|
||||
|
||||
input_ids = input_ids.to(device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
max_new_tokens=2048, # Max length for generation
|
||||
max_new_tokens=2048,
|
||||
do_sample=True,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
@@ -426,14 +402,11 @@ def main(args):
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
elif args.backend == "trtllm":
|
||||
# Convert input_ids to list of tensors for TensorRT-LLM
|
||||
batch_input_ids = [ids for ids in batch["input_ids"]]
|
||||
input_lengths = [x.size(0) for x in batch_input_ids]
|
||||
|
||||
# Get end_id from tokenizer
|
||||
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
|
||||
print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
|
||||
# random_seed=42, repetition_penalty=1.1,
|
||||
outputs = runner.generate(
|
||||
batch_input_ids=batch_input_ids,
|
||||
max_new_tokens=2048,
|
||||
@@ -451,7 +424,6 @@ def main(args):
|
||||
return_all_generated_tokens=False
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
# Extract output_ids from TensorRT-LLM output
|
||||
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
|
||||
num_output_sents, num_beams, _ = output_ids.size()
|
||||
assert num_beams == 1
|
||||
@@ -463,18 +435,12 @@ def main(args):
|
||||
for i in range(batch_size * num_return_sequences):
|
||||
batch_idx = i // num_return_sequences
|
||||
seq_idx = i % num_return_sequences
|
||||
# inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist()
|
||||
# input_text = tokenizer.decode(inputs)
|
||||
# print(f'Input [Text {batch_idx}]: \"{input_text}\"')
|
||||
output_begin = input_lengths[batch_idx]
|
||||
output_end = sequence_lengths[i][beam]
|
||||
# outputs_i = output_ids[i][beam][output_begin:output_end].tolist()
|
||||
outputs_i = output_ids[i][beam][:output_end].tolist()
|
||||
outputs.append(outputs_i)
|
||||
elif args.backend == "vllm":
|
||||
input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
|
||||
# prompts = [batch["prompt_text_after_apply_template"][i] for i in range(len(batch["prompt_text_after_apply_template"]))]
|
||||
# print(prompts)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
@@ -483,26 +449,21 @@ def main(args):
|
||||
max_tokens=2048,
|
||||
)
|
||||
outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
|
||||
# outputs = model.generate(prompts, sampling_params)
|
||||
print(outputs)
|
||||
# breakpoint()
|
||||
for j, output in enumerate(outputs):
|
||||
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
|
||||
|
||||
llm_end_time = time.time()
|
||||
total_llm_time += (llm_end_time - llm_start_time)
|
||||
|
||||
items_for_token2wav = []
|
||||
items_for_token_2wav = []
|
||||
for i in range(len(batch["ids"])):
|
||||
llm_post_processing_start_time = time.time()
|
||||
# Extract generated tokens (excluding input)
|
||||
input_length = len(batch["input_ids"][i])
|
||||
generated_ids = outputs[i][input_length:] # Remove last token if needed
|
||||
generated_ids = outputs[i][input_length:]
|
||||
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
# Extract speech IDs from token strings like <|s_23456|>
|
||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||
print(i, speech_ids)
|
||||
# breakpoint()
|
||||
if len(speech_ids) == 0:
|
||||
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||
continue
|
||||
@@ -517,7 +478,7 @@ def main(args):
|
||||
llm_post_processing_end_time = time.time()
|
||||
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
|
||||
if current_prompt_audio is not None:
|
||||
items_for_token2wav.append({
|
||||
items_for_token_2wav.append({
|
||||
"speech_ids": speech_ids,
|
||||
"prompt_audio": current_prompt_audio.squeeze(0),
|
||||
"id": batch["ids"][i]
|
||||
@@ -525,8 +486,8 @@ def main(args):
|
||||
else:
|
||||
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||
|
||||
for i in range(0, len(items_for_token2wav), args.token2wav_batch_size):
|
||||
t2w_batch = items_for_token2wav[i:i + args.token2wav_batch_size]
|
||||
for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
|
||||
t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
|
||||
if not t2w_batch:
|
||||
continue
|
||||
|
||||
@@ -535,7 +496,6 @@ def main(args):
|
||||
t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
|
||||
t2w_ids = [item["id"] for item in t2w_batch]
|
||||
|
||||
# Generate audio using CosyVoice2
|
||||
token2wav_start_time = time.time()
|
||||
generated_wavs = token2wav_model(
|
||||
t2w_generated_speech_tokens_list,
|
||||
@@ -547,7 +507,6 @@ def main(args):
|
||||
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
|
||||
|
||||
audio_save_start_time = time.time()
|
||||
# Convert to numpy and save
|
||||
for j, audio_hat in enumerate(generated_wavs):
|
||||
generated_wave = audio_hat.squeeze().cpu().numpy()
|
||||
total_audio_samples += len(generated_wave)
|
||||
@@ -571,7 +530,6 @@ def main(args):
|
||||
|
||||
log_file_path = os.path.join(args.output_dir, "log.txt")
|
||||
with open(log_file_path, 'w') as f:
|
||||
# Convert Namespace to dict for JSON serialization
|
||||
args_dict = vars(args)
|
||||
log_data = {
|
||||
"args": args_dict,
|
||||
@@ -602,4 +560,4 @@ if __name__ == "__main__":
|
||||
from transformers import AutoModelForCausalLM
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
main(args)
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user