This commit is contained in:
yuekaiz
2025-10-09 15:13:43 +08:00
parent 8811e9f33a
commit 33aee03ed5
14 changed files with 100 additions and 72 deletions

View File

@@ -48,9 +48,11 @@ import hashlib
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
"""
Generates a unique ID for a torch.Tensor.
@@ -65,6 +67,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
return hasher.hexdigest()
class TritonPythonModel:
"""Triton Python model for vocoder.
@@ -114,7 +117,6 @@ class TritonPythonModel:
request_id = request.request_id()
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(
@@ -125,7 +127,10 @@ class TritonPythonModel:
spk_id = get_spk_id_from_prompt_audio(wav)
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
audio_hat = self.token2wav_model.forward_streaming(
target_speech_tokens, finalize, request_id=request_id,
speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
)
outputs = []

View File

@@ -35,7 +35,7 @@ import numpy as np
from hyperpyyaml import load_hyperpyyaml
def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor):
"""perform fade_in_out in tensor style
"""
mel_overlap_len = int(window.shape[0] / 2)
@@ -45,6 +45,7 @@ def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torc
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
import tensorrt as trt
logging.info("Converting onnx to trt...")
@@ -90,6 +91,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
f.write(engine_bytes)
logging.info("Succesfully convert onnx to trt...")
class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
@@ -108,6 +110,7 @@ class TrtContextWrapper:
def release_estimator(self, context, stream):
self.trt_context_pool.put([context, stream])
class CosyVoice2_Token2Wav(torch.nn.Module):
def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
super().__init__()
@@ -131,27 +134,33 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
providers=["CPUExecutionProvider"])
self.spk_model = onnxruntime.InferenceSession(
f"{model_dir}/campplus.onnx", sess_options=option,
providers=["CPUExecutionProvider"])
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
gpu="l20"
gpu = "l20"
if enable_trt:
if streaming:
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
1,
self.dtype, streaming)
self.load_trt(
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
1,
self.dtype, streaming
)
else:
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
1,
self.dtype)
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False)
self.load_trt(
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
1,
self.dtype
)
self.load_spk_trt(
f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False
)
self.streaming_flow_cache = {}
self.speaker_cache = {}
@@ -215,7 +224,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
opt_batch_size = 2
max_batch_size = 16
if streaming:
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
del self.flow.decoder.estimator
@@ -228,13 +237,27 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
if streaming:
min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)]
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)]
opt_shape = [
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500),
(opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2),
(16, opt_batch_size * 2, 8, 100, 128)
]
max_shape = [
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000),
(max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2),
(16, max_batch_size * 2, 8, 1000, 128)
]
input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
else:
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)]
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)]
opt_shape = [
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500),
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80)
]
max_shape = [
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000),
(max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80)
]
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
@@ -279,11 +302,17 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
mel_len = mel.shape[0]
prompt_mels_for_flow.append(mel)
prompt_mels_lens_for_flow.append(mel_len)
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
prompt_mels_for_flow, batch_first=True, padding_value=0
) # [B, T', num_mels=80]
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
return prompt_mels_for_flow, prompt_mels_lens_for_flow
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
generated_speech_tokens_list: list[list[int]],
prompt_mels_for_flow: torch.Tensor,
prompt_mels_lens_for_flow: torch.Tensor,
spk_emb_for_flow: torch.Tensor):
batch_size = prompt_mels_for_flow.shape[0]
flow_inputs = []
flow_inputs_lens = []
@@ -311,7 +340,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
generated_wavs.append(wav)
return generated_wavs
@torch.inference_mode()
def forward(
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
@@ -320,7 +348,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
generated_mels, generated_mels_lens = self.forward_flow(
prompt_speech_tokens_list, generated_speech_tokens_list,
prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
)
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
return generated_wavs
@@ -337,7 +368,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
def get_prompt_audio_cache_for_streaming_tts(
self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
):
@@ -356,7 +386,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
return new_cache
@torch.inference_mode()
def forward_streaming(
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
@@ -379,9 +408,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
if request_id not in self.streaming_flow_cache:
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
self.hift_cache_dict[request_id] = dict(
mel = torch.zeros(1, 80, 0, device='cuda'),
source = torch.zeros(1, 1, 0, device='cuda'),
speech = torch.zeros(1, 0, device='cuda'),
mel=torch.zeros(1, 80, 0, device='cuda'),
source=torch.zeros(1, 1, 0, device='cuda'),
speech=torch.zeros(1, 0, device='cuda'),
)
current_request_cache = self.streaming_flow_cache[request_id]
@@ -389,7 +418,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
token=generated_speech_tokens,
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
@@ -400,15 +428,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
], dim=4)
hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
@@ -422,9 +447,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
# update vocoder cache
self.hift_cache_dict[request_id] = dict(
mel = mel[..., -self.mel_cache_len:].clone().detach(),
source = source[:, :, -self.source_cache_len:].clone().detach(),
speech = speech[:, -self.source_cache_len:].clone().detach(),
mel=mel[..., -self.mel_cache_len:].clone().detach(),
source=source[:, :, -self.source_cache_len:].clone().detach(),
speech=speech[:, -self.source_cache_len:].clone().detach(),
)
if not last_chunk:
speech = speech[:, :-self.source_cache_len]
@@ -436,6 +461,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
return speech
def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
for i, item in enumerate(batch):
@@ -447,6 +473,7 @@ def collate_fn(batch):
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--enable-trt", action="store_true")
@@ -457,6 +484,7 @@ def get_args():
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
@@ -466,22 +494,17 @@ if __name__ == "__main__":
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
for epoch in range(args.warmup):
start_time = time.time()
for batch in data_loader:
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
for id, wav in zip(ids, generated_wavs):
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
end_time = time.time()
epoch_time = end_time - start_time
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")