mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix decoupled mode
This commit is contained in:
@@ -295,11 +295,26 @@ class TritonPythonModel:
|
||||
if self.decoupled:
|
||||
response_sender = request.get_response_sender()
|
||||
request_id = request.request_id()
|
||||
for generated_ids in generated_ids_iter:
|
||||
raise NotImplementedError("Decoupled mode is not implemented")
|
||||
generated_ids = []
|
||||
for generated_id in generated_ids_iter:
|
||||
# convert the numpy array into a int32 tensor
|
||||
generated_id = generated_id.tolist()
|
||||
if len(generated_id) > 0:
|
||||
assert len(generated_id) == 1, "Generated ID is not a single integer"
|
||||
generated_ids.append(generated_id[0])
|
||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
|
||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
||||
|
||||
# Prepare response
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
self.logger.log_info(f"send tritonserver_response_complete_final to end")
|
||||
else:
|
||||
generated_ids = next(generated_ids_iter)
|
||||
generated_ids = torch.tensor([generated_ids]).to(self.device)
|
||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
|
||||
if generated_ids is None or len(generated_ids) == 0:
|
||||
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
||||
|
||||
@@ -311,9 +326,5 @@ class TritonPythonModel:
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
responses.append(inference_response)
|
||||
|
||||
if self.decoupled:
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
self.logger.log_info(f"send tritonserver_response_complete_final to end")
|
||||
|
||||
if not self.decoupled:
|
||||
return responses
|
||||
Reference in New Issue
Block a user