diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py new file mode 100644 index 0000000..97659ad --- /dev/null +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -0,0 +1,455 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import math +import os +import re +import threading +import time +from typing import Dict, List, Tuple, Optional, Union + +import numpy as np +import torch +from torch.utils.dlpack import from_dlpack, to_dlpack +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer + +import torchaudio + + +from matcha.utils.audio import mel_spectrogram + +ORIGINAL_VOCAB_SIZE = 151663 +torch.set_num_threads(1) + + +class TritonPythonModel: + """Triton Python model for Spark TTS. + + This model orchestrates the end-to-end TTS pipeline by coordinating + between audio tokenizer, LLM, and vocoder components. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + self.logger = pb_utils.Logger + # Parse model parameters + self.model_config = json.loads(args['model_config']) + parameters = self.model_config['parameters'] + model_params = {k: v["string_value"] for k, v in parameters.items()} + self.logger.log_info(f"model_params:{model_params}") + self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" + self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") + + # Initialize tokenizer + llm_tokenizer_dir = model_params["llm_tokenizer_dir"] + self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) + self.prompt_template = "<|sos|>{input_text}<|task_id|>" + self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>") + + self.device = torch.device("cuda") + self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) + + self.token_frame_rate = 25 + self.flow_pre_lookahead_len = 3 + self.token_hop_len = 15 + + spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt") + if not os.path.exists(spk_info_path): + raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") + spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) + self.default_spk_info = spk_info["001"] + + def forward_llm(self, input_ids): + """ + Prepares the response from the language model based on the provided + inputs. Creates a `pb_utils.InferenceRequest` object with passed + `llm_request_inputs` to send to a decoupled TensorRTLLM model. + For each response from the language model: + - Checks for errors and raise an exception if any are found. + - Extracts the "output_ids" tensor from the response. + - Determines the finish reason based on the presence of the + end-of-sequence token or reaching the maximum length. + - Appends the generated token IDs to `output_ids`. + - If the finish reason is determined, decodes the output IDs to text + and prepares the final response. + + The final response includes the generated text, finish reason, + completion tokens, prompt tokens, and total tokens. + + Parameters + ---------- + - llm_request_inputs (dict): A dictionary containing the inputs for the language model. + + Returns + ------- + - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata. + """ + # convert input_ids to numpy, with shape [1, sequence_length] + input_ids = input_ids.cpu().numpy() + max_tokens = 750 + input_dict = { + "request_output_len": np.array([[max_tokens]], dtype=np.int32), + "end_id": np.array([[self.eos_token_id]], dtype=np.int32), + "pad_id": np.array([[self.eos_token_id]], dtype=np.int32), + "streaming": np.array([[self.decoupled]], dtype=np.bool_), + "runtime_top_p": np.array([[0.95]], dtype=np.float32), + "runtime_top_k": np.array([[50]], dtype=np.int32), + "temperature": np.array([[0.8]], dtype=np.float32), + "repetition_penalty": np.array([[1.1]], dtype=np.float32), + "random_seed": np.array([[42]], dtype=np.uint64), + "input_ids": input_ids, + "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), + } + + # Convert inputs to Triton tensors + input_tensor_list = [ + pb_utils.Tensor(k, v) for k, v in input_dict.items() + ] + + # Create and execute inference request + llm_request = pb_utils.InferenceRequest( + model_name="tensorrt_llm", + requested_output_names=["output_ids", "sequence_length"], + inputs=input_tensor_list, + ) + + llm_responses = llm_request.exec(decoupled=self.decoupled) + if self.decoupled: + for llm_response in llm_responses: + if llm_response.has_error(): + raise pb_utils.TritonModelException(llm_response.error().message()) + + # Extract and process output + output_ids = pb_utils.get_output_tensor_by_name( + llm_response, "output_ids").as_numpy() + seq_lens = pb_utils.get_output_tensor_by_name( + llm_response, "sequence_length").as_numpy() + + # Get actual output IDs up to the sequence length + actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + + yield actual_output_ids + else: + llm_response = llm_responses + if llm_response.has_error(): + raise pb_utils.TritonModelException(llm_response.error().message()) + + # Extract and process output + output_ids = pb_utils.get_output_tensor_by_name( + llm_response, "output_ids").as_numpy() + seq_lens = pb_utils.get_output_tensor_by_name( + llm_response, "sequence_length").as_numpy() + + # Get actual output IDs up to the sequence length + actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + + yield actual_output_ids + + def forward_audio_tokenizer(self, wav, wav_len): + """Forward pass through the audio tokenizer component. + + Args: + wav: Input waveform tensor + wav_len: Waveform length tensor + + Returns: + Tuple of global and semantic tokens + """ + inference_request = pb_utils.InferenceRequest( + model_name='audio_tokenizer', + requested_output_names=['prompt_speech_tokens'], + inputs=[wav, wav_len] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') + prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() + + return prompt_speech_tokens + + def forward_speaker_embedding(self, wav): + """Forward pass through the speaker embedding component. + + Args: + wav: Input waveform tensor + + Returns: + Prompt speaker embedding tensor + """ + inference_request = pb_utils.InferenceRequest( + model_name='speaker_embedding', + requested_output_names=['prompt_spk_embedding'], + inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding') + prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack()) + + return prompt_spk_embedding + + def forward_token2wav( + self, + target_speech_tokens: torch.Tensor, + request_id: str, + prompt_speech_tokens: torch.Tensor = None, + prompt_speech_feat: torch.Tensor = None, + prompt_spk_embedding: torch.Tensor = None, + token_offset: int = None, + finalize: bool = None) -> torch.Tensor: + """Forward pass through the vocoder component. + + Args: + prompt_speech_tokens: Prompt speech tokens tensor + prompt_speech_feat: Prompt speech feat tensor + prompt_spk_embedding: Prompt spk embedding tensor + target_speech_tokens: Target speech tokens tensor + + Returns: + Generated waveform tensor + """ + target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) + + inputs_tensor = [target_speech_tokens_tensor] + + if token_offset is not None: + assert finalize is not None + token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32)) + finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) + inputs_tensor.append(token_offset_tensor) + inputs_tensor.append(finalize_tensor) + + if prompt_spk_embedding is not None: + assert prompt_speech_feat is not None + prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) + prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) + prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) + inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor]) + + # Create and execute inference request + inference_request = pb_utils.InferenceRequest( + model_name='token2wav', + requested_output_names=['waveform'], + inputs=inputs_tensor, + request_id=request_id, + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output waveform + waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') + waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() + + return waveform + + def parse_input(self, text, prompt_text, prompt_speech_tokens): + total_text = f"{prompt_text}{text}" + prompt = self.prompt_template.format(input_text=total_text) + input_ids = self.tokenizer.encode(prompt) + input_ids = torch.tensor([input_ids], dtype=torch.int32) + input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1) + return input_ids + + def _extract_speech_feat(self, speech): + speech_feat = mel_spectrogram( + speech, + n_fft=1920, + num_mels=80, + sampling_rate=24000, + hop_size=480, + win_size=1920, + fmin=0, + fmax=8000).squeeze( + dim=0).transpose( + 0, + 1).to( + self.device) + speech_feat = speech_feat.unsqueeze(dim=0) + return speech_feat + + def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag): + for generated_ids in generated_ids_iter: + generated_ids = generated_ids.tolist() + if len(generated_ids) == 0: + break + semantic_token_ids_arr.extend(generated_ids) + llm_is_done_flag[0] = True + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated audio + """ + responses = [] + + for request in requests: + request_id = request.request_id() + # Extract input tensors + wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") + + # Process reference audio through audio tokenizer + if wav is not None: + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) + prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) + + wav_tensor = wav.as_numpy() + wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) + speech_feat = self._extract_speech_feat(prompt_speech_resample) + token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) + prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() + prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() + + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') + prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + else: + # using pre-cached reference text + reference_text = self.default_spk_info["prompt_text"] + prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE + prompt_speech_feat = None + prompt_spk_embedding = None + + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = target_text[0][0].decode('utf-8') + + # Prepare prompt for LLM + input_ids = self.parse_input( + text=target_text, + prompt_text=reference_text, + prompt_speech_tokens=prompt_speech_tokens, + ) + + # Generate semantic tokens with LLM + generated_ids_iter = self.forward_llm(input_ids) + + if self.decoupled: + response_sender = request.get_response_sender() + + semantic_token_ids_arr = [] + llm_is_done_flag = [False] + + llm_thread = threading.Thread( + target=self._llm_gen_thread, + args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag) + ) + + llm_thread.start() + + token_offset, chunk_index = 0, 0 + start_time = time.time() + this_token_hop_len = self.token_hop_len + + while True: + pending_num = len(semantic_token_ids_arr) - token_offset + + if llm_is_done_flag[0]: + break + + if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: + this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] + this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) + + sub_tts_speech = self.forward_token2wav( + this_tts_speech_token, request_id, prompt_speech_tokens, + prompt_speech_feat, prompt_spk_embedding, token_offset, False + ) + + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + + token_offset += this_token_hop_len + self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}") + + if self.dynamic_chunk_strategy == "exponential": + this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) + elif self.dynamic_chunk_strategy == "time_based": + # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 + cost_time = time.time() - start_time + duration = token_offset / self.token_frame_rate + if chunk_index > 0 and cost_time > 0: + avg_chunk_processing_time = cost_time / (chunk_index + 1) + if avg_chunk_processing_time > 0: + multiples = (duration - cost_time) / avg_chunk_processing_time + self.logger.log_info(f"multiples: {multiples}") + next_pending_num = len(semantic_token_ids_arr) - token_offset + if multiples > 4: + this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len + elif multiples > 2: + this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len + else: + this_token_hop_len = self.token_hop_len + this_token_hop_len = max(self.token_hop_len, this_token_hop_len) + chunk_index += 1 + else: + time.sleep(0.02) + + this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device) + sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True) + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + + llm_thread.join() + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + self.logger.log_info("send tritonserver_response_complete_final to end") + else: + generated_ids = next(generated_ids_iter) + generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device) + if generated_ids is None or len(generated_ids) == 0: + raise pb_utils.TritonModelException("Generated IDs is None or empty") + + audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding) + + # Prepare response + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + responses.append(inference_response) + + if not self.decoupled: + return responses diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt new file mode 100644 index 0000000..73a9a05 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt @@ -0,0 +1,73 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "cosyvoice2" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +model_transaction_policy { + decoupled: ${decoupled_mode} +} +parameters [ + { + key: "llm_tokenizer_dir", + value: {string_value:"${llm_tokenizer_dir}"} + }, + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "reference_wav" + data_type: TYPE_FP32 + dims: [-1] + optional: true + }, + { + name: "reference_wav_len" + data_type: TYPE_INT32 + dims: [1] + optional: true + }, + { + name: "reference_text" + data_type: TYPE_STRING + dims: [1] + optional: true + }, + { + name: "target_text" + data_type: TYPE_STRING + dims: [1] + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: ${bls_instance_num} + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py new file mode 100644 index 0000000..10bc272 --- /dev/null +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -0,0 +1,278 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import os + +import logging +from typing import List, Dict + +import torch +from torch.utils.dlpack import to_dlpack +from torch.nn import functional as F + +import triton_python_backend_utils as pb_utils + +from hyperpyyaml import load_hyperpyyaml +from cosyvoice.utils.common import fade_in_out +from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm +from cosyvoice.utils.common import TrtContextWrapper +from collections import defaultdict +import numpy as np + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +ORIGINAL_VOCAB_SIZE = 151663 +torch.set_num_threads(1) + + +class CosyVoice2: + + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'): + + self.model_dir = model_dir + self.fp16 = fp16 + + hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir) + if not os.path.exists(hyper_yaml_path): + raise ValueError('{} not found!'.format(hyper_yaml_path)) + with open(hyper_yaml_path, 'r') as f: + configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')}) + self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device) + self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) + if load_jit: + self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) + if load_trt: + self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), + '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + trt_concurrent, + self.fp16) + + +class CosyVoice2Model: + + def __init__(self, + flow: torch.nn.Module, + hift: torch.nn.Module, + fp16: bool = False, + device: str = 'cuda'): + self.device = device + self.flow = flow + self.hift = hift + self.fp16 = fp16 + if self.fp16 is True: + self.flow.half() + + # streaming tts config + self.token_hop_len = 25 + self.mel_cache_len = 8 + self.source_cache_len = int(self.mel_cache_len * 480) + self.speech_window = np.hamming(2 * self.source_cache_len) + self.hift_cache_dict = defaultdict(lambda: None) + + def load_jit(self, flow_encoder_model): + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + self.flow.encoder = flow_encoder + + def load(self, flow_model, hift_model): + self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) + self.flow.to(self.device).eval() + # in case hift_model is a hifigan model + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() + + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) + del self.flow.decoder.estimator + import tensorrt as trt + with open(flow_decoder_estimator_model, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_trt_kwargs(self): + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] + opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)] + max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] + input_names = ["x", "mask", "mu", "cond"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): + with torch.cuda.amp.autocast(self.fp16): + tts_mel, _ = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + streaming=stream, + finalize=finalize) + tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] + # append hift cache + if self.hift_cache_dict[uuid] is not None: + hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] + tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = torch.zeros(1, 1, 0) + # keep overlap mel and hift cache + if finalize is False: + tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) + self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:], + 'source': tts_source[:, :, -self.source_cache_len:], + 'speech': tts_speech[:, -self.source_cache_len:]} + tts_speech = tts_speech[:, :-self.source_cache_len] + else: + if speed != 1.0: + assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode' + tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear') + tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) + return tts_speech + + +class TritonPythonModel: + """Triton Python model for vocoder. + + This model takes global and semantic tokens as input and generates audio waveforms + using the BiCodec vocoder. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + # Parse model parameters + parameters = json.loads(args['model_config'])['parameters'] + model_params = {key: value["string_value"] for key, value in parameters.items()} + model_dir = model_params["model_dir"] + + # Initialize device and vocoder + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + logger.info(f"Initializing vocoder from {model_dir} on {self.device}") + + self.token2wav_model = CosyVoice2( + model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device + ) + + spk_info_path = os.path.join(model_dir, "spk2info.pt") + if not os.path.exists(spk_info_path): + raise ValueError(f"spk2info.pt not found in {model_dir}") + spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) + self.default_spk_info = spk_info["001"] + + logger.info("Token2Wav initialized successfully") + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated waveforms + """ + responses = [] + # Process each request in batch + for request in requests: + target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() + target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device) + + prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens") + if prompt_speech_tokens_tensor is not None: + prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy() + prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy() + prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy() + prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device) + prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device) + prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device) + prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE + else: + prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device) + prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device) + prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device) + + # shift the speech tokens according to the original vocab size + target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE + + # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. + token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset") + if token_offset is not None: + token_offset = token_offset.as_numpy().item() + finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() + if not finalize: + stream = True + else: + stream = False + request_id = request.request_id() + audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens, + prompt_token=prompt_speech_tokens, + prompt_feat=prompt_speech_feat, + embedding=prompt_spk_embedding, + token_offset=token_offset, + uuid=request_id, + stream=stream, + finalize=finalize) + if finalize: + self.token2wav_model.model.hift_cache_dict.pop(request_id) + + else: + tts_mel, _ = self.token2wav_model.model.flow.inference( + token=target_speech_tokens, + token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( + self.device + ), + prompt_token=prompt_speech_tokens, + prompt_token_len=torch.tensor( + [prompt_speech_tokens.shape[1]], dtype=torch.int32 + ).to(self.device), + prompt_feat=prompt_speech_feat, + prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=prompt_spk_embedding, + streaming=False, + finalize=True, + ) + + audio_hat, _ = self.token2wav_model.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) + + generated_wave = audio_hat.squeeze(0).cpu().numpy() + + wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) + inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) + responses.append(inference_response) + + return responses diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt new file mode 100644 index 0000000..c33a85f --- /dev/null +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt @@ -0,0 +1,80 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "token2wav" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +parameters [ + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "target_speech_tokens" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "prompt_speech_tokens" + data_type: TYPE_INT32 + dims: [-1] + optional: true + }, + { + name: "prompt_speech_feat" + data_type: TYPE_FP16 + dims: [-1, 80] + optional: true + }, + { + name: "prompt_spk_embedding" + data_type: TYPE_FP16 + dims: [-1] + optional: true + }, + { + name: "token_offset" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "finalize" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh new file mode 100644 index 0000000..7c7f3cd --- /dev/null +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -0,0 +1,142 @@ +#!/bin/bash +# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) +export CUDA_VISIBLE_DEVICES=0 +cosyvoice_path=/workspace/CosyVoice +export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH +export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH +stage=$1 +stop_stage=$2 + +huggingface_model_local_dir=./cosyvoice2_llm +model_scope_model_local_dir=./CosyVoice2-0.5B +trt_dtype=bfloat16 +trt_weights_dir=./trt_weights_${trt_dtype} +trt_engines_dir=./trt_engines_${trt_dtype} + +model_repo=./model_repo_cosyvoice2 + +use_spk2info_cache=False + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + echo "Cloning CosyVoice" + git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path + cd $cosyvoice_path + git submodule update --init --recursive + cd runtime/triton_trtllm +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + echo "Downloading CosyVoice2-0.5B" + # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py + huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm + modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir + # download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings + wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt +fi + + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + echo "Converting checkpoint to TensorRT weights" + python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \ + --output_dir $trt_weights_dir \ + --dtype $trt_dtype || exit 1 + + echo "Building TensorRT engines" + trtllm-build --checkpoint_dir $trt_weights_dir \ + --output_dir $trt_engines_dir \ + --max_batch_size 16 \ + --max_num_tokens 32768 \ + --gemm_plugin $trt_dtype || exit 1 + + echo "Testing TensorRT engines" + python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \ + --tokenizer_dir $huggingface_model_local_dir \ + --top_k 50 --top_p 0.95 --temperature 0.8 \ + --engine_dir=$trt_engines_dir || exit 1 +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + echo "Creating model repository" + rm -rf $model_repo + mkdir -p $model_repo + cosyvoice2_dir="cosyvoice2" + + cp -r ./model_repo/${cosyvoice2_dir} $model_repo + cp -r ./model_repo/tensorrt_llm $model_repo + cp -r ./model_repo/token2wav $model_repo + if [ $use_spk2info_cache == "False" ]; then + cp -r ./model_repo/audio_tokenizer $model_repo + cp -r ./model_repo/speaker_embedding $model_repo + fi + + ENGINE_PATH=$trt_engines_dir + MAX_QUEUE_DELAY_MICROSECONDS=0 + MODEL_DIR=$model_scope_model_local_dir + LLM_TOKENIZER_DIR=$huggingface_model_local_dir + BLS_INSTANCE_NUM=4 + TRITON_MAX_BATCH_SIZE=16 + DECOUPLED_MODE=True # True for streaming, False for offline + + python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 + if [ $use_spk2info_cache == "False" ]; then + python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + echo "Starting Triton server" + tritonserver --model-repository $model_repo +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + echo "Single request test http, only work for offline TTS mode" + python3 client_http.py \ + --reference-audio ./assets/prompt_audio.wav \ + --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ + --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \ + --model-name cosyvoice2 +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + echo "Running benchmark client grpc" + num_task=4 + + mode=streaming + BLS_INSTANCE_NUM=4 + + python3 client_grpc.py \ + --server-addr localhost \ + --model-name cosyvoice2 \ + --num-tasks $num_task \ + --mode $mode \ + --use-spk2info-cache $use_spk2info_cache \ + --huggingface-dataset yuekai/seed_tts_cosy2 \ + --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache} +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + echo "stage 6: Offline inference benchmark" + n_gpus=1 + datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh + backend=trtllm # hf, trtllm, vllm + + batch_sizes=(16 8 4 2 1) + token2wav_batch_size=1 + for batch_size in ${batch_sizes[@]}; do + for dataset in ${datasets[@]}; do + output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size} + CUDA_VISIBLE_DEVICES=0 \ + python3 offline_inference.py \ + --output-dir $output_dir \ + --llm-model-name-or-path $huggingface_model_local_dir \ + --token2wav-path $model_scope_model_local_dir \ + --backend $backend \ + --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \ + --engine-dir $trt_engines_dir \ + --split-name ${dataset} || exit 1 + done + done +fi diff --git a/runtime/triton_trtllm/token2wav_dit.py b/runtime/triton_trtllm/token2wav_dit.py new file mode 100644 index 0000000..69db946 --- /dev/null +++ b/runtime/triton_trtllm/token2wav_dit.py @@ -0,0 +1,496 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage + CUDA_VISIBLE_DEVICES=0 \ + python3 token2wav.py --enable-trt || exit 1 +""" +import torch +# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec +from flashcosyvoice.modules.hifigan import HiFTGenerator +from flashcosyvoice.utils.audio import mel_spectrogram +import torchaudio.compliance.kaldi as kaldi +import onnxruntime +import s3tokenizer +from torch.utils.data import DataLoader +from datasets import load_dataset +import torchaudio +import os +import logging +import argparse +import queue +import time +import numpy as np +from hyperpyyaml import load_hyperpyyaml + + +def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor): + """perform fade_in_out in tensor style + """ + mel_overlap_len = int(window.shape[0] / 2) + fade_in_mel = fade_in_mel.clone() + fade_in_mel[..., :mel_overlap_len] = \ + fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ + fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] + return fade_in_mel + +def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype): + import tensorrt as trt + logging.info("Converting onnx to trt...") + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + network = builder.create_network(network_flags) + parser = trt.OnnxParser(network, logger) + config = builder.create_builder_config() + # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB + if dtype == torch.float16: + config.set_flag(trt.BuilderFlag.FP16) + elif dtype == torch.bfloat16: + config.set_flag(trt.BuilderFlag.BF16) + elif dtype == torch.float32: + config.set_flag(trt.BuilderFlag.FP32) + profile = builder.create_optimization_profile() + # load onnx model + with open(onnx_model, "rb") as f: + if not parser.parse(f.read()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + raise ValueError('failed to parse {}'.format(onnx_model)) + # set input shapes + for i in range(len(trt_kwargs['input_names'])): + profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) + if dtype == torch.float16: + tensor_dtype = trt.DataType.HALF + elif dtype == torch.bfloat16: + tensor_dtype = trt.DataType.BF16 + elif dtype == torch.float32: + tensor_dtype = trt.DataType.FLOAT + else: + raise ValueError('invalid dtype {}'.format(dtype)) + # set input and output data type + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + input_tensor.dtype = tensor_dtype + for i in range(network.num_outputs): + output_tensor = network.get_output(i) + output_tensor.dtype = tensor_dtype + config.add_optimization_profile(profile) + engine_bytes = builder.build_serialized_network(network, config) + # save trt engine + with open(trt_model, "wb") as f: + f.write(engine_bytes) + logging.info("Succesfully convert onnx to trt...") + +class TrtContextWrapper: + def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) + self.trt_engine = trt_engine + self.device = device + for _ in range(trt_concurrent): + trt_context = trt_engine.create_execution_context() + trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device))) + assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) + self.trt_context_pool.put([trt_context, trt_stream]) + assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' + + def acquire_estimator(self): + return self.trt_context_pool.get(), self.trt_engine + + def release_estimator(self, context, stream): + self.trt_context_pool.put([context, stream]) + +class CosyVoice2_Token2Wav(torch.nn.Module): + def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16): + super().__init__() + self.device_id = device_id + self.device = f"cuda:{device_id}" + with open(f"{model_dir}/flow.yaml", "r") as f: + configs = load_hyperpyyaml(f) + self.flow = configs['flow'] + + self.dtype = dtype + self.flow.to(self.dtype) + + self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True) + self.flow.to(self.device).eval() + + self.hift = HiFTGenerator() + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() + + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, + providers=["CPUExecutionProvider"]) + + self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval() + + gpu="l20" + if enable_trt: + if streaming: + self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx', + 1, + self.dtype, streaming) + else: + self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan', + f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx', + 1, + self.dtype) + self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt', + f'{model_dir}/campplus.onnx', + 1, + False) + + + self.streaming_flow_cache = {} + self.speaker_cache = {} + + self.mel_cache_len = 8 # hard-coded, 160ms + self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave + self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda() + + # hifigan cache for streaming tts + self.hift_cache_dict = {} + + def forward_spk_embedding(self, spk_feat): + if isinstance(self.spk_model, onnxruntime.InferenceSession): + return self.spk_model.run( + None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} + )[0].flatten().tolist() + else: + [spk_model, stream], trt_engine = self.spk_model.acquire_estimator() + # NOTE need to synchronize when switching stream + with torch.cuda.device(self.device_id): + torch.cuda.current_stream().synchronize() + spk_feat = spk_feat.unsqueeze(dim=0).to(self.device) + batch_size = spk_feat.size(0) + + with stream: + spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80)) + output_tensor = torch.empty((batch_size, 192), device=spk_feat.device) + + data_ptrs = [spk_feat.contiguous().data_ptr(), + output_tensor.contiguous().data_ptr()] + for i, j in enumerate(data_ptrs): + + spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.spk_model.release_estimator(spk_model, stream) + + return output_tensor.cpu().numpy().flatten().tolist() + + def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True): + if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0: + trt_kwargs = self.get_spk_trt_kwargs() + convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16) + import tensorrt as trt + with open(spk_model, 'rb') as f: + spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert spk_engine is not None, 'failed to load trt {}'.format(spk_model) + self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_spk_trt_kwargs(self): + min_shape = [(1, 4, 80)] + opt_shape = [(1, 500, 80)] + max_shape = [(1, 3000, 80)] + input_names = ["input"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False): + assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: + opt_batch_size = 2 + max_batch_size = 16 + if streaming: + opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts + trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming) + convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype) + del self.flow.decoder.estimator + import tensorrt as trt + with open(flow_decoder_estimator_model, 'rb') as f: + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device) + + def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False): + if streaming: + min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)] + opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)] + max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)] + input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"] + else: + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)] + opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)] + max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)] + input_names = ["x", "mask", "mu", "cond", "t", "spks"] + return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + + def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]: + prompt_speech_tokens_list, prompt_speech_mels_list = [], [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T] + prompt_speech_mels_list.append(log_mel) + prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list) + prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize( + prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device) + ) + for i in range(len(prompt_speech_tokens)): + speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() + prompt_speech_tokens_list.append(speech_tokens_i) + return prompt_speech_tokens_list + + def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor: + spk_emb_for_flow = [] + for audio in prompt_audios_list: + assert len(audio.shape) == 1 + spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) + spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) + spk_emb = self.forward_spk_embedding(spk_feat) + + spk_emb_for_flow.append(spk_emb) + spk_emb_for_flow = torch.tensor(spk_emb_for_flow) + if self.dtype != torch.float32: + spk_emb_for_flow = spk_emb_for_flow.to(self.dtype) + return spk_emb_for_flow + + def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]): + prompt_mels_for_flow = [] + prompt_mels_lens_for_flow = [] + for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate): + assert len(audio.shape) == 1 + audio = audio.unsqueeze(0) + if sample_rate != 24000: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) + mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] + mel_len = mel.shape[0] + prompt_mels_for_flow.append(mel) + prompt_mels_lens_for_flow.append(mel_len) + prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] + prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) + return prompt_mels_for_flow, prompt_mels_lens_for_flow + + def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor): + batch_size = prompt_mels_for_flow.shape[0] + flow_inputs = [] + flow_inputs_lens = [] + for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list): + flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens)) + flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens)) + + flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0) + flow_inputs_lens = torch.tensor(flow_inputs_lens) + + with torch.amp.autocast(self.device, dtype=torch.float16): + generated_mels, generated_mels_lens = self.flow.inference( + flow_inputs.to(self.device), flow_inputs_lens.to(self.device), + prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10 + ) + + return generated_mels, generated_mels_lens + + def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor): + batch_size = generated_mels.shape[0] + generated_wavs = [] + for i in range(batch_size): + mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0) + wav, _ = self.hift(speech_feat=mel) + generated_wavs.append(wav) + return generated_wavs + + + @torch.inference_mode() + def forward( + self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] + ): + # assert all item in prompt_audios_sample_rate is 16000 + assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) + + + prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate) + + generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + + generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow) + + return generated_wavs + + def prepare_prompt_audio( + self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] + ): + # assert all item in prompt_audios_sample_rate is 16000 + assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) + + + prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) + + prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate) + + spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) + + return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + + + def get_prompt_audio_cache_for_streaming_tts( + self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow + ): + assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts" + for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list): + prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3]) + prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0) + + cache = self.flow.setup_cache( + prompt_speech_tokens_tensor.to(self.device), + prompt_mels_for_flow.to(self.device), + spk_emb_for_flow.to(self.device), + n_timesteps=10 + ) + + # cache dict's tensor batch dim is 1 for now + return cache + + + @torch.inference_mode() + def forward_streaming( + self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 + ): + + if speaker_id not in self.speaker_cache: + assert prompt_audio is not None, "prompt_audio is required for new speaker" + assert prompt_audio_sample_rate == 16000 + + prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate]) + + token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0])) + prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous() + prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len] + + cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) + prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} + + self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} + + if request_id not in self.streaming_flow_cache: + self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy() + self.hift_cache_dict[request_id] = dict( + mel = torch.zeros(1, 80, 0, device='cuda'), + source = torch.zeros(1, 1, 0, device='cuda'), + speech = torch.zeros(1, 0, device='cuda'), + ) + + current_request_cache = self.streaming_flow_cache[request_id] + prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict'] + generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') + + chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk( + token=generated_speech_tokens, + spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device), + cache=current_request_cache, + last_chunk=last_chunk, + n_timesteps=10, + ) + + self.streaming_flow_cache[request_id] = new_streaming_flow_cache + + if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100): + self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([ + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]], + self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:], + ], dim=4) + + + + hift_cache_mel = self.hift_cache_dict[request_id]['mel'] + hift_cache_source = self.hift_cache_dict[request_id]['source'] + hift_cache_speech = self.hift_cache_dict[request_id]['speech'] + mel = torch.concat([hift_cache_mel, chunk_mel], dim=2) + + speech, source = self.hift(mel, hift_cache_source) + + # overlap speech smooth + if hift_cache_speech.shape[-1] > 0: + speech = fade_in_out(speech, hift_cache_speech, self.speech_window) + + # update vocoder cache + self.hift_cache_dict[request_id] = dict( + mel = mel[..., -self.mel_cache_len:].clone().detach(), + source = source[:, :, -self.source_cache_len:].clone().detach(), + speech = speech[:, -self.source_cache_len:].clone().detach(), + ) + if not last_chunk: + speech = speech[:, :-self.source_cache_len] + + if last_chunk: + assert request_id in self.streaming_flow_cache + self.streaming_flow_cache.pop(request_id) + self.hift_cache_dict.pop(request_id) + + return speech + +def collate_fn(batch): + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], [] + for i, item in enumerate(batch): + generated_speech_tokens_list.append(item['target_audio_cosy2_tokens']) + audio = torch.from_numpy(item['prompt_audio']['array']).float() + prompt_audios_list.append(audio) + prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate']) + ids.append(item['id']) + + return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--enable-trt", action="store_true") + parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--output-dir", type=str, default="generated_wavs") + parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") + parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch") + return parser.parse_args() + +if __name__ == "__main__": + args = get_args() + model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) + # mkdir output_dir if not exists + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + dataset_name = "yuekai/seed_tts_cosy2" + + dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True) + + + data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) + + + for epoch in range(args.warmup): + start_time = time.time() + + for batch in data_loader: + ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch + + generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate) + + + for id, wav in zip(ids, generated_wavs): + torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000) + + end_time = time.time() + epoch_time = end_time - start_time + print(f"Measurement epoch time taken: {epoch_time:.4f} seconds") \ No newline at end of file