mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
fix cache shallow copy
This commit is contained in:
@@ -2,6 +2,9 @@
|
|||||||
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
|
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
cosyvoice_path=/workspace/CosyVoice
|
cosyvoice_path=/workspace/CosyVoice
|
||||||
|
cosyvoice_path=/workspace_yuekai/tts/CosyVoice
|
||||||
|
stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
|
||||||
|
export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
|
||||||
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
|
export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
|
||||||
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
|
export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
|
||||||
stage=$1
|
stage=$1
|
||||||
@@ -140,3 +143,11 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
done
|
done
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
|
|
||||||
|
python3 benchmark_streaming_token2wav.py --enable-trt
|
||||||
|
|
||||||
|
|
||||||
|
fi
|
||||||
@@ -362,8 +362,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
spk_emb_for_flow.to(self.device),
|
spk_emb_for_flow.to(self.device),
|
||||||
n_timesteps=10
|
n_timesteps=10
|
||||||
)
|
)
|
||||||
|
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
|
||||||
# cache dict's tensor batch dim is 1 for now
|
cache['estimator_att_cache'] = cache['estimator_att_cache'].clone()
|
||||||
|
cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone()
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
@@ -371,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
def forward_streaming(
|
def forward_streaming(
|
||||||
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
||||||
):
|
):
|
||||||
|
|
||||||
if speaker_id not in self.speaker_cache:
|
if speaker_id not in self.speaker_cache:
|
||||||
assert prompt_audio is not None, "prompt_audio is required for new speaker"
|
assert prompt_audio is not None, "prompt_audio is required for new speaker"
|
||||||
assert prompt_audio_sample_rate == 16000
|
assert prompt_audio_sample_rate == 16000
|
||||||
@@ -388,7 +388,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
|
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
|
||||||
|
|
||||||
if request_id not in self.streaming_flow_cache:
|
if request_id not in self.streaming_flow_cache:
|
||||||
self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy()
|
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
||||||
self.hift_cache_dict[request_id] = dict(
|
self.hift_cache_dict[request_id] = dict(
|
||||||
mel = torch.zeros(1, 80, 0, device='cuda'),
|
mel = torch.zeros(1, 80, 0, device='cuda'),
|
||||||
source = torch.zeros(1, 1, 0, device='cuda'),
|
source = torch.zeros(1, 1, 0, device='cuda'),
|
||||||
@@ -396,12 +396,14 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
current_request_cache = self.streaming_flow_cache[request_id]
|
current_request_cache = self.streaming_flow_cache[request_id]
|
||||||
prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
|
|
||||||
|
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
|
||||||
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
||||||
|
|
||||||
|
|
||||||
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
|
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
|
||||||
token=generated_speech_tokens,
|
token=generated_speech_tokens,
|
||||||
spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device),
|
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
|
||||||
cache=current_request_cache,
|
cache=current_request_cache,
|
||||||
last_chunk=last_chunk,
|
last_chunk=last_chunk,
|
||||||
n_timesteps=10,
|
n_timesteps=10,
|
||||||
@@ -409,9 +411,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
|
|
||||||
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
|
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
|
||||||
|
|
||||||
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
|
|
||||||
|
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
|
||||||
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
|
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
|
||||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
|
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
|
||||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
|
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
|
||||||
], dim=4)
|
], dim=4)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user