mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
remove cache router
This commit is contained in:
@@ -31,6 +31,7 @@ def get_args():
|
||||
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
||||
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
||||
parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
|
||||
parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"])
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -53,12 +54,14 @@ if __name__ == "__main__":
|
||||
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
|
||||
|
||||
flow_pre_lookahead_len = 3
|
||||
CHUNK_SIZE = 25
|
||||
CHUNK_SIZE = 15
|
||||
token_frame_rate = 25
|
||||
OVERLAP_SIZE = 0
|
||||
|
||||
warmup_times = 3
|
||||
for _ in range(warmup_times):
|
||||
start_time = time.time()
|
||||
total_forward_count = 0
|
||||
for batch in data_loader:
|
||||
tts_speech_list = []
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
|
||||
@@ -83,17 +86,26 @@ if __name__ == "__main__":
|
||||
|
||||
buffer = generated_speech_tokens
|
||||
output_wavs = []
|
||||
chunk_index = 0
|
||||
while True:
|
||||
if args.strategy == "equal":
|
||||
this_chunk_size = CHUNK_SIZE
|
||||
elif args.strategy == "exponential":
|
||||
this_chunk_size = token_frame_rate * (2 ** chunk_index)
|
||||
|
||||
if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len:
|
||||
wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
||||
buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:]
|
||||
if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
|
||||
wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
||||
buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
|
||||
|
||||
output_wavs.append(wavs)
|
||||
total_forward_count += 1
|
||||
chunk_index += 1
|
||||
|
||||
else:
|
||||
wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
||||
output_wavs.append(wavs)
|
||||
total_forward_count += 1
|
||||
# chunk_index += 1
|
||||
break
|
||||
|
||||
for i, wav in enumerate(output_wavs):
|
||||
@@ -112,4 +124,4 @@ if __name__ == "__main__":
|
||||
if _ == 0:
|
||||
token2wav_model.speaker_cache = {}
|
||||
print(f"Warmup time: {end_time - start_time} seconds")
|
||||
|
||||
print(f"Total forward count: {total_forward_count}")
|
||||
|
||||
Reference in New Issue
Block a user