mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -105,6 +105,7 @@ def extract_speech_ids(speech_tokens_str):
|
||||
print(f"Unexpected token: {token_str}")
|
||||
return speech_ids
|
||||
|
||||
|
||||
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||
speech_id_str = ""
|
||||
@@ -182,14 +183,13 @@ def get_args():
|
||||
return args
|
||||
|
||||
|
||||
|
||||
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
"""Simplified data collator for batch_size=1 processing"""
|
||||
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
||||
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
||||
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||
mels, prompt_audio_cosy2tokens_list = [], []
|
||||
for i, item in enumerate(batch):
|
||||
for item in batch:
|
||||
prompt_text, target_text = (
|
||||
item["prompt_text"],
|
||||
item["target_text"],
|
||||
@@ -227,7 +227,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
||||
for i in range(len(codes)):
|
||||
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
||||
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
|
||||
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
|
||||
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
||||
# Create chat template for LLM generation
|
||||
chat = [
|
||||
@@ -244,7 +244,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
)
|
||||
input_ids_list.append(input_ids.squeeze(0))
|
||||
|
||||
|
||||
# For batch_size=1, no need to pad
|
||||
if len(input_ids_list) == 1:
|
||||
input_ids = input_ids_list[0].unsqueeze(0)
|
||||
@@ -256,7 +255,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
for input_ids in input_ids_list
|
||||
]
|
||||
input_ids = torch.stack(input_ids_list)
|
||||
|
||||
|
||||
ids = [item["id"] for item in batch]
|
||||
|
||||
return {
|
||||
@@ -287,7 +286,7 @@ def main():
|
||||
assert torch.cuda.is_available()
|
||||
world_size, local_rank, rank = init_distributed()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
|
||||
# Load LLM model and tokenizer directly
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
||||
@@ -329,7 +328,7 @@ def main():
|
||||
for batch in dataloader:
|
||||
with torch.no_grad():
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
|
||||
|
||||
# Generate speech tokens using LLM
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
@@ -339,31 +338,31 @@ def main():
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
)
|
||||
|
||||
|
||||
# Process each sample in the batch
|
||||
for i in range(len(batch["ids"])):
|
||||
# Extract generated tokens (excluding input)
|
||||
input_length = input_ids[i].shape[0]
|
||||
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
|
||||
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
|
||||
# Extract speech IDs from token strings like <|s_23456|>
|
||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||
|
||||
|
||||
if len(speech_ids) == 0:
|
||||
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||
continue
|
||||
|
||||
|
||||
# Convert to tensor for CosyVoice2
|
||||
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
|
||||
|
||||
|
||||
if args.prompt_text is not None:
|
||||
current_prompt_text = args.prompt_text
|
||||
current_prompt_audio = prompt_speech_16k
|
||||
else:
|
||||
current_prompt_text = batch["prompt_text"][i]
|
||||
current_prompt_audio = batch["prompt_audio_list"][i]
|
||||
|
||||
|
||||
if current_prompt_audio is not None:
|
||||
# Generate audio using CosyVoice2
|
||||
audio_hat = audio_decode_cosyvoice2(
|
||||
@@ -372,18 +371,17 @@ def main():
|
||||
current_prompt_audio,
|
||||
cosyvoice_codec,
|
||||
)
|
||||
|
||||
|
||||
# Convert to numpy and save
|
||||
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||
target_sample_rate = 24000
|
||||
|
||||
|
||||
utt = batch["ids"][i]
|
||||
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
|
||||
|
||||
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
|
||||
else:
|
||||
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(world_size * len(batch["ids"]))
|
||||
|
||||
Reference in New Issue
Block a user