From 8811e9f33a5e7a14ad308f821b967f394e72bdcc Mon Sep 17 00:00:00 2001 From: root Date: Thu, 9 Oct 2025 14:49:22 +0800 Subject: [PATCH] fix white space --- examples/grpo/cosyvoice2/README.md | 4 +-- examples/grpo/cosyvoice2/run.sh | 6 ++-- .../model_repo/cosyvoice2_dit/1/model.py | 4 +-- .../model_repo/token2wav_dit/1/model.py | 6 ++-- .../token2wav_dit/1/token2wav_dit.py | 29 ++++++++----------- runtime/triton_trtllm/offline_inference.py | 1 - runtime/triton_trtllm/streaming_inference.py | 10 +++---- 7 files changed, 27 insertions(+), 33 deletions(-) diff --git a/examples/grpo/cosyvoice2/README.md b/examples/grpo/cosyvoice2/README.md index 8783aa1..1f5c6a0 100644 --- a/examples/grpo/cosyvoice2/README.md +++ b/examples/grpo/cosyvoice2/README.md @@ -36,7 +36,7 @@ Stage `0` converts raw JSONL files into the parquet format expected by veRL: ```bash bash run.sh 0 0 ``` -Create two JSONL files—`train.jsonl` and `test.jsonl`. +Create two JSONL files—`train.jsonl` and `test.jsonl`. The script will then generate two Parquet files: ``` @@ -111,7 +111,7 @@ bash run.sh 5 5 The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository. > [!TIP] -> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format. +> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format. ## Results diff --git a/examples/grpo/cosyvoice2/run.sh b/examples/grpo/cosyvoice2/run.sh index ce97ab3..b1658e2 100644 --- a/examples/grpo/cosyvoice2/run.sh +++ b/examples/grpo/cosyvoice2/run.sh @@ -33,7 +33,7 @@ fi if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint" - modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path + modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path python3 pretrained_to_huggingface.py \ --pretrained-cosyvoice2-path $model_scope_model_path \ --save-path $sft_model_path @@ -61,7 +61,7 @@ fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "stage 1: start token2wav asr server for reward function" python3 token2wav_asr_server.py --number-of-devices 8 -fi +fi exp_name=official_llm_aishell3_grpo if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -125,7 +125,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then --backend fsdp \ --local_dir $llm_path/actor \ --target_dir $llm_path/merged_hf_model || exit 1 -fi +fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "stage 4: Test the model" diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index 523a5b8..8e2b28b 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -254,7 +254,7 @@ class TritonPythonModel: target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] - + # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav_dit', @@ -362,7 +362,7 @@ class TritonPythonModel: chunk_index += 1 else: break - + this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index 1f6b591..1f90644 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -62,7 +62,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str: # Create a SHA-256 hash of the byte string hasher = hashlib.sha256() hasher.update(tensor_bytes) - + return hasher.hexdigest() class TritonPythonModel: @@ -111,9 +111,9 @@ class TritonPythonModel: target_speech_tokens = target_speech_tokens.squeeze().tolist() finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() - + request_id = request.request_id() - + wav_array = pb_utils.get_input_tensor_by_name( request, "reference_wav").as_numpy() diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index 3d50325..d413003 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -133,7 +133,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): option.intra_op_num_threads = 1 self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"]) - self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() gpu="l20" @@ -253,7 +252,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() prompt_speech_tokens_list.append(speech_tokens_i) return prompt_speech_tokens_list - + def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: spk_emb_for_flow = [] for audio in prompt_audios_list: @@ -263,11 +262,11 @@ class CosyVoice2_Token2Wav(torch.nn.Module): spk_emb = self.forward_spk_embedding(spk_feat) spk_emb_for_flow.append(spk_emb) - spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) if self.dtype != torch.float32: spk_emb_for_flow = spk_emb_for_flow.to(self.dtype) return spk_emb_for_flow - + def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): prompt_mels_for_flow = [] prompt_mels_lens_for_flow = [] @@ -283,7 +282,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) return prompt_mels_for_flow, prompt_mels_lens_for_flow - + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): batch_size = prompt_mels_for_flow.shape[0] flow_inputs = [] @@ -318,28 +317,24 @@ class CosyVoice2_Token2Wav(torch.nn.Module): self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] ): assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) - prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) - return generated_wavs def prepare_prompt_audio( self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] ): assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) - prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) - return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow @@ -365,7 +360,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): @torch.inference_mode() def forward_streaming( self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 - ): + ): if speaker_id not in self.speaker_cache: assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio_sample_rate == 16000 @@ -384,7 +379,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): if request_id not in self.streaming_flow_cache: self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.hift_cache_dict[request_id] = dict( - mel = torch.zeros(1, 80, 0, device='cuda'), + mel = torch.zeros(1, 80, 0, device='cuda'), source = torch.zeros(1, 1, 0, device='cuda'), speech = torch.zeros(1, 0, device='cuda'), ) @@ -445,7 +440,7 @@ def collate_fn(batch): ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] for i, item in enumerate(batch): generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) - audio = torch.from_numpy(item['prompt_audio']['array']).float() + audio = torch.from_numpy(item['prompt_audio']['array']).float() prompt_audios_list.append(audio) prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) ids.append(item['id']) @@ -473,20 +468,20 @@ if __name__ == "__main__": data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) - - + + for epoch in range(args.warmup): start_time = time.time() - + for batch in data_loader: ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) - + for id, wav in zip(ids, generated_wavs): torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) - + end_time = time.time() epoch_time = end_time - start_time print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") \ No newline at end of file diff --git a/runtime/triton_trtllm/offline_inference.py b/runtime/triton_trtllm/offline_inference.py index d309d18..77f2915 100644 --- a/runtime/triton_trtllm/offline_inference.py +++ b/runtime/triton_trtllm/offline_inference.py @@ -365,7 +365,6 @@ def main(args): runner = None else: raise ValueError(f"Unsupported backend: {args.backend}") - if 'Step-Audio-2-mini' in args.token2wav_path: from token2wav_dit import CosyVoice2_Token2Wav else: diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index a5404e2..e9c2ebb 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -14,7 +14,7 @@ def collate_fn(batch): prompt_speech_tokens_list, prompt_text_list = [], [] for i, item in enumerate(batch): generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) - audio = torch.from_numpy(item['prompt_audio']['array']).float() + audio = torch.from_numpy(item['prompt_audio']['array']).float() prompt_audios_list.append(audio) prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) ids.append(item['id']) @@ -37,7 +37,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - + if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) @@ -46,7 +46,7 @@ if __name__ == "__main__": data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) - + CHUNK_SIZE = 25 token_frame_rate = 25 OVERLAP_SIZE = 0 @@ -68,7 +68,7 @@ if __name__ == "__main__": semantic_token_ids_arr, token_offset = [], 0 flow_prompt_speech_token_len = len(prompt_speech_tokens) - + buffer = generated_speech_tokens output_wavs = [] chunk_index = 0 @@ -97,7 +97,7 @@ if __name__ == "__main__": output_wavs[i] = wav.cpu().numpy().squeeze() - audios = output_wavs + audios = output_wavs reconstructed_audio = np.concatenate(audios) sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")