This commit is contained in:
root
2025-07-29 08:40:51 +00:00
parent d1c354eac7
commit 62d082634e
7 changed files with 71 additions and 68 deletions

View File

@@ -105,6 +105,7 @@ def extract_speech_ids(speech_tokens_str):
print(f"Unexpected token: {token_str}")
return speech_ids
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
speech_id_str = ""
@@ -182,14 +183,13 @@ def get_args():
return args
def data_collator(batch, tokenizer, s3_tokenizer):
"""Simplified data collator for batch_size=1 processing"""
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
mels, prompt_audio_cosy2tokens_list = [], []
for i, item in enumerate(batch):
for item in batch:
prompt_text, target_text = (
item["prompt_text"],
item["target_text"],
@@ -227,7 +227,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
for i in range(len(codes)):
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
# Create chat template for LLM generation
chat = [
@@ -244,7 +244,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
)
input_ids_list.append(input_ids.squeeze(0))
# For batch_size=1, no need to pad
if len(input_ids_list) == 1:
input_ids = input_ids_list[0].unsqueeze(0)
@@ -256,7 +255,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
for input_ids in input_ids_list
]
input_ids = torch.stack(input_ids_list)
ids = [item["id"] for item in batch]
return {
@@ -287,7 +286,7 @@ def main():
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
# Load LLM model and tokenizer directly
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
@@ -329,7 +328,7 @@ def main():
for batch in dataloader:
with torch.no_grad():
input_ids = batch["input_ids"].to(device)
# Generate speech tokens using LLM
outputs = model.generate(
input_ids,
@@ -339,31 +338,31 @@ def main():
temperature=args.temperature,
top_k=args.top_k,
)
# Process each sample in the batch
for i in range(len(batch["ids"])):
# Extract generated tokens (excluding input)
input_length = input_ids[i].shape[0]
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Extract speech IDs from token strings like <|s_23456|>
speech_ids = extract_speech_ids(speech_tokens_str)
if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
continue
# Convert to tensor for CosyVoice2
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
if args.prompt_text is not None:
current_prompt_text = args.prompt_text
current_prompt_audio = prompt_speech_16k
else:
current_prompt_text = batch["prompt_text"][i]
current_prompt_audio = batch["prompt_audio_list"][i]
if current_prompt_audio is not None:
# Generate audio using CosyVoice2
audio_hat = audio_decode_cosyvoice2(
@@ -372,18 +371,17 @@ def main():
current_prompt_audio,
cosyvoice_codec,
)
# Convert to numpy and save
generated_wave = audio_hat.squeeze(0).cpu().numpy()
target_sample_rate = 24000
utt = batch["ids"][i]
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
else:
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))