diff --git a/examples/grpo/cosyvoice2/huggingface_to_pretrained.py b/examples/grpo/cosyvoice2/huggingface_to_pretrained.py index 60692d9..ca49fc3 100644 --- a/examples/grpo/cosyvoice2/huggingface_to_pretrained.py +++ b/examples/grpo/cosyvoice2/huggingface_to_pretrained.py @@ -21,6 +21,7 @@ import torch from safetensors import safe_open from transformers import AutoTokenizer + def get_args(): parser = ArgumentParser() @@ -39,6 +40,7 @@ def get_args(): args = parser.parse_args() return args + if __name__ == "__main__": args = get_args() @@ -67,4 +69,3 @@ if __name__ == "__main__": hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"] torch.save(hf_tensors, args.output_path) - diff --git a/examples/grpo/cosyvoice2/infer_dataset.py b/examples/grpo/cosyvoice2/infer_dataset.py index 40c968d..4dcbc96 100644 --- a/examples/grpo/cosyvoice2/infer_dataset.py +++ b/examples/grpo/cosyvoice2/infer_dataset.py @@ -105,6 +105,7 @@ def extract_speech_ids(speech_tokens_str): print(f"Unexpected token: {token_str}") return speech_ids + def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens): """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>""" speech_id_str = "" @@ -182,14 +183,13 @@ def get_args(): return args - def data_collator(batch, tokenizer, s3_tokenizer): """Simplified data collator for batch_size=1 processing""" target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu") input_ids_list, prompt_audio_list, prompt_text_list = [], [], [] mels, prompt_audio_cosy2tokens_list = [], [] - for i, item in enumerate(batch): + for item in batch: prompt_text, target_text = ( item["prompt_text"], item["target_text"], @@ -227,7 +227,7 @@ def data_collator(batch, tokenizer, s3_tokenizer): codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device)) for i in range(len(codes)): prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()]) - for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list): + for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list: prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens) # Create chat template for LLM generation chat = [ @@ -244,7 +244,6 @@ def data_collator(batch, tokenizer, s3_tokenizer): ) input_ids_list.append(input_ids.squeeze(0)) - # For batch_size=1, no need to pad if len(input_ids_list) == 1: input_ids = input_ids_list[0].unsqueeze(0) @@ -256,7 +255,7 @@ def data_collator(batch, tokenizer, s3_tokenizer): for input_ids in input_ids_list ] input_ids = torch.stack(input_ids_list) - + ids = [item["id"] for item in batch] return { @@ -287,7 +286,7 @@ def main(): assert torch.cuda.is_available() world_size, local_rank, rank = init_distributed() device = torch.device(f"cuda:{local_rank}") - + # Load LLM model and tokenizer directly tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path) model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path) @@ -329,7 +328,7 @@ def main(): for batch in dataloader: with torch.no_grad(): input_ids = batch["input_ids"].to(device) - + # Generate speech tokens using LLM outputs = model.generate( input_ids, @@ -339,31 +338,31 @@ def main(): temperature=args.temperature, top_k=args.top_k, ) - + # Process each sample in the batch for i in range(len(batch["ids"])): # Extract generated tokens (excluding input) input_length = input_ids[i].shape[0] generated_ids = outputs[i][input_length:-1] # Remove last token if needed speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - + # Extract speech IDs from token strings like <|s_23456|> speech_ids = extract_speech_ids(speech_tokens_str) - + if len(speech_ids) == 0: print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping") continue - + # Convert to tensor for CosyVoice2 audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0) - + if args.prompt_text is not None: current_prompt_text = args.prompt_text current_prompt_audio = prompt_speech_16k else: current_prompt_text = batch["prompt_text"][i] current_prompt_audio = batch["prompt_audio_list"][i] - + if current_prompt_audio is not None: # Generate audio using CosyVoice2 audio_hat = audio_decode_cosyvoice2( @@ -372,18 +371,17 @@ def main(): current_prompt_audio, cosyvoice_codec, ) - + # Convert to numpy and save generated_wave = audio_hat.squeeze(0).cpu().numpy() target_sample_rate = 24000 - + utt = batch["ids"][i] sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate) print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens") else: print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping") - if rank == 0: progress_bar.update(world_size * len(batch["ids"])) diff --git a/examples/grpo/cosyvoice2/prepare_data.py b/examples/grpo/cosyvoice2/prepare_data.py index e63ae47..46c3c09 100644 --- a/examples/grpo/cosyvoice2/prepare_data.py +++ b/examples/grpo/cosyvoice2/prepare_data.py @@ -23,8 +23,6 @@ import datasets from verl.utils.hdfs_io import copy, makedirs - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file") diff --git a/examples/grpo/cosyvoice2/pretrained_to_huggingface.py b/examples/grpo/cosyvoice2/pretrained_to_huggingface.py index 5034f56..e2a9962 100644 --- a/examples/grpo/cosyvoice2/pretrained_to_huggingface.py +++ b/examples/grpo/cosyvoice2/pretrained_to_huggingface.py @@ -31,7 +31,6 @@ import torch sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") - def get_args(): parser = ArgumentParser() @@ -96,17 +95,20 @@ if __name__ == "__main__": # set the weight and bias of the new lm_head to 0 new_lm_head.weight.data.zero_() new_lm_head.bias.data.zero_() - new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.weight - new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.bias + new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight + new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias llm.lm_head = new_lm_head input_embeddings = llm.get_input_embeddings() with torch.no_grad(): - input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = speech_embedding.weight - input_embeddings.weight[original_tokenizer_vocab_size+cosyvoice2_token_size+3:original_tokenizer_vocab_size+cosyvoice2_token_size+3+2] = llm_embedding.weight + input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight + input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight - eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size, original_tokenizer_vocab_size + cosyvoice2_token_size + 1, original_tokenizer_vocab_size + cosyvoice2_token_size + 2] + eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size, + original_tokenizer_vocab_size + cosyvoice2_token_size + 1, + original_tokenizer_vocab_size + cosyvoice2_token_size + 2, + original_tokenizer_vocab_size + cosyvoice2_token_size + 3] llm.generation_config.eos_token_id = eos_token_ids llm.generation_config.temperature = 1.0 llm.generation_config.top_p = 0.8 @@ -121,4 +123,4 @@ if __name__ == "__main__": TEMPLATE = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- '<|sos|>' + message['content'] + '<|task_id|>' }}{%- elif message['role'] == 'assistant' %}{{- message['content']}}{%- endif %}{%- endfor %}" tokenizer.chat_template = TEMPLATE - tokenizer.save_pretrained(args.save_path) \ No newline at end of file + tokenizer.save_pretrained(args.save_path) diff --git a/examples/grpo/cosyvoice2/reward_tts.py b/examples/grpo/cosyvoice2/reward_tts.py index f49dc6d..4c40761 100644 --- a/examples/grpo/cosyvoice2/reward_tts.py +++ b/examples/grpo/cosyvoice2/reward_tts.py @@ -18,7 +18,10 @@ Reward calculation for CosyVoice2-0.5B. from __future__ import annotations -import os, re, warnings, json, time, argparse +import re +import json +import time +import argparse from typing import List import numpy as np @@ -31,6 +34,7 @@ REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer" def _parse_ids(token_str: str) -> List[int]: return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)] + def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float: """Send token IDs and ground-truth text to the Triton server and get reward.""" @@ -100,7 +104,6 @@ def compute_score( try: reward = _remote_reward(ids, ground_truth) except Exception as e: - warnings.warn(f"Remote reward server error: {e}; returning 0.0") reward = 0.0 if debug_dump: @@ -110,46 +113,46 @@ def compute_score( return reward + # CLI quick test if __name__ == "__main__": import sys - + def get_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Test TTS CER scoring with data from JSONL file", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - + parser.add_argument( "--input", "-i", type=str, default="data/emilia_zh-cosy-tiny-test.jsonl", help="Path to input JSONL file" ) - + parser.add_argument( "--max-samples", "-n", type=int, default=None, help="Maximum number of samples to process (default: all)" ) - + parser.add_argument( "--no-interactive", action="store_true", help="Run in non-interactive mode (process all samples without prompts)" ) - - + parser.add_argument( "--debug", action="store_true", help="Enable debug mode" ) - + return parser.parse_args() - + def load_jsonl(file_path: str): """Load data from jsonl file.""" data = [] @@ -157,37 +160,37 @@ if __name__ == "__main__": for line in f: data.append(json.loads(line.strip())) return data - + def code_to_solution_str(code_list: List[int]) -> str: """Convert code list to solution string format.""" return ''.join([f"<|s_{code}|>" for code in code_list]) - + # Parse command line arguments args = get_args() - + try: # Load data from jsonl file print(f"Loading data from: {args.input}") data_list = load_jsonl(args.input) print(f"Loaded {len(data_list)} samples") - + # Limit samples if specified if args.max_samples is not None: data_list = data_list[:args.max_samples] print(f"Processing first {len(data_list)} samples (limited by --max-samples)") - + # Process each sample begin_time = time.time() for i, sample in enumerate(data_list): print(f"\n--- Sample {i+1}/{len(data_list)} ---") print(f"Index: {sample.get('index', 'unknown')}") print(f"Text: {sample['text']}") - + # Extract required fields code_list = sample['code'] ground_truth = sample['text'] data_source = sample.get('index', f'sample_{i}') # Use index as data_source - + # Convert code list to solution string solution_str = code_to_solution_str(code_list) print(f"Solution tokens: {len(code_list)} tokens") @@ -195,7 +198,7 @@ if __name__ == "__main__": print(f"Solution string: {solution_str}") else: print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}") - + # Call compute_score function try: score = compute_score( @@ -208,7 +211,7 @@ if __name__ == "__main__": print(f"Final Score: {score:.4f}") except Exception as e: print(f"Error computing score: {e}") - + # Ask user if they want to continue (for interactive mode) if not args.no_interactive and i < len(data_list) - 1: try: @@ -218,7 +221,7 @@ if __name__ == "__main__": except KeyboardInterrupt: print("\nStopped by user") break - + print(f"\nProcessed {min(i+1, len(data_list))} samples") end_time = time.time() print(f"Time taken: {end_time - begin_time} seconds") diff --git a/examples/grpo/cosyvoice2/scripts/offline-decode-files.py b/examples/grpo/cosyvoice2/scripts/offline-decode-files.py index 35fc03d..847d434 100644 --- a/examples/grpo/cosyvoice2/scripts/offline-decode-files.py +++ b/examples/grpo/cosyvoice2/scripts/offline-decode-files.py @@ -102,6 +102,7 @@ import string punctuation_all = punctuation + string.punctuation Pathlike = Union[str, Path] + def remove_punctuation(text: str) -> str: for x in punctuation_all: if x == '\'': @@ -109,6 +110,7 @@ def remove_punctuation(text: str) -> str: text = text.replace(x, '') return text + def store_transcripts( filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False ) -> None: @@ -304,6 +306,7 @@ def write_error_stats( print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) return float(tot_err_rate) + def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -533,7 +536,7 @@ def get_args(): default=None, help="wav_base_name label", ) - + # Dataset related arguments for loading labels when label file is not provided parser.add_argument( "--dataset-name", @@ -541,14 +544,14 @@ def get_args(): default="yuekai/seed_tts_cosy2", help="Huggingface dataset name for loading labels", ) - + parser.add_argument( "--split-name", type=str, default="wenetspeech4tts", help="Dataset split name for loading labels", ) - + return parser.parse_args() @@ -590,7 +593,7 @@ def normalize_text_alimeeting(text: str) -> str: See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl """ import re - text = text.replace('\u00A0', '') # test_hard + text = text.replace('\u00A0', '') # test_hard text = text.replace(" ", "") text = text.replace("", "") text = text.replace("<%>", "") @@ -685,10 +688,10 @@ def main(): print( f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" ) - + # Load labels either from file or from dataset labels_dict = {} - + if args.label: # Load labels from file (original functionality) print(f"Loading labels from file: {args.label}") @@ -716,11 +719,11 @@ def main(): split=args.split_name, trust_remote_code=True, ) - + for item in dataset: audio_id = item["id"] labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"]) - + print(f"Loaded {len(labels_dict)} labels from dataset") # Perform evaluation if labels are available @@ -750,4 +753,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/grpo/cosyvoice2/token2wav_asr_server.py b/examples/grpo/cosyvoice2/token2wav_asr_server.py index 1273c18..8a6cb6e 100644 --- a/examples/grpo/cosyvoice2/token2wav_asr_server.py +++ b/examples/grpo/cosyvoice2/token2wav_asr_server.py @@ -14,6 +14,13 @@ # limitations under the License. """Pytriton server for token2wav conversion and ASR""" +from datasets import load_dataset +from cosyvoice.cli.cosyvoice import CosyVoice2 +from omnisense.models import OmniSenseVoiceSmall +from pytriton.proxy.types import Request +from pytriton.triton import Triton, TritonConfig +from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor +from pytriton.decorators import batch import argparse import io import logging @@ -37,15 +44,6 @@ zh_tn_model = ZhNormalizer( overwrite_cache=True, ) -from pytriton.decorators import batch -from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor -from pytriton.triton import Triton, TritonConfig -from pytriton.proxy.types import Request - -from omnisense.models import OmniSenseVoiceSmall -from cosyvoice.cli.cosyvoice import CosyVoice2 - -from datasets import load_dataset sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") @@ -78,7 +76,6 @@ class _ASR_Server: return {"TRANSCRIPTS": transcripts} - def audio_decode_cosyvoice2( audio_tokens, prompt_text, prompt_speech_16k, codec_decoder ): @@ -123,12 +120,12 @@ def get_random_prompt_from_dataset(dataset): """ random_idx = random.randint(0, len(dataset) - 1) sample = dataset[random_idx] - + # Extract audio data audio_data = sample["audio"] audio_array = audio_data["array"] sample_rate = audio_data["sampling_rate"] - + # Convert audio to 16kHz if needed if sample_rate != 16000: num_samples = int(len(audio_array) * (16000 / sample_rate)) @@ -141,6 +138,7 @@ def get_random_prompt_from_dataset(dataset): prompt_text = prompt_text.replace(" ", "") return prompt_text, prompt_speech_16k + class _Token2Wav_ASR: """Wraps a single OmniSenseVoiceSmall model instance for Triton.""" @@ -163,6 +161,7 @@ class _Token2Wav_ASR: self.codec_decoder = CosyVoice2( "/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True ) + @batch def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray): """ @@ -236,7 +235,6 @@ class _Token2Wav_ASR: transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8") rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1) - return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}