mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -9,6 +9,7 @@ import time
|
||||
from token2wav_dit import CosyVoice2_Token2Wav
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||
prompt_speech_tokens_list, prompt_text_list = [], []
|
||||
@@ -23,6 +24,7 @@ def collate_fn(batch):
|
||||
|
||||
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--enable-trt", action="store_true")
|
||||
@@ -79,7 +81,11 @@ if __name__ == "__main__":
|
||||
this_chunk_size = token_frame_rate * (2 ** chunk_index)
|
||||
|
||||
if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
|
||||
wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
||||
wavs = token2wav_model.forward_streaming(
|
||||
buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len],
|
||||
False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio,
|
||||
prompt_audio_sample_rate=prompt_audio_sample_rate
|
||||
)
|
||||
buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
|
||||
|
||||
output_wavs.append(wavs)
|
||||
@@ -87,7 +93,10 @@ if __name__ == "__main__":
|
||||
chunk_index += 1
|
||||
|
||||
else:
|
||||
wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
||||
wavs = token2wav_model.forward_streaming(
|
||||
buffer, True, request_id=id, speaker_id=f"{id}",
|
||||
prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate
|
||||
)
|
||||
output_wavs.append(wavs)
|
||||
total_forward_count += 1
|
||||
# chunk_index += 1
|
||||
@@ -96,7 +105,6 @@ if __name__ == "__main__":
|
||||
for i, wav in enumerate(output_wavs):
|
||||
output_wavs[i] = wav.cpu().numpy().squeeze()
|
||||
|
||||
|
||||
audios = output_wavs
|
||||
reconstructed_audio = np.concatenate(audios)
|
||||
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
|
||||
@@ -111,4 +119,4 @@ if __name__ == "__main__":
|
||||
print(f"Cost time without speaker cache: {end_time - start_time} seconds")
|
||||
else:
|
||||
print(f"Cost time with speaker cache: {end_time - start_time} seconds")
|
||||
print(f"Total flow matching forward calls: {total_forward_count}")
|
||||
print(f"Total flow matching forward calls: {total_forward_count}")
|
||||
|
||||
Reference in New Issue
Block a user