This commit is contained in:
root
2025-07-29 08:39:41 +00:00
parent 1b8d194b67
commit 07cbc51cd1
8 changed files with 165 additions and 157 deletions

View File

@@ -35,33 +35,34 @@ import s3tokenizer
ORIGINAL_VOCAB_SIZE = 151663
class TritonPythonModel:
"""Triton Python model for audio tokenization.
This model takes reference audio input and extracts semantic tokens
using s3tokenizer.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.device = torch.device("cuda")
model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing tokenized outputs
"""
@@ -79,18 +80,18 @@ class TritonPythonModel:
# Prepare inputs
wav = wav_array[:, :wav_len].squeeze(0)
mels.append(s3tokenizer.log_mel_spectrogram(wav))
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
codes = codes.clone() + ORIGINAL_VOCAB_SIZE
responses = []
for i in range(len(requests)):
prompt_speech_tokens = codes[i, :codes_lens[i].item()]
prompt_speech_tokens = codes[i, :codes_lens[i].item()]
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
"prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
inference_response = pb_utils.InferenceResponse(
output_tensors=[prompt_speech_tokens_tensor])
responses.append(inference_response)
return responses
return responses