remove cache router

This commit is contained in:
root
2025-09-26 15:14:31 +08:00
parent 31a0adc73d
commit 79116ac32e
7 changed files with 219 additions and 243 deletions

View File

@@ -109,7 +109,6 @@ class TritonPythonModel:
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
self.http_client = httpx.AsyncClient()
self.runtime_cache = {}
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
"""Converts a tensor or list of speech token IDs to a string representation."""
@@ -264,38 +263,11 @@ class TritonPythonModel:
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# optional cache inputs
if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None:
# inputs_tensor.extend([
# pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()),
# pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()),
# pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()),
# pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()),
# pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()),
# pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()),
# pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()),
# ])
inputs_tensor.extend([
self.runtime_cache[request_id]["conformer_cnn_cache"],
self.runtime_cache[request_id]["conformer_att_cache"],
self.runtime_cache[request_id]["estimator_cnn_cache"],
self.runtime_cache[request_id]["estimator_att_cache"],
self.runtime_cache[request_id]["mel"],
self.runtime_cache[request_id]["source"],
self.runtime_cache[request_id]["speech"],
])
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav_dit',
requested_output_names=[
"waveform",
"conformer_cnn_cache",
"conformer_att_cache",
"estimator_cnn_cache",
"estimator_att_cache",
"mel",
"source",
"speech",
],
inputs=inputs_tensor,
request_id=request_id,
@@ -306,14 +278,6 @@ class TritonPythonModel:
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache")
self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache")
self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache")
self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache")
self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel")
self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source")
self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech")
# Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
@@ -339,16 +303,6 @@ class TritonPythonModel:
async def _process_request(self, request):
request_id = request.request_id()
if request_id not in self.runtime_cache:
self.runtime_cache[request_id] = {
"conformer_cnn_cache": None,
"conformer_att_cache": None,
"estimator_cnn_cache": None,
"estimator_att_cache": None,
"mel": None,
"source": None,
"speech": None,
}
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
@@ -369,7 +323,7 @@ class TritonPythonModel:
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
# reference_text = self.default_spk_info["prompt_text"]
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
@@ -453,9 +407,7 @@ class TritonPythonModel:
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
if request_id in self.runtime_cache:
del self.runtime_cache[request_id]
self.logger.log_info(f"Deleted cache for request_id: {request_id}")
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:

View File

@@ -31,7 +31,7 @@ parameters [
value: {string_value:"${model_dir}"}
}
]
parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}}
input [
{
name: "reference_wav"