mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
support streaming tts
This commit is contained in:
@@ -114,6 +114,7 @@ class TritonPythonModel:
|
||||
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
||||
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
||||
"temperature": np.array([[0.8]], dtype=np.float32),
|
||||
"repetition_penalty": np.array([[1.1]], dtype=np.float32),
|
||||
"input_ids": input_ids,
|
||||
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
||||
}
|
||||
@@ -144,6 +145,7 @@ class TritonPythonModel:
|
||||
|
||||
# Get actual output IDs up to the sequence length
|
||||
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
||||
print(f"actual_output_ids: {actual_output_ids}")
|
||||
|
||||
yield actual_output_ids
|
||||
else:
|
||||
@@ -193,7 +195,10 @@ class TritonPythonModel:
|
||||
prompt_speech_tokens: torch.Tensor,
|
||||
prompt_speech_feat: torch.Tensor,
|
||||
prompt_spk_embedding: torch.Tensor,
|
||||
target_speech_tokens: torch.Tensor) -> torch.Tensor:
|
||||
target_speech_tokens: torch.Tensor,
|
||||
request_id: str,
|
||||
token_offset: int = None,
|
||||
finalize: bool = None) -> torch.Tensor:
|
||||
"""Forward pass through the vocoder component.
|
||||
|
||||
Args:
|
||||
@@ -210,11 +215,22 @@ class TritonPythonModel:
|
||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||
|
||||
inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
||||
|
||||
if token_offset is not None:
|
||||
assert finalize is not None
|
||||
token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
|
||||
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||
inputs_tensor.append(token_offset_tensor)
|
||||
inputs_tensor.append(finalize_tensor)
|
||||
|
||||
|
||||
# Create and execute inference request
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='token2wav',
|
||||
requested_output_names=['waveform'],
|
||||
inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
||||
inputs=inputs_tensor,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
inference_response = inference_request.exec()
|
||||
@@ -275,6 +291,7 @@ class TritonPythonModel:
|
||||
responses = []
|
||||
|
||||
for request in requests:
|
||||
request_id = request.request_id()
|
||||
# Extract input tensors
|
||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||
@@ -292,6 +309,11 @@ class TritonPythonModel:
|
||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||
|
||||
|
||||
flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
|
||||
token_hop_len = 25
|
||||
flow_pre_lookahead_len = 3
|
||||
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
|
||||
@@ -308,24 +330,46 @@ class TritonPythonModel:
|
||||
# Generate semantic tokens with LLM
|
||||
generated_ids_iter = self.forward_llm(input_ids)
|
||||
|
||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
||||
print(f"here2")
|
||||
if self.decoupled:
|
||||
response_sender = request.get_response_sender()
|
||||
request_id = request.request_id()
|
||||
generated_ids = []
|
||||
for generated_id in generated_ids_iter:
|
||||
# convert the numpy array into a int32 tensor
|
||||
generated_id = generated_id.tolist()
|
||||
if len(generated_id) > 0:
|
||||
assert len(generated_id) == 1, "Generated ID is not a single integer"
|
||||
generated_ids.append(generated_id[0])
|
||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
|
||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
||||
|
||||
# Prepare response
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
|
||||
|
||||
semantic_token_ids_arr, token_offset = [], 0
|
||||
for generated_ids in generated_ids_iter:
|
||||
|
||||
generated_ids = generated_ids.tolist()
|
||||
print(f"generated_id: {generated_ids}")
|
||||
semantic_token_ids_arr.extend(generated_ids)
|
||||
|
||||
prompt_token_pad = int(np.ceil(flow_prompt_speech_token_len / token_hop_len) * token_hop_len - flow_prompt_speech_token_len)
|
||||
this_token_hop_len = token_hop_len + prompt_token_pad if token_offset == 0 else token_hop_len
|
||||
print(f"this_token_hop_len: {this_token_hop_len}")
|
||||
if len(semantic_token_ids_arr) - token_offset >= this_token_hop_len + flow_pre_lookahead_len:
|
||||
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + flow_pre_lookahead_len]
|
||||
print(f"this_tts_speech_token: {this_tts_speech_token}")
|
||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
print(f"here3")
|
||||
|
||||
sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, False)
|
||||
print(f"here4")
|
||||
# Prepare response to send
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
|
||||
self.logger.log_info(f"[{request_id}]")
|
||||
token_offset += this_token_hop_len
|
||||
print(f"here")
|
||||
|
||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||
else:
|
||||
@@ -334,8 +378,7 @@ class TritonPythonModel:
|
||||
if generated_ids is None or len(generated_ids) == 0:
|
||||
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
||||
|
||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
|
||||
|
||||
# Prepare response
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
|
||||
Reference in New Issue
Block a user