mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
remove cache router
This commit is contained in:
@@ -109,7 +109,6 @@ class TritonPythonModel:
|
||||
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||
self.default_spk_info = spk_info["001"]
|
||||
self.http_client = httpx.AsyncClient()
|
||||
self.runtime_cache = {}
|
||||
|
||||
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
||||
"""Converts a tensor or list of speech token IDs to a string representation."""
|
||||
@@ -264,38 +263,11 @@ class TritonPythonModel:
|
||||
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
||||
|
||||
# optional cache inputs
|
||||
if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None:
|
||||
# inputs_tensor.extend([
|
||||
# pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()),
|
||||
# pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()),
|
||||
# pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()),
|
||||
# pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()),
|
||||
# pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()),
|
||||
# pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()),
|
||||
# pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()),
|
||||
# ])
|
||||
inputs_tensor.extend([
|
||||
self.runtime_cache[request_id]["conformer_cnn_cache"],
|
||||
self.runtime_cache[request_id]["conformer_att_cache"],
|
||||
self.runtime_cache[request_id]["estimator_cnn_cache"],
|
||||
self.runtime_cache[request_id]["estimator_att_cache"],
|
||||
self.runtime_cache[request_id]["mel"],
|
||||
self.runtime_cache[request_id]["source"],
|
||||
self.runtime_cache[request_id]["speech"],
|
||||
])
|
||||
# Create and execute inference request
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='token2wav_dit',
|
||||
requested_output_names=[
|
||||
"waveform",
|
||||
"conformer_cnn_cache",
|
||||
"conformer_att_cache",
|
||||
"estimator_cnn_cache",
|
||||
"estimator_att_cache",
|
||||
"mel",
|
||||
"source",
|
||||
"speech",
|
||||
],
|
||||
inputs=inputs_tensor,
|
||||
request_id=request_id,
|
||||
@@ -306,14 +278,6 @@ class TritonPythonModel:
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
|
||||
self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache")
|
||||
self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache")
|
||||
self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache")
|
||||
self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache")
|
||||
self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel")
|
||||
self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source")
|
||||
self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech")
|
||||
|
||||
# Extract and convert output waveform
|
||||
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
||||
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
||||
@@ -339,16 +303,6 @@ class TritonPythonModel:
|
||||
|
||||
async def _process_request(self, request):
|
||||
request_id = request.request_id()
|
||||
if request_id not in self.runtime_cache:
|
||||
self.runtime_cache[request_id] = {
|
||||
"conformer_cnn_cache": None,
|
||||
"conformer_att_cache": None,
|
||||
"estimator_cnn_cache": None,
|
||||
"estimator_att_cache": None,
|
||||
"mel": None,
|
||||
"source": None,
|
||||
"speech": None,
|
||||
}
|
||||
# Extract input tensors
|
||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||
|
||||
@@ -369,7 +323,7 @@ class TritonPythonModel:
|
||||
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||
|
||||
# reference_text = self.default_spk_info["prompt_text"]
|
||||
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||
@@ -453,9 +407,7 @@ class TritonPythonModel:
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
if request_id in self.runtime_cache:
|
||||
del self.runtime_cache[request_id]
|
||||
self.logger.log_info(f"Deleted cache for request_id: {request_id}")
|
||||
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||
else:
|
||||
|
||||
@@ -31,7 +31,7 @@ parameters [
|
||||
value: {string_value:"${model_dir}"}
|
||||
}
|
||||
]
|
||||
parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}}
|
||||
|
||||
input [
|
||||
{
|
||||
name: "reference_wav"
|
||||
|
||||
@@ -103,91 +103,47 @@ class TritonPythonModel:
|
||||
List of inference responses containing generated waveforms
|
||||
"""
|
||||
responses = []
|
||||
# Process each request in batch
|
||||
for request in requests:
|
||||
request_id = request.request_id()
|
||||
|
||||
# Get inputs
|
||||
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens")
|
||||
target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack())
|
||||
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
||||
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
|
||||
# shift the speech tokens according to the original vocab size
|
||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
target_speech_tokens = target_speech_tokens.squeeze().tolist()
|
||||
|
||||
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
|
||||
|
||||
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||
wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy()
|
||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item()
|
||||
wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0)
|
||||
|
||||
request_id = request.request_id()
|
||||
|
||||
|
||||
wav_array = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav").as_numpy()
|
||||
wav_len = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav_len").as_numpy().item()
|
||||
|
||||
wav_array = torch.from_numpy(wav_array)
|
||||
# Prepare inputs
|
||||
wav = wav_array[:, :wav_len].squeeze(0)
|
||||
|
||||
spk_id = get_spk_id_from_prompt_audio(wav)
|
||||
# wav = wav.to(self.device)
|
||||
|
||||
# Handle cache
|
||||
conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache")
|
||||
if conformer_cnn_cache is not None:
|
||||
self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack())
|
||||
|
||||
conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache")
|
||||
self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1)
|
||||
|
||||
estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache")
|
||||
self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0)
|
||||
# update cache before forward
|
||||
# self.token2wav_model.streaming_flow_cache[request_id]
|
||||
# self.token2wav_model.hift_cache_dict[request_id]
|
||||
|
||||
estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache")
|
||||
self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0)
|
||||
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
|
||||
|
||||
mel_np = pb_utils.get_input_tensor_by_name(request, "mel")
|
||||
self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack())
|
||||
|
||||
source_np = pb_utils.get_input_tensor_by_name(request, "source")
|
||||
self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack())
|
||||
|
||||
speech_np = pb_utils.get_input_tensor_by_name(request, "speech")
|
||||
self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack())
|
||||
|
||||
# Forward pass
|
||||
audio_hat = self.token2wav_model.forward_streaming(
|
||||
target_speech_tokens,
|
||||
finalize,
|
||||
request_id=request_id,
|
||||
speaker_id=f"{spk_id}",
|
||||
prompt_audio=wav,
|
||||
prompt_audio_sample_rate=16000
|
||||
)
|
||||
|
||||
# Prepare outputs
|
||||
# get the cache after forward
|
||||
outputs = []
|
||||
|
||||
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||
|
||||
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
||||
outputs.append(wav_tensor)
|
||||
|
||||
if request_id in self.token2wav_model.streaming_flow_cache:
|
||||
cache = self.token2wav_model.streaming_flow_cache[request_id]
|
||||
hifigan_cache = self.token2wav_model.hift_cache_dict[request_id]
|
||||
conformer_cnn_cache = cache['conformer_cnn_cache']
|
||||
conformer_att_cache = cache['conformer_att_cache'].transpose(0,1)
|
||||
estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0)
|
||||
estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0)
|
||||
mel = hifigan_cache['mel']
|
||||
source = hifigan_cache['source']
|
||||
speech = hifigan_cache['speech']
|
||||
|
||||
outputs.extend([
|
||||
pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)),
|
||||
pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)),
|
||||
pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)),
|
||||
pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)),
|
||||
pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)),
|
||||
pb_utils.Tensor.from_dlpack("source", to_dlpack(source)),
|
||||
pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)),
|
||||
])
|
||||
else:
|
||||
outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)),
|
||||
pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)),
|
||||
pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)),
|
||||
pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)),
|
||||
pb_utils.Tensor("mel", np.array([], dtype=np.float32)),
|
||||
pb_utils.Tensor("source", np.array([], dtype=np.float32)),
|
||||
pb_utils.Tensor("speech", np.array([], dtype=np.float32)),
|
||||
])
|
||||
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
|
||||
responses.append(inference_response)
|
||||
return responses
|
||||
|
||||
def finalize(self):
|
||||
self.logger.log_info("Finalizing Token2WavDiT model")
|
||||
return responses
|
||||
|
||||
@@ -22,7 +22,6 @@ dynamic_batching {
|
||||
default_priority_level: 10
|
||||
}
|
||||
|
||||
parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}}
|
||||
parameters [
|
||||
{
|
||||
key: "model_dir",
|
||||
@@ -52,48 +51,6 @@ input [
|
||||
dims: [ 1 ]
|
||||
reshape: { shape: [ ] }
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "conformer_cnn_cache"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ 512, -1 ]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "conformer_att_cache"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ 10, 8, -1, 128 ]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "estimator_cnn_cache"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ 10, 16, -1, 1024, 2 ]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "estimator_att_cache"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ 10, 16, -1, 8, -1, 128 ]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "mel"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 80, -1 ]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "source"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1, -1 ]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "speech"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
optional: true
|
||||
}
|
||||
]
|
||||
output [
|
||||
@@ -101,41 +58,6 @@ output [
|
||||
name: "waveform"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
},
|
||||
{
|
||||
name: "conformer_cnn_cache"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ 512, -1 ]
|
||||
},
|
||||
{
|
||||
name: "conformer_att_cache"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ 10, 8, -1, 128 ]
|
||||
},
|
||||
{
|
||||
name: "estimator_cnn_cache"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ 10, 16, -1, 1024, 2 ]
|
||||
},
|
||||
{
|
||||
name: "estimator_att_cache"
|
||||
data_type: TYPE_FP16
|
||||
dims: [ 10, 16, -1, 8, -1, 128 ]
|
||||
},
|
||||
{
|
||||
name: "mel"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 80, -1 ]
|
||||
},
|
||||
{
|
||||
name: "source"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 1, -1 ]
|
||||
},
|
||||
{
|
||||
name: "speech"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user