This commit is contained in:
yuekaiz
2025-10-09 15:13:43 +08:00
parent 8811e9f33a
commit 33aee03ed5
14 changed files with 100 additions and 72 deletions

View File

@@ -53,7 +53,7 @@ except RuntimeError:
pass pass
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501
def audio_decode_cosyvoice2( def audio_decode_cosyvoice2(

View File

@@ -1,5 +1,3 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #

View File

@@ -1,5 +1,3 @@
#!/usr/bin/env python3
#
# Copyright (c) 2023 by manyeyes # Copyright (c) 2023 by manyeyes
# Copyright (c) 2023 Xiaomi Corporation # Copyright (c) 2023 Xiaomi Corporation
@@ -195,7 +193,7 @@ def write_error_stats(
hyp = list("".join(hyp)) hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp) results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results: for _cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali: for ref_word, hyp_word in ali:
if ref_word == ERR: if ref_word == ERR:

View File

@@ -295,7 +295,7 @@ def main():
metrics_port=8002, metrics_port=8002,
) )
device_ids = [i for i in range(args.number_of_devices)] device_ids = list(range(args.number_of_devices))
device_ids = device_ids * args.number_of_instances_per_device device_ids = device_ids * args.number_of_instances_per_device
with Triton(config=triton_config) as triton: with Triton(config=triton_config) as triton:

View File

@@ -122,7 +122,10 @@ def write_triton_stats(stats, summary_file):
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
summary_f.write( summary_f.write(
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" f"queue time {total_queue_time_s:<5.2f} s, "
f"compute infer time {total_infer_time_s:<5.2f} s, "
f"compute input time {total_input_time_s:<5.2f} s, "
f"compute output time {total_output_time_s:<5.2f} s \n"
) )
model_batch_stats = model_state["batch_stats"] model_batch_stats = model_state["batch_stats"]
for batch in model_batch_stats: for batch in model_batch_stats:
@@ -136,7 +139,12 @@ def write_triton_stats(stats, summary_file):
compute_input_time_ms = int(compute_input["ns"]) / 1e6 compute_input_time_ms = int(compute_input["ns"]) / 1e6
compute_output_time_ms = int(compute_output["ns"]) / 1e6 compute_output_time_ms = int(compute_output["ns"]) / 1e6
summary_f.write( summary_f.write(
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
f"{compute_infer_time_ms / batch_count:.2f} ms, "
f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
) )
summary_f.write( summary_f.write(
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "

View File

@@ -25,7 +25,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import requests import requests
import soundfile as sf import soundfile as sf
import json
import numpy as np import numpy as np
import argparse import argparse

View File

@@ -25,12 +25,9 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json import json
import math
import os import os
import re
import threading import threading
import time import time
from typing import Dict, List, Tuple, Optional, Union
import numpy as np import numpy as np
import torch import torch

View File

@@ -178,7 +178,6 @@ class TritonPythonModel:
yield final_id yield final_id
buffer = buffer[match.end():] buffer = buffer[match.end():]
def forward_audio_tokenizer(self, wav, wav_len): def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component. """Forward pass through the audio tokenizer component.
@@ -263,7 +262,7 @@ class TritonPythonModel:
], ],
inputs=inputs_tensor, inputs=inputs_tensor,
request_id=request_id, request_id=request_id,
parameters={"priority": index+1}, parameters={"priority": index + 1},
) )
inference_response = await inference_request.async_exec() inference_response = await inference_request.async_exec()

View File

@@ -28,7 +28,6 @@ import json
import os import os
import logging import logging
from typing import List, Dict
import torch import torch
from torch.utils.dlpack import to_dlpack from torch.utils.dlpack import to_dlpack

View File

@@ -48,9 +48,11 @@ import hashlib
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663 ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1) torch.set_num_threads(1)
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str: def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
""" """
Generates a unique ID for a torch.Tensor. Generates a unique ID for a torch.Tensor.
@@ -65,6 +67,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
return hasher.hexdigest() return hasher.hexdigest()
class TritonPythonModel: class TritonPythonModel:
"""Triton Python model for vocoder. """Triton Python model for vocoder.
@@ -114,7 +117,6 @@ class TritonPythonModel:
request_id = request.request_id() request_id = request.request_id()
wav_array = pb_utils.get_input_tensor_by_name( wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy() request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name( wav_len = pb_utils.get_input_tensor_by_name(
@@ -125,7 +127,10 @@ class TritonPythonModel:
spk_id = get_spk_id_from_prompt_audio(wav) spk_id = get_spk_id_from_prompt_audio(wav)
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) audio_hat = self.token2wav_model.forward_streaming(
target_speech_tokens, finalize, request_id=request_id,
speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
)
outputs = [] outputs = []

View File

@@ -35,7 +35,7 @@ import numpy as np
from hyperpyyaml import load_hyperpyyaml from hyperpyyaml import load_hyperpyyaml
def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor): def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor):
"""perform fade_in_out in tensor style """perform fade_in_out in tensor style
""" """
mel_overlap_len = int(window.shape[0] / 2) mel_overlap_len = int(window.shape[0] / 2)
@@ -45,6 +45,7 @@ def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torc
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel return fade_in_mel
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
import tensorrt as trt import tensorrt as trt
logging.info("Converting onnx to trt...") logging.info("Converting onnx to trt...")
@@ -90,6 +91,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
f.write(engine_bytes) f.write(engine_bytes)
logging.info("Succesfully convert onnx to trt...") logging.info("Succesfully convert onnx to trt...")
class TrtContextWrapper: class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
@@ -108,6 +110,7 @@ class TrtContextWrapper:
def release_estimator(self, context, stream): def release_estimator(self, context, stream):
self.trt_context_pool.put([context, stream]) self.trt_context_pool.put([context, stream])
class CosyVoice2_Token2Wav(torch.nn.Module): class CosyVoice2_Token2Wav(torch.nn.Module):
def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16): def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
super().__init__() super().__init__()
@@ -131,27 +134,33 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
option = onnxruntime.SessionOptions() option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1 option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, self.spk_model = onnxruntime.InferenceSession(
providers=["CPUExecutionProvider"]) f"{model_dir}/campplus.onnx", sess_options=option,
providers=["CPUExecutionProvider"])
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
gpu="l20" gpu = "l20"
if enable_trt: if enable_trt:
if streaming: if streaming:
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', self.load_trt(
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
1, f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
self.dtype, streaming) 1,
self.dtype, streaming
)
else: else:
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', self.load_trt(
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
1, f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
self.dtype) 1,
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', self.dtype
f'{model_dir}/campplus.onnx', )
1, self.load_spk_trt(
False) f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False
)
self.streaming_flow_cache = {} self.streaming_flow_cache = {}
self.speaker_cache = {} self.speaker_cache = {}
@@ -215,7 +224,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
opt_batch_size = 2 opt_batch_size = 2
max_batch_size = 16 max_batch_size = 16
if streaming: if streaming:
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming) trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype) convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
del self.flow.decoder.estimator del self.flow.decoder.estimator
@@ -228,13 +237,27 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False): def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
if streaming: if streaming:
min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)] min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)] opt_shape = [
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)] (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500),
(opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2),
(16, opt_batch_size * 2, 8, 100, 128)
]
max_shape = [
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000),
(max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2),
(16, max_batch_size * 2, 8, 1000, 128)
]
input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"] input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
else: else:
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)] opt_shape = [
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)] (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500),
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80)
]
max_shape = [
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000),
(max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80)
]
input_names = ["x", "mask", "mu", "cond", "t", "spks"] input_names = ["x", "mask", "mu", "cond", "t", "spks"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
@@ -279,11 +302,17 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
mel_len = mel.shape[0] mel_len = mel.shape[0]
prompt_mels_for_flow.append(mel) prompt_mels_for_flow.append(mel)
prompt_mels_lens_for_flow.append(mel_len) prompt_mels_lens_for_flow.append(mel_len)
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
prompt_mels_for_flow, batch_first=True, padding_value=0
) # [B, T', num_mels=80]
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
return prompt_mels_for_flow, prompt_mels_lens_for_flow return prompt_mels_for_flow, prompt_mels_lens_for_flow
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
generated_speech_tokens_list: list[list[int]],
prompt_mels_for_flow: torch.Tensor,
prompt_mels_lens_for_flow: torch.Tensor,
spk_emb_for_flow: torch.Tensor):
batch_size = prompt_mels_for_flow.shape[0] batch_size = prompt_mels_for_flow.shape[0]
flow_inputs = [] flow_inputs = []
flow_inputs_lens = [] flow_inputs_lens = []
@@ -311,7 +340,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
generated_wavs.append(wav) generated_wavs.append(wav)
return generated_wavs return generated_wavs
@torch.inference_mode() @torch.inference_mode()
def forward( def forward(
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
@@ -320,7 +348,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) generated_mels, generated_mels_lens = self.forward_flow(
prompt_speech_tokens_list, generated_speech_tokens_list,
prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
)
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
return generated_wavs return generated_wavs
@@ -337,7 +368,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
def get_prompt_audio_cache_for_streaming_tts( def get_prompt_audio_cache_for_streaming_tts(
self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
): ):
@@ -356,7 +386,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache'] # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
return new_cache return new_cache
@torch.inference_mode() @torch.inference_mode()
def forward_streaming( def forward_streaming(
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
@@ -379,9 +408,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
if request_id not in self.streaming_flow_cache: if request_id not in self.streaming_flow_cache:
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
self.hift_cache_dict[request_id] = dict( self.hift_cache_dict[request_id] = dict(
mel = torch.zeros(1, 80, 0, device='cuda'), mel=torch.zeros(1, 80, 0, device='cuda'),
source = torch.zeros(1, 1, 0, device='cuda'), source=torch.zeros(1, 1, 0, device='cuda'),
speech = torch.zeros(1, 0, device='cuda'), speech=torch.zeros(1, 0, device='cuda'),
) )
current_request_cache = self.streaming_flow_cache[request_id] current_request_cache = self.streaming_flow_cache[request_id]
@@ -389,7 +418,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
token=generated_speech_tokens, token=generated_speech_tokens,
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device), spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
@@ -400,15 +428,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
self.streaming_flow_cache[request_id] = new_streaming_flow_cache self.streaming_flow_cache[request_id] = new_streaming_flow_cache
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]], self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
], dim=4) ], dim=4)
hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone() hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
hift_cache_source = self.hift_cache_dict[request_id]['source'].clone() hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone() hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
@@ -422,9 +447,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
# update vocoder cache # update vocoder cache
self.hift_cache_dict[request_id] = dict( self.hift_cache_dict[request_id] = dict(
mel = mel[..., -self.mel_cache_len:].clone().detach(), mel=mel[..., -self.mel_cache_len:].clone().detach(),
source = source[:, :, -self.source_cache_len:].clone().detach(), source=source[:, :, -self.source_cache_len:].clone().detach(),
speech = speech[:, -self.source_cache_len:].clone().detach(), speech=speech[:, -self.source_cache_len:].clone().detach(),
) )
if not last_chunk: if not last_chunk:
speech = speech[:, :-self.source_cache_len] speech = speech[:, :-self.source_cache_len]
@@ -436,6 +461,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
return speech return speech
def collate_fn(batch): def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
for i, item in enumerate(batch): for i, item in enumerate(batch):
@@ -447,6 +473,7 @@ def collate_fn(batch):
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--enable-trt", action="store_true") parser.add_argument("--enable-trt", action="store_true")
@@ -457,6 +484,7 @@ def get_args():
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
@@ -466,22 +494,17 @@ if __name__ == "__main__":
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
for epoch in range(args.warmup): for epoch in range(args.warmup):
start_time = time.time() start_time = time.time()
for batch in data_loader: for batch in data_loader:
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
for id, wav in zip(ids, generated_wavs): for id, wav in zip(ids, generated_wavs):
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
end_time = time.time() end_time = time.time()
epoch_time = end_time - start_time epoch_time = end_time - start_time
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")

View File

@@ -28,7 +28,6 @@ import argparse
import json import json
import os import os
import sys import sys
from pathlib import Path
import torch import torch
import torch.distributed as dist import torch.distributed as dist

View File

@@ -15,11 +15,6 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import ast
import csv
import os
from pathlib import Path
from typing import List, Optional
import numpy as np import numpy as np
import torch import torch

View File

@@ -9,6 +9,7 @@ import time
from token2wav_dit import CosyVoice2_Token2Wav from token2wav_dit import CosyVoice2_Token2Wav
import soundfile as sf import soundfile as sf
def collate_fn(batch): def collate_fn(batch):
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
prompt_speech_tokens_list, prompt_text_list = [], [] prompt_speech_tokens_list, prompt_text_list = [], []
@@ -23,6 +24,7 @@ def collate_fn(batch):
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--enable-trt", action="store_true") parser.add_argument("--enable-trt", action="store_true")
@@ -79,7 +81,11 @@ if __name__ == "__main__":
this_chunk_size = token_frame_rate * (2 ** chunk_index) this_chunk_size = token_frame_rate * (2 ** chunk_index)
if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len: if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) wavs = token2wav_model.forward_streaming(
buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len],
False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio,
prompt_audio_sample_rate=prompt_audio_sample_rate
)
buffer = buffer[this_chunk_size - OVERLAP_SIZE:] buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
output_wavs.append(wavs) output_wavs.append(wavs)
@@ -87,7 +93,10 @@ if __name__ == "__main__":
chunk_index += 1 chunk_index += 1
else: else:
wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) wavs = token2wav_model.forward_streaming(
buffer, True, request_id=id, speaker_id=f"{id}",
prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate
)
output_wavs.append(wavs) output_wavs.append(wavs)
total_forward_count += 1 total_forward_count += 1
# chunk_index += 1 # chunk_index += 1
@@ -96,7 +105,6 @@ if __name__ == "__main__":
for i, wav in enumerate(output_wavs): for i, wav in enumerate(output_wavs):
output_wavs[i] = wav.cpu().numpy().squeeze() output_wavs[i] = wav.cpu().numpy().squeeze()
audios = output_wavs audios = output_wavs
reconstructed_audio = np.concatenate(audios) reconstructed_audio = np.concatenate(audios)
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16") sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
@@ -111,4 +119,4 @@ if __name__ == "__main__":
print(f"Cost time without speaker cache: {end_time - start_time} seconds") print(f"Cost time without speaker cache: {end_time - start_time} seconds")
else: else:
print(f"Cost time with speaker cache: {end_time - start_time} seconds") print(f"Cost time with speaker cache: {end_time - start_time} seconds")
print(f"Total flow matching forward calls: {total_forward_count}") print(f"Total flow matching forward calls: {total_forward_count}")