This commit is contained in:
Yuekai Zhang
2025-09-03 03:45:17 -07:00
parent 633b991290
commit ad257b06e3
4 changed files with 14 additions and 20 deletions

View File

@@ -57,13 +57,13 @@ class TritonPythonModel:
self.device = torch.device("cuda")
model_dir = model_params["model_dir"]
gpu="l20"
gpu = "l20"
enable_trt = True
if enable_trt:
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False)
f'{model_dir}/campplus.onnx',
1,
False)
else:
campplus_model = f'{model_dir}/campplus.onnx'
option = onnxruntime.SessionOptions()
@@ -121,7 +121,7 @@ class TritonPythonModel:
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.spk_model.release_estimator(spk_model, stream)
return embedding.half()
def execute(self, requests):
@@ -142,7 +142,6 @@ class TritonPythonModel:
wav_array = torch.from_numpy(wav_array).to(self.device)
embedding = self._extract_spk_embedding(wav_array)
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
"prompt_spk_embedding", to_dlpack(embedding))