fix decoupled mode

This commit is contained in:
yuekaiz
2025-07-29 11:13:07 +08:00
parent 178da09993
commit dc196df940
3 changed files with 52 additions and 33 deletions

View File

@@ -295,11 +295,26 @@ class TritonPythonModel:
if self.decoupled:
response_sender = request.get_response_sender()
request_id = request.request_id()
for generated_ids in generated_ids_iter:
raise NotImplementedError("Decoupled mode is not implemented")
generated_ids = []
for generated_id in generated_ids_iter:
# convert the numpy array into a int32 tensor
generated_id = generated_id.tolist()
if len(generated_id) > 0:
assert len(generated_id) == 1, "Generated ID is not a single integer"
generated_ids.append(generated_id[0])
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
# Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info(f"send tritonserver_response_complete_final to end")
else:
generated_ids = next(generated_ids_iter)
generated_ids = torch.tensor([generated_ids]).to(self.device)
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
if generated_ids is None or len(generated_ids) == 0:
raise pb_utils.TritonModelException("Generated IDs is None or empty")
@@ -311,9 +326,5 @@ class TritonPythonModel:
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
responses.append(inference_response)
if self.decoupled:
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info(f"send tritonserver_response_complete_final to end")
if not self.decoupled:
return responses