clean code

This commit is contained in:
root
2025-10-08 16:48:00 +08:00
parent f186ec3338
commit a019a2504e
5 changed files with 46 additions and 193 deletions

View File

@@ -106,13 +106,10 @@ class TritonPythonModel:
# Process each request in batch
for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
# shift the speech tokens according to the original vocab size
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens.squeeze().tolist()
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
request_id = request.request_id()
@@ -124,23 +121,14 @@ class TritonPythonModel:
request, "reference_wav_len").as_numpy().item()
wav_array = torch.from_numpy(wav_array)
# Prepare inputs
wav = wav_array[:, :wav_len].squeeze(0)
spk_id = get_spk_id_from_prompt_audio(wav)
# wav = wav.to(self.device)
# update cache before forward
# self.token2wav_model.streaming_flow_cache[request_id]
# self.token2wav_model.hift_cache_dict[request_id]
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
# get the cache after forward
outputs = []
generated_wave = audio_hat.squeeze(0).cpu().numpy()
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
outputs.append(wav_tensor)
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)