From 444b7ff5dfafb0c98eecab5e7db461f997843a48 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 19 Sep 2025 13:48:32 +0800 Subject: [PATCH] fix cache shallow copy --- .../run_stepaudio2_dit_token2wav.sh | 11 +++++++++++ runtime/triton_trtllm/token2wav_dit.py | 19 +++++++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index 7c7f3cd..c0034c2 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -2,6 +2,9 @@ # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) export CUDA_VISIBLE_DEVICES=0 cosyvoice_path=/workspace/CosyVoice +cosyvoice_path=/workspace_yuekai/tts/CosyVoice +stepaudio2_path=/workspace_yuekai/tts/Step-Audio2 +export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH stage=$1 @@ -140,3 +143,11 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then done done fi + + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + + python3 benchmark_streaming_token2wav.py --enable-trt + + +fi \ No newline at end of file diff --git a/runtime/triton_trtllm/token2wav_dit.py b/runtime/triton_trtllm/token2wav_dit.py index 69db946..fdc1a12 100644 --- a/runtime/triton_trtllm/token2wav_dit.py +++ b/runtime/triton_trtllm/token2wav_dit.py @@ -362,8 +362,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module): spk_emb_for_flow.to(self.device), n_timesteps=10 ) - - # cache dict's tensor batch dim is 1 for now + # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] + cache['estimator_att_cache'] = cache['estimator_att_cache'].clone() + cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone() return cache @@ -371,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def forward_streaming( self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 ): - if speaker_id not in self.speaker_cache: assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio_sample_rate == 16000 @@ -388,7 +388,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} if request_id not in self.streaming_flow_cache: - self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy() + self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.hift_cache_dict[request_id] = dict( mel = torch.zeros(1, 80, 0, device='cuda'), source = torch.zeros(1, 1, 0, device='cuda'), @@ -396,12 +396,14 @@ class CosyVoice2_Token2Wav(torch.nn.Module): ) current_request_cache = self.streaming_flow_cache[request_id] - prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] + + current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') + chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( token=generated_speech_tokens, - spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device), + spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device), cache=current_request_cache, last_chunk=last_chunk, n_timesteps=10, @@ -409,9 +411,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module): self.streaming_flow_cache[request_id] = new_streaming_flow_cache - if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): + + if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ - self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]], + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]], self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], ], dim=4)