mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add streaming dit
This commit is contained in:
115
runtime/triton_trtllm/streaming_inference.py
Normal file
115
runtime/triton_trtllm/streaming_inference.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import torch
|
||||
import os
|
||||
import argparse
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
import time
|
||||
from token2wav_dit import CosyVoice2_Token2Wav
|
||||
import soundfile as sf
|
||||
|
||||
def collate_fn(batch):
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||
prompt_speech_tokens_list, prompt_text_list = [], []
|
||||
for i, item in enumerate(batch):
|
||||
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||
prompt_audios_list.append(audio)
|
||||
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||
ids.append(item['id'])
|
||||
prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens'])
|
||||
prompt_text_list.append(item['prompt_text'])
|
||||
|
||||
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--enable-trt", action="store_true")
|
||||
parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
|
||||
parser.add_argument("--batch-size", type=int, default=1)
|
||||
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
||||
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
||||
parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def fake_generated_id_iter(generated_speech_tokens_list):
|
||||
for i in range(len(generated_speech_tokens_list)):
|
||||
yield generated_speech_tokens_list[i]
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
dataset_name = args.dataset_name
|
||||
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||
|
||||
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
|
||||
|
||||
flow_pre_lookahead_len = 3
|
||||
CHUNK_SIZE = 25
|
||||
OVERLAP_SIZE = 0
|
||||
|
||||
warmup_times = 3
|
||||
for _ in range(warmup_times):
|
||||
start_time = time.time()
|
||||
for batch in data_loader:
|
||||
tts_speech_list = []
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
|
||||
|
||||
id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
|
||||
# if id != "unseen3_text5":
|
||||
# continue
|
||||
# else:
|
||||
# a = torch.load("semantic_token_ids_arr_debug_871e2b90-42a7-4829-957c-b45e6a96fdb2.pt")
|
||||
# generated_speech_tokens = a["semantic_token_ids_arr"]
|
||||
# print(generated_speech_tokens)
|
||||
assert prompt_audio_sample_rate == 16000
|
||||
|
||||
prompt_text = prompt_text_list[0]
|
||||
prompt_speech_tokens = prompt_speech_tokens_list[0]
|
||||
|
||||
|
||||
# generated_ids_iter = fake_generated_id_iter(generated_speech_tokens)
|
||||
|
||||
semantic_token_ids_arr, token_offset = [], 0
|
||||
flow_prompt_speech_token_len = len(prompt_speech_tokens)
|
||||
|
||||
buffer = generated_speech_tokens
|
||||
output_wavs = []
|
||||
while True:
|
||||
|
||||
if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len:
|
||||
wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
||||
buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:]
|
||||
|
||||
output_wavs.append(wavs)
|
||||
|
||||
else:
|
||||
wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
||||
output_wavs.append(wavs)
|
||||
break
|
||||
|
||||
for i, wav in enumerate(output_wavs):
|
||||
output_wavs[i] = wav.cpu().numpy().squeeze()
|
||||
|
||||
|
||||
audios = output_wavs
|
||||
reconstructed_audio = np.concatenate(audios)
|
||||
# Save reconstructed audio
|
||||
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
|
||||
|
||||
|
||||
print(f"Saved {id}")
|
||||
end_time = time.time()
|
||||
|
||||
if _ == 0:
|
||||
token2wav_model.speaker_cache = {}
|
||||
print(f"Warmup time: {end_time - start_time} seconds")
|
||||
|
||||
Reference in New Issue
Block a user