mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
fix lint
This commit is contained in:
@@ -122,7 +122,10 @@ def write_triton_stats(stats, summary_file):
|
||||
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
||||
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
||||
summary_f.write(
|
||||
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n"
|
||||
f"queue time {total_queue_time_s:<5.2f} s, "
|
||||
f"compute infer time {total_infer_time_s:<5.2f} s, "
|
||||
f"compute input time {total_input_time_s:<5.2f} s, "
|
||||
f"compute output time {total_output_time_s:<5.2f} s \n"
|
||||
)
|
||||
model_batch_stats = model_state["batch_stats"]
|
||||
for batch in model_batch_stats:
|
||||
@@ -136,7 +139,12 @@ def write_triton_stats(stats, summary_file):
|
||||
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
||||
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
||||
summary_f.write(
|
||||
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
|
||||
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
|
||||
f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
|
||||
f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
|
||||
f"{compute_infer_time_ms / batch_count:.2f} ms, "
|
||||
f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
|
||||
f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
|
||||
)
|
||||
summary_f.write(
|
||||
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
|
||||
|
||||
@@ -25,7 +25,6 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import requests
|
||||
import soundfile as sf
|
||||
import json
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
|
||||
@@ -25,12 +25,9 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, List, Tuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -178,7 +178,6 @@ class TritonPythonModel:
|
||||
yield final_id
|
||||
buffer = buffer[match.end():]
|
||||
|
||||
|
||||
def forward_audio_tokenizer(self, wav, wav_len):
|
||||
"""Forward pass through the audio tokenizer component.
|
||||
|
||||
@@ -263,7 +262,7 @@ class TritonPythonModel:
|
||||
],
|
||||
inputs=inputs_tensor,
|
||||
request_id=request_id,
|
||||
parameters={"priority": index+1},
|
||||
parameters={"priority": index + 1},
|
||||
)
|
||||
|
||||
inference_response = await inference_request.async_exec()
|
||||
|
||||
@@ -28,7 +28,6 @@ import json
|
||||
import os
|
||||
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
|
||||
import torch
|
||||
from torch.utils.dlpack import to_dlpack
|
||||
|
||||
@@ -48,9 +48,11 @@ import hashlib
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
||||
"""
|
||||
Generates a unique ID for a torch.Tensor.
|
||||
@@ -65,6 +67,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for vocoder.
|
||||
|
||||
@@ -114,7 +117,6 @@ class TritonPythonModel:
|
||||
|
||||
request_id = request.request_id()
|
||||
|
||||
|
||||
wav_array = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav").as_numpy()
|
||||
wav_len = pb_utils.get_input_tensor_by_name(
|
||||
@@ -125,7 +127,10 @@ class TritonPythonModel:
|
||||
|
||||
spk_id = get_spk_id_from_prompt_audio(wav)
|
||||
|
||||
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
|
||||
audio_hat = self.token2wav_model.forward_streaming(
|
||||
target_speech_tokens, finalize, request_id=request_id,
|
||||
speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
|
||||
)
|
||||
|
||||
outputs = []
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ import numpy as np
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
|
||||
|
||||
def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
|
||||
def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor):
|
||||
"""perform fade_in_out in tensor style
|
||||
"""
|
||||
mel_overlap_len = int(window.shape[0] / 2)
|
||||
@@ -45,6 +45,7 @@ def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torc
|
||||
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
||||
return fade_in_mel
|
||||
|
||||
|
||||
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
|
||||
import tensorrt as trt
|
||||
logging.info("Converting onnx to trt...")
|
||||
@@ -90,6 +91,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
|
||||
f.write(engine_bytes)
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
@@ -108,6 +110,7 @@ class TrtContextWrapper:
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
|
||||
|
||||
class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
|
||||
super().__init__()
|
||||
@@ -131,27 +134,33 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
|
||||
providers=["CPUExecutionProvider"])
|
||||
self.spk_model = onnxruntime.InferenceSession(
|
||||
f"{model_dir}/campplus.onnx", sess_options=option,
|
||||
providers=["CPUExecutionProvider"])
|
||||
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
|
||||
|
||||
gpu="l20"
|
||||
gpu = "l20"
|
||||
if enable_trt:
|
||||
if streaming:
|
||||
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
|
||||
1,
|
||||
self.dtype, streaming)
|
||||
self.load_trt(
|
||||
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
|
||||
1,
|
||||
self.dtype, streaming
|
||||
)
|
||||
else:
|
||||
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
||||
1,
|
||||
self.dtype)
|
||||
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||
f'{model_dir}/campplus.onnx',
|
||||
1,
|
||||
False)
|
||||
|
||||
self.load_trt(
|
||||
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
||||
1,
|
||||
self.dtype
|
||||
)
|
||||
self.load_spk_trt(
|
||||
f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||
f'{model_dir}/campplus.onnx',
|
||||
1,
|
||||
False
|
||||
)
|
||||
|
||||
self.streaming_flow_cache = {}
|
||||
self.speaker_cache = {}
|
||||
@@ -215,7 +224,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
opt_batch_size = 2
|
||||
max_batch_size = 16
|
||||
if streaming:
|
||||
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
|
||||
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
|
||||
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
|
||||
del self.flow.decoder.estimator
|
||||
@@ -228,13 +237,27 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
|
||||
if streaming:
|
||||
min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
|
||||
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)]
|
||||
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)]
|
||||
opt_shape = [
|
||||
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500),
|
||||
(opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2),
|
||||
(16, opt_batch_size * 2, 8, 100, 128)
|
||||
]
|
||||
max_shape = [
|
||||
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000),
|
||||
(max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2),
|
||||
(16, max_batch_size * 2, 8, 1000, 128)
|
||||
]
|
||||
input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
|
||||
else:
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
|
||||
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)]
|
||||
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)]
|
||||
opt_shape = [
|
||||
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500),
|
||||
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80)
|
||||
]
|
||||
max_shape = [
|
||||
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000),
|
||||
(max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80)
|
||||
]
|
||||
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
@@ -279,11 +302,17 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
mel_len = mel.shape[0]
|
||||
prompt_mels_for_flow.append(mel)
|
||||
prompt_mels_lens_for_flow.append(mel_len)
|
||||
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
||||
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
|
||||
prompt_mels_for_flow, batch_first=True, padding_value=0
|
||||
) # [B, T', num_mels=80]
|
||||
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
||||
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
||||
|
||||
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
||||
def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
|
||||
generated_speech_tokens_list: list[list[int]],
|
||||
prompt_mels_for_flow: torch.Tensor,
|
||||
prompt_mels_lens_for_flow: torch.Tensor,
|
||||
spk_emb_for_flow: torch.Tensor):
|
||||
batch_size = prompt_mels_for_flow.shape[0]
|
||||
flow_inputs = []
|
||||
flow_inputs_lens = []
|
||||
@@ -311,7 +340,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
generated_wavs.append(wav)
|
||||
return generated_wavs
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
@@ -320,7 +348,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
|
||||
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||
generated_mels, generated_mels_lens = self.forward_flow(
|
||||
prompt_speech_tokens_list, generated_speech_tokens_list,
|
||||
prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
)
|
||||
|
||||
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
||||
return generated_wavs
|
||||
@@ -337,7 +368,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
||||
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
|
||||
|
||||
def get_prompt_audio_cache_for_streaming_tts(
|
||||
self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
):
|
||||
@@ -356,7 +386,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
|
||||
return new_cache
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_streaming(
|
||||
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
||||
@@ -379,9 +408,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
if request_id not in self.streaming_flow_cache:
|
||||
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
||||
self.hift_cache_dict[request_id] = dict(
|
||||
mel = torch.zeros(1, 80, 0, device='cuda'),
|
||||
source = torch.zeros(1, 1, 0, device='cuda'),
|
||||
speech = torch.zeros(1, 0, device='cuda'),
|
||||
mel=torch.zeros(1, 80, 0, device='cuda'),
|
||||
source=torch.zeros(1, 1, 0, device='cuda'),
|
||||
speech=torch.zeros(1, 0, device='cuda'),
|
||||
)
|
||||
|
||||
current_request_cache = self.streaming_flow_cache[request_id]
|
||||
@@ -389,7 +418,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
|
||||
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
||||
|
||||
|
||||
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
|
||||
token=generated_speech_tokens,
|
||||
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
|
||||
@@ -400,15 +428,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
|
||||
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
|
||||
|
||||
|
||||
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
|
||||
], dim=4)
|
||||
|
||||
|
||||
|
||||
hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
|
||||
hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
|
||||
hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
|
||||
@@ -422,9 +447,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
|
||||
# update vocoder cache
|
||||
self.hift_cache_dict[request_id] = dict(
|
||||
mel = mel[..., -self.mel_cache_len:].clone().detach(),
|
||||
source = source[:, :, -self.source_cache_len:].clone().detach(),
|
||||
speech = speech[:, -self.source_cache_len:].clone().detach(),
|
||||
mel=mel[..., -self.mel_cache_len:].clone().detach(),
|
||||
source=source[:, :, -self.source_cache_len:].clone().detach(),
|
||||
speech=speech[:, -self.source_cache_len:].clone().detach(),
|
||||
)
|
||||
if not last_chunk:
|
||||
speech = speech[:, :-self.source_cache_len]
|
||||
@@ -436,6 +461,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
|
||||
return speech
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
@@ -447,6 +473,7 @@ def collate_fn(batch):
|
||||
|
||||
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--enable-trt", action="store_true")
|
||||
@@ -457,6 +484,7 @@ def get_args():
|
||||
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
||||
@@ -466,22 +494,17 @@ if __name__ == "__main__":
|
||||
|
||||
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||
|
||||
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||
|
||||
|
||||
for epoch in range(args.warmup):
|
||||
start_time = time.time()
|
||||
|
||||
for batch in data_loader:
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
||||
|
||||
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
|
||||
for id, wav in zip(ids, generated_wavs):
|
||||
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
||||
|
||||
end_time = time.time()
|
||||
epoch_time = end_time - start_time
|
||||
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||
|
||||
@@ -28,7 +28,6 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -15,11 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import csv
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -9,6 +9,7 @@ import time
|
||||
from token2wav_dit import CosyVoice2_Token2Wav
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||
prompt_speech_tokens_list, prompt_text_list = [], []
|
||||
@@ -23,6 +24,7 @@ def collate_fn(batch):
|
||||
|
||||
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--enable-trt", action="store_true")
|
||||
@@ -79,7 +81,11 @@ if __name__ == "__main__":
|
||||
this_chunk_size = token_frame_rate * (2 ** chunk_index)
|
||||
|
||||
if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
|
||||
wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
||||
wavs = token2wav_model.forward_streaming(
|
||||
buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len],
|
||||
False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio,
|
||||
prompt_audio_sample_rate=prompt_audio_sample_rate
|
||||
)
|
||||
buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
|
||||
|
||||
output_wavs.append(wavs)
|
||||
@@ -87,7 +93,10 @@ if __name__ == "__main__":
|
||||
chunk_index += 1
|
||||
|
||||
else:
|
||||
wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
||||
wavs = token2wav_model.forward_streaming(
|
||||
buffer, True, request_id=id, speaker_id=f"{id}",
|
||||
prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate
|
||||
)
|
||||
output_wavs.append(wavs)
|
||||
total_forward_count += 1
|
||||
# chunk_index += 1
|
||||
@@ -96,7 +105,6 @@ if __name__ == "__main__":
|
||||
for i, wav in enumerate(output_wavs):
|
||||
output_wavs[i] = wav.cpu().numpy().squeeze()
|
||||
|
||||
|
||||
audios = output_wavs
|
||||
reconstructed_audio = np.concatenate(audios)
|
||||
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
|
||||
@@ -111,4 +119,4 @@ if __name__ == "__main__":
|
||||
print(f"Cost time without speaker cache: {end_time - start_time} seconds")
|
||||
else:
|
||||
print(f"Cost time with speaker cache: {end_time - start_time} seconds")
|
||||
print(f"Total flow matching forward calls: {total_forward_count}")
|
||||
print(f"Total flow matching forward calls: {total_forward_count}")
|
||||
|
||||
Reference in New Issue
Block a user