fix cache shallow copy

This commit is contained in:
root
2025-09-19 13:48:32 +08:00
parent b207c60885
commit 444b7ff5df
2 changed files with 22 additions and 8 deletions

View File

@@ -2,6 +2,9 @@
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
export CUDA_VISIBLE_DEVICES=0
cosyvoice_path=/workspace/CosyVoice
cosyvoice_path=/workspace_yuekai/tts/CosyVoice
stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
stage=$1
@@ -140,3 +143,11 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
done
done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
python3 benchmark_streaming_token2wav.py --enable-trt
fi

View File

@@ -362,8 +362,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
spk_emb_for_flow.to(self.device),
n_timesteps=10
)
# cache dict's tensor batch dim is 1 for now
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
cache['estimator_att_cache'] = cache['estimator_att_cache'].clone()
cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone()
return cache
@@ -371,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
def forward_streaming(
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
):
if speaker_id not in self.speaker_cache:
assert prompt_audio is not None, "prompt_audio is required for new speaker"
assert prompt_audio_sample_rate == 16000
@@ -388,7 +388,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
if request_id not in self.streaming_flow_cache:
self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy()
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
self.hift_cache_dict[request_id] = dict(
mel = torch.zeros(1, 80, 0, device='cuda'),
source = torch.zeros(1, 1, 0, device='cuda'),
@@ -396,12 +396,14 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
)
current_request_cache = self.streaming_flow_cache[request_id]
prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
token=generated_speech_tokens,
spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device),
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
cache=current_request_cache,
last_chunk=last_chunk,
n_timesteps=10,
@@ -409,9 +411,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
], dim=4)