add triton solution

This commit is contained in:
Yuekai Zhang
2025-07-22 06:50:13 -07:00
parent b048a2d6db
commit 5427c274e3
18 changed files with 3448 additions and 0 deletions

View File

@@ -0,0 +1,95 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import torch
from torch.utils.dlpack import to_dlpack
import triton_python_backend_utils as pb_utils
import os
import numpy as np
import s3tokenizer
class TritonPythonModel:
"""Triton Python model for audio tokenization.
This model takes reference audio input and extracts semantic tokens
using s3tokenizer.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.device = torch.device("cuda")
model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing tokenized outputs
"""
mels = []
# Process each request in batch
for request in requests:
# Extract input tensors
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(
request, "reference_wav_len").as_numpy().item()
wav_array = torch.from_numpy(wav_array).to(self.device)
# Prepare inputs
wav = wav_array[:, :wav_len].squeeze(0)
mels.append(s3tokenizer.log_mel_spectrogram(wav))
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
codes = codes.clone() + 151663
responses = []
for i in range(len(requests)):
prompt_speech_tokens = codes[i, :codes_lens[i].item()]
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
"prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
inference_response = pb_utils.InferenceResponse(
output_tensors=[prompt_speech_tokens_tensor])
responses.append(inference_response)
return responses

View File

@@ -0,0 +1,53 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "audio_tokenizer"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
}
]
output [
{
name: "prompt_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

View File

@@ -0,0 +1,331 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import math
import os
import re
from typing import Dict, List, Tuple, Optional, Union
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
import torchaudio.compliance.kaldi as kaldi
import torchaudio
import onnxruntime
from matcha.utils.audio import mel_spectrogram
class TritonPythonModel:
"""Triton Python model for Spark TTS.
This model orchestrates the end-to-end TTS pipeline by coordinating
between audio tokenizer, LLM, and vocoder components.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
self.logger = pb_utils.Logger
# Parse model parameters
self.model_config = json.loads(args['model_config'])
parameters = self.model_config['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.logger.log_info(f"model_params:{model_params}")
# Initialize tokenizer
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
self.device = torch.device("cuda")
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
campplus_model = f'{model_params["model_dir"]}/campplus.onnx'
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
def forward_llm(self, input_ids):
"""
Prepares the response from the language model based on the provided
inputs. Creates a `pb_utils.InferenceRequest` object with passed
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
For each response from the language model:
- Checks for errors and raise an exception if any are found.
- Extracts the "output_ids" tensor from the response.
- Determines the finish reason based on the presence of the
end-of-sequence token or reaching the maximum length.
- Appends the generated token IDs to `output_ids`.
- If the finish reason is determined, decodes the output IDs to text
and prepares the final response.
The final response includes the generated text, finish reason,
completion tokens, prompt tokens, and total tokens.
Parameters
----------
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
Returns
-------
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
"""
# convert input_ids to numpy, with shape [1, sequence_length]
input_ids = input_ids.cpu().numpy()
max_tokens = 1024
input_dict = {
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
"pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
"streaming": np.array([[self.decoupled]], dtype=np.bool_),
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
"runtime_top_k": np.array([[50]], dtype=np.int32),
"temperature": np.array([[0.8]], dtype=np.float32),
"input_ids": input_ids,
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
}
# Convert inputs to Triton tensors
input_tensor_list = [
pb_utils.Tensor(k, v) for k, v in input_dict.items()
]
# Create and execute inference request
llm_request = pb_utils.InferenceRequest(
model_name="tensorrt_llm",
requested_output_names=["output_ids", "sequence_length"],
inputs=input_tensor_list,
)
llm_responses = llm_request.exec(decoupled=self.decoupled)
if self.decoupled:
for llm_response in llm_responses:
if llm_response.has_error():
raise pb_utils.TritonModelException(llm_response.error().message())
# Extract and process output
output_ids = pb_utils.get_output_tensor_by_name(
llm_response, "output_ids").as_numpy()
seq_lens = pb_utils.get_output_tensor_by_name(
llm_response, "sequence_length").as_numpy()
# Get actual output IDs up to the sequence length
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
yield actual_output_ids
else:
llm_response = llm_responses
if llm_response.has_error():
raise pb_utils.TritonModelException(llm_response.error().message())
# Extract and process output
output_ids = pb_utils.get_output_tensor_by_name(
llm_response, "output_ids").as_numpy()
seq_lens = pb_utils.get_output_tensor_by_name(
llm_response, "sequence_length").as_numpy()
# Get actual output IDs up to the sequence length
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
yield actual_output_ids
def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component.
Args:
wav: Input waveform tensor
wav_len: Waveform length tensor
Returns:
Tuple of global and semantic tokens
"""
inference_request = pb_utils.InferenceRequest(
model_name='audio_tokenizer',
requested_output_names=['prompt_speech_tokens'],
inputs=[wav, wav_len]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
return prompt_speech_tokens
def forward_token2wav(self, prompt_speech_tokens: torch.Tensor, prompt_speech_feat: torch.Tensor, prompt_spk_embedding: torch.Tensor, target_speech_tokens: torch.Tensor) -> torch.Tensor:
"""Forward pass through the vocoder component.
Args:
prompt_speech_tokens: Prompt speech tokens tensor
prompt_speech_feat: Prompt speech feat tensor
prompt_spk_embedding: Prompt spk embedding tensor
target_speech_tokens: Target speech tokens tensor
Returns:
Generated waveform tensor
"""
print(prompt_speech_tokens.shape, prompt_speech_feat.shape, prompt_spk_embedding.shape, target_speech_tokens.shape)
# Convert tensors to Triton format
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav',
requested_output_names=['waveform'],
inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def parse_input(self, text, prompt_text, prompt_speech_tokens):
total_text = f"{prompt_text}{text}"
prompt = self.prompt_template.format(input_text=total_text)
input_ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([input_ids], dtype=torch.int32)
print(input_ids.shape, "before cat")
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
print(input_ids.shape, "after cat", prompt_speech_tokens.shape)
return input_ids
def _extract_spk_embedding(self, speech):
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.campplus_session.run(None,
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device).half()
return embedding
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(speech, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000).squeeze(dim=0).transpose(0, 1).to(self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated audio
"""
responses = []
for request in requests:
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
# Process reference audio through audio tokenizer
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
# TODO: FIX ME
wav_tensor = wav.as_numpy()
print(wav_tensor.shape, "wav_tensor")
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
print(wav_tensor.shape, "wav_tensor after")
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
print(speech_feat.shape, "speech_feat")
print(prompt_speech_tokens.shape, "prompt_speech_tokens here")
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
print(prompt_speech_tokens.shape, "prompt_speech_tokens after")
print(speech_feat.shape, "speech_feat after")
print(token_len, "token_len")
# Extract text inputs
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
# Prepare prompt for LLM
input_ids = self.parse_input(
text=target_text,
prompt_text=reference_text,
prompt_speech_tokens=prompt_speech_tokens,
)
# Generate semantic tokens with LLM
generated_ids_iter = self.forward_llm(input_ids)
if self.decoupled:
response_sender = request.get_response_sender()
request_id = request.request_id()
for generated_ids in generated_ids_iter:
raise NotImplementedError("Decoupled mode is not implemented")
else:
generated_ids = next(generated_ids_iter)
generated_ids = torch.tensor([generated_ids]).to(self.device)
if generated_ids is None or len(generated_ids) == 0:
raise pb_utils.TritonModelException("Generated IDs is None or empty")
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
# Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
responses.append(inference_response)
if self.decoupled:
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info(f"send tritonserver_response_complete_final to end")
if not self.decoupled:
return responses

View File

@@ -0,0 +1,70 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "cosyvoice2"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
model_transaction_policy {
decoupled: ${decoupled_mode}
}
parameters [
{
key: "llm_tokenizer_dir",
value: {string_value:"${llm_tokenizer_dir}"}
},
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
},
{
name: "target_text"
data_type: TYPE_STRING
dims: [1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: ${bls_instance_num}
kind: KIND_CPU
}
]

View File

@@ -0,0 +1,857 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
name: "tensorrt_llm"
backend: "${triton_backend}"
max_batch_size: ${triton_max_batch_size}
model_transaction_policy {
decoupled: ${decoupled_mode}
}
dynamic_batching {
preferred_batch_size: [ ${triton_max_batch_size} ]
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
default_queue_policy: { max_queue_size: ${max_queue_size} }
}
input [
{
name: "input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
allow_ragged_batch: true
optional: true
},
{
name: "encoder_input_features"
data_type: ${encoder_input_features_data_type}
dims: [ -1, -1 ]
allow_ragged_batch: true
optional: true
},
{
name: "encoder_output_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "input_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
},
{
name: "request_output_len"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
},
{
name: "num_return_sequences"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "draft_input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "decoder_input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "decoder_input_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
reshape: { shape: [ ] }
},
{
name: "draft_logits"
data_type: ${logits_datatype}
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "draft_acceptance_threshold"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "end_id"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "pad_id"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "stop_words_list"
data_type: TYPE_INT32
dims: [ 2, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "bad_words_list"
data_type: TYPE_INT32
dims: [ 2, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "embedding_bias"
data_type: TYPE_FP32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "beam_width"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "temperature"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_k"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_min"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_decay"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_reset_ids"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "len_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "early_stopping"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "repetition_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "min_length"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "beam_search_diversity_rate"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "presence_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "frequency_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "random_seed"
data_type: TYPE_UINT64
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_log_probs"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_context_logits"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_generation_logits"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_perf_metrics"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "exclude_input_in_output"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "stop"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "streaming"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "prompt_embedding_table"
data_type: TYPE_FP16
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "prompt_table_extra_ids"
data_type: TYPE_UINT64
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "prompt_vocab_size"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
# cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
{
name: "cross_attention_mask"
data_type: TYPE_BOOL
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
# Mrope param when mrope is used
{
name: "mrope_rotary_cos_sin"
data_type: TYPE_FP32
dims: [ -1 ]
optional: true
},
{
name: "mrope_position_deltas"
data_type: TYPE_INT64
dims: [ 1 ]
optional: true
},
# the unique task ID for the given LoRA.
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
{
name: "lora_task_id"
data_type: TYPE_UINT64
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
# each of the in / out tensors are first flattened and then concatenated together in the format above.
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
{
name: "lora_weights"
data_type: TYPE_FP16
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
# module identifier (same size a first dimension of lora_weights)
# See LoraModule::ModuleType for model id mapping
#
# "attn_qkv": 0 # compbined qkv adapter
# "attn_q": 1 # q adapter
# "attn_k": 2 # k adapter
# "attn_v": 3 # v adapter
# "attn_dense": 4 # adapter for the dense layer in attention
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
#
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
{
name: "lora_config"
data_type: TYPE_INT32
dims: [ -1, 3 ]
optional: true
allow_ragged_batch: true
},
{
name: "context_phase_params"
data_type: TYPE_UINT8
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
# skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
{
name: "skip_cross_attn_blocks"
data_type: TYPE_BOOL
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_starts"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_ends"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_priorities"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_durations_ms"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_decode_priority"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_decode_duration_ms"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "guided_decoding_guide_type"
data_type: TYPE_STRING
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "guided_decoding_guide"
data_type: TYPE_STRING
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "lookahead_window_size"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "lookahead_ngram_size"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "lookahead_verification_set_size"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
}
]
output [
{
name: "output_ids"
data_type: TYPE_INT32
dims: [ -1, -1 ]
},
{
name: "sequence_length"
data_type: TYPE_INT32
dims: [ -1 ]
},
{
name: "cum_log_probs"
data_type: TYPE_FP32
dims: [ -1 ]
},
{
name: "output_log_probs"
data_type: TYPE_FP32
dims: [ -1, -1 ]
},
{
name: "context_logits"
data_type: ${logits_datatype}
dims: [ -1, -1 ]
},
{
name: "generation_logits"
data_type: ${logits_datatype}
dims: [ -1, -1, -1 ]
},
{
name: "batch_index"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "sequence_index"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "context_phase_params"
data_type: TYPE_UINT8
dims: [ -1 ]
},
{
name: "kv_cache_alloc_new_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "kv_cache_reused_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "kv_cache_alloc_total_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "arrival_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "first_scheduled_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "first_token_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "last_token_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "acceptance_rate"
data_type: TYPE_FP32
dims: [ 1 ]
},
{
name: "total_accepted_draft_tokens"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "total_draft_tokens"
data_type: TYPE_INT32
dims: [ 1 ]
}
]
instance_group [
{
count: 1
kind : KIND_CPU
}
]
parameters: {
key: "max_beam_width"
value: {
string_value: "${max_beam_width}"
}
}
parameters: {
key: "FORCE_CPU_ONLY_INPUT_TENSORS"
value: {
string_value: "no"
}
}
parameters: {
key: "gpt_model_type"
value: {
string_value: "${batching_strategy}"
}
}
parameters: {
key: "gpt_model_path"
value: {
string_value: "${engine_dir}"
}
}
parameters: {
key: "encoder_model_path"
value: {
string_value: "${encoder_engine_dir}"
}
}
parameters: {
key: "max_tokens_in_paged_kv_cache"
value: {
string_value: "${max_tokens_in_paged_kv_cache}"
}
}
parameters: {
key: "max_attention_window_size"
value: {
string_value: "${max_attention_window_size}"
}
}
parameters: {
key: "sink_token_length"
value: {
string_value: "${sink_token_length}"
}
}
parameters: {
key: "batch_scheduler_policy"
value: {
string_value: "${batch_scheduler_policy}"
}
}
parameters: {
key: "kv_cache_free_gpu_mem_fraction"
value: {
string_value: "${kv_cache_free_gpu_mem_fraction}"
}
}
parameters: {
key: "cross_kv_cache_fraction"
value: {
string_value: "${cross_kv_cache_fraction}"
}
}
parameters: {
key: "kv_cache_host_memory_bytes"
value: {
string_value: "${kv_cache_host_memory_bytes}"
}
}
# kv_cache_onboard_blocks is for internal implementation.
parameters: {
key: "kv_cache_onboard_blocks"
value: {
string_value: "${kv_cache_onboard_blocks}"
}
}
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
# parameters: {
# key: "enable_trt_overlap"
# value: {
# string_value: "${enable_trt_overlap}"
# }
# }
parameters: {
key: "exclude_input_in_output"
value: {
string_value: "${exclude_input_in_output}"
}
}
parameters: {
key: "cancellation_check_period_ms"
value: {
string_value: "${cancellation_check_period_ms}"
}
}
parameters: {
key: "stats_check_period_ms"
value: {
string_value: "${stats_check_period_ms}"
}
}
parameters: {
key: "iter_stats_max_iterations"
value: {
string_value: "${iter_stats_max_iterations}"
}
}
parameters: {
key: "request_stats_max_iterations"
value: {
string_value: "${request_stats_max_iterations}"
}
}
parameters: {
key: "enable_kv_cache_reuse"
value: {
string_value: "${enable_kv_cache_reuse}"
}
}
parameters: {
key: "normalize_log_probs"
value: {
string_value: "${normalize_log_probs}"
}
}
parameters: {
key: "enable_chunked_context"
value: {
string_value: "${enable_chunked_context}"
}
}
parameters: {
key: "gpu_device_ids"
value: {
string_value: "${gpu_device_ids}"
}
}
parameters: {
key: "participant_ids"
value: {
string_value: "${participant_ids}"
}
}
parameters: {
key: "lora_cache_optimal_adapter_size"
value: {
string_value: "${lora_cache_optimal_adapter_size}"
}
}
parameters: {
key: "lora_cache_max_adapter_size"
value: {
string_value: "${lora_cache_max_adapter_size}"
}
}
parameters: {
key: "lora_cache_gpu_memory_fraction"
value: {
string_value: "${lora_cache_gpu_memory_fraction}"
}
}
parameters: {
key: "lora_cache_host_memory_bytes"
value: {
string_value: "${lora_cache_host_memory_bytes}"
}
}
parameters: {
key: "lora_prefetch_dir"
value: {
string_value: "${lora_prefetch_dir}"
}
}
parameters: {
key: "decoding_mode"
value: {
string_value: "${decoding_mode}"
}
}
parameters: {
key: "executor_worker_path"
value: {
string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
}
}
parameters: {
key: "lookahead_window_size"
value: {
string_value: "${lookahead_window_size}"
}
}
parameters: {
key: "lookahead_ngram_size"
value: {
string_value: "${lookahead_ngram_size}"
}
}
parameters: {
key: "lookahead_verification_set_size"
value: {
string_value: "${lookahead_verification_set_size}"
}
}
parameters: {
key: "medusa_choices"
value: {
string_value: "${medusa_choices}"
}
}
parameters: {
key: "eagle_choices"
value: {
string_value: "${eagle_choices}"
}
}
parameters: {
key: "gpu_weights_percent"
value: {
string_value: "${gpu_weights_percent}"
}
}
parameters: {
key: "enable_context_fmha_fp32_acc"
value: {
string_value: "${enable_context_fmha_fp32_acc}"
}
}
parameters: {
key: "multi_block_mode"
value: {
string_value: "${multi_block_mode}"
}
}
parameters: {
key: "cuda_graph_mode"
value: {
string_value: "${cuda_graph_mode}"
}
}
parameters: {
key: "cuda_graph_cache_size"
value: {
string_value: "${cuda_graph_cache_size}"
}
}
parameters: {
key: "speculative_decoding_fast_logits"
value: {
string_value: "${speculative_decoding_fast_logits}"
}
}
parameters: {
key: "tokenizer_dir"
value: {
string_value: "${tokenizer_dir}"
}
}
parameters: {
key: "guided_decoding_backend"
value: {
string_value: "${guided_decoding_backend}"
}
}
parameters: {
key: "xgrammar_tokenizer_info_path"
value: {
string_value: "${xgrammar_tokenizer_info_path}"
}
}

View File

@@ -0,0 +1,198 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import logging
from typing import List, Dict
import torch
from torch.utils.dlpack import to_dlpack
import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
#import sys
#sys.path.append("/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS")
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class CosyVoice2:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
self.model_dir = model_dir
self.fp16 = fp16
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
class CosyVoice2Model:
def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.flow.half()
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load(self, flow_model, hift_model):
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
class TritonPythonModel:
"""Triton Python model for vocoder.
This model takes global and semantic tokens as input and generates audio waveforms
using the BiCodec vocoder.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {key: value["string_value"] for key, value in parameters.items()}
model_dir = model_params["model_dir"]
# Initialize device and vocoder
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
self.token2wav_model = CosyVoice2(
model_dir, load_jit=True, load_trt=True, fp16=True
)
logger.info("Token2Wav initialized successfully")
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated waveforms
"""
responses = []
# Process each request in batch
for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
prompt_speech_tokens = prompt_speech_tokens - 151663
target_speech_tokens = target_speech_tokens - 151663
tts_mel, _ = self.token2wav_model.model.flow.inference(
token=target_speech_tokens,
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
self.device
),
prompt_token=prompt_speech_tokens,
prompt_token_len=torch.tensor(
[prompt_speech_tokens.shape[1]], dtype=torch.int32
).to(self.device),
prompt_feat=prompt_speech_feat,
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=prompt_spk_embedding,
streaming=False,
finalize=True,
)
audio_hat, _ = self.token2wav_model.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
generated_wave = audio_hat.squeeze(0).cpu().numpy()
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
responses.append(inference_response)
return responses

View File

@@ -0,0 +1,63 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "token2wav"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "target_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
},
{
name: "prompt_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
},
{
name: "prompt_speech_feat"
data_type: TYPE_FP16
dims: [-1, 80]
},
{
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]