mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
fix cache shallow copy
This commit is contained in:
@@ -2,6 +2,9 @@
|
||||
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
cosyvoice_path=/workspace/CosyVoice
|
||||
cosyvoice_path=/workspace_yuekai/tts/CosyVoice
|
||||
stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
|
||||
export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
|
||||
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
|
||||
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
|
||||
stage=$1
|
||||
@@ -140,3 +143,11 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
|
||||
python3 benchmark_streaming_token2wav.py --enable-trt
|
||||
|
||||
|
||||
fi
|
||||
@@ -362,8 +362,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
spk_emb_for_flow.to(self.device),
|
||||
n_timesteps=10
|
||||
)
|
||||
|
||||
# cache dict's tensor batch dim is 1 for now
|
||||
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
|
||||
cache['estimator_att_cache'] = cache['estimator_att_cache'].clone()
|
||||
cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone()
|
||||
return cache
|
||||
|
||||
|
||||
@@ -371,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
def forward_streaming(
|
||||
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
||||
):
|
||||
|
||||
if speaker_id not in self.speaker_cache:
|
||||
assert prompt_audio is not None, "prompt_audio is required for new speaker"
|
||||
assert prompt_audio_sample_rate == 16000
|
||||
@@ -388,7 +388,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
|
||||
|
||||
if request_id not in self.streaming_flow_cache:
|
||||
self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy()
|
||||
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
||||
self.hift_cache_dict[request_id] = dict(
|
||||
mel = torch.zeros(1, 80, 0, device='cuda'),
|
||||
source = torch.zeros(1, 1, 0, device='cuda'),
|
||||
@@ -396,12 +396,14 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
)
|
||||
|
||||
current_request_cache = self.streaming_flow_cache[request_id]
|
||||
prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
|
||||
|
||||
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
|
||||
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
||||
|
||||
|
||||
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
|
||||
token=generated_speech_tokens,
|
||||
spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device),
|
||||
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
|
||||
cache=current_request_cache,
|
||||
last_chunk=last_chunk,
|
||||
n_timesteps=10,
|
||||
@@ -409,9 +411,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
|
||||
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
|
||||
|
||||
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
|
||||
|
||||
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
|
||||
], dim=4)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user