From 73d261dd480cc29f1a580ab5b39e1467fc15943f Mon Sep 17 00:00:00 2001 From: root Date: Tue, 2 Sep 2025 18:32:12 +0800 Subject: [PATCH] support streaming tts --- runtime/triton_trtllm/client_grpc.py | 62 +++++----- .../model_repo/cosyvoice2/1/model.py | 77 ++++++++++--- .../model_repo/token2wav/1/model.py | 108 +++++++++++++++--- .../model_repo/token2wav/config.pbtxt | 14 +++ 4 files changed, 199 insertions(+), 62 deletions(-) diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 881b519..4f1e1c3 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -395,38 +395,45 @@ def run_sync_streaming_inference( # Reconstruct audio using cross-fade (from client_grpc_streaming.py) actual_duration = 0 if audios: - cross_fade_samples = int(chunk_overlap_duration * save_sample_rate) - fade_out = np.linspace(1, 0, cross_fade_samples) - fade_in = np.linspace(0, 1, cross_fade_samples) - reconstructed_audio = None + # Only spark_tts model uses cross-fade + if model_name == "spark_tts": + cross_fade_samples = int(chunk_overlap_duration * save_sample_rate) + fade_out = np.linspace(1, 0, cross_fade_samples) + fade_in = np.linspace(0, 1, cross_fade_samples) + reconstructed_audio = None - # Simplified reconstruction based on client_grpc_streaming.py - if not audios: - print("Warning: No audio chunks received.") - reconstructed_audio = np.array([], dtype=np.float32) # Empty array - elif len(audios) == 1: - reconstructed_audio = audios[0] + # Simplified reconstruction based on client_grpc_streaming.py + if not audios: + print("Warning: No audio chunks received.") + reconstructed_audio = np.array([], dtype=np.float32) # Empty array + elif len(audios) == 1: + reconstructed_audio = audios[0] + else: + reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap + for i in range(1, len(audios)): + # Cross-fade section + cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + + audios[i - 1][-cross_fade_samples:] * fade_out) + # Middle section of the current chunk + middle_part = audios[i][cross_fade_samples:-cross_fade_samples] + # Concatenate + reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) + # Add the last part of the final chunk + reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) + + if reconstructed_audio is not None and reconstructed_audio.size > 0: + actual_duration = len(reconstructed_audio) / save_sample_rate + # Save reconstructed audio + sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") + else: + print("Warning: No audio chunks received or reconstructed.") + actual_duration = 0 # Set duration to 0 if no audio else: - reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap - for i in range(1, len(audios)): - # Cross-fade section - cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + - audios[i - 1][-cross_fade_samples:] * fade_out) - # Middle section of the current chunk - middle_part = audios[i][cross_fade_samples:-cross_fade_samples] - # Concatenate - reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) - # Add the last part of the final chunk - reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) - - if reconstructed_audio is not None and reconstructed_audio.size > 0: + reconstructed_audio = np.concatenate(audios) + print(f"reconstructed_audio: {reconstructed_audio.shape}") actual_duration = len(reconstructed_audio) / save_sample_rate # Save reconstructed audio - os.makedirs(os.path.dirname(audio_save_path), exist_ok=True) sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") - else: - print("Warning: No audio chunks received or reconstructed.") - actual_duration = 0 # Set duration to 0 if no audio else: print("Warning: No audio chunks received.") @@ -667,6 +674,7 @@ async def main(): manifest_item_list = split_data(manifest_item_list, num_tasks) os.makedirs(args.log_dir, exist_ok=True) + tasks = [] start_time = time.time() for i in range(num_tasks): diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py index 77a440b..c85b20b 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py @@ -114,6 +114,7 @@ class TritonPythonModel: "runtime_top_p": np.array([[0.95]], dtype=np.float32), "runtime_top_k": np.array([[50]], dtype=np.int32), "temperature": np.array([[0.8]], dtype=np.float32), + "repetition_penalty": np.array([[1.1]], dtype=np.float32), "input_ids": input_ids, "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), } @@ -144,6 +145,7 @@ class TritonPythonModel: # Get actual output IDs up to the sequence length actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + print(f"actual_output_ids: {actual_output_ids}") yield actual_output_ids else: @@ -193,7 +195,10 @@ class TritonPythonModel: prompt_speech_tokens: torch.Tensor, prompt_speech_feat: torch.Tensor, prompt_spk_embedding: torch.Tensor, - target_speech_tokens: torch.Tensor) -> torch.Tensor: + target_speech_tokens: torch.Tensor, + request_id: str, + token_offset: int = None, + finalize: bool = None) -> torch.Tensor: """Forward pass through the vocoder component. Args: @@ -210,11 +215,22 @@ class TritonPythonModel: prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) + inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor] + + if token_offset is not None: + assert finalize is not None + token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32)) + finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) + inputs_tensor.append(token_offset_tensor) + inputs_tensor.append(finalize_tensor) + + # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav', requested_output_names=['waveform'], - inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor] + inputs=inputs_tensor, + request_id=request_id, ) inference_response = inference_request.exec() @@ -275,6 +291,7 @@ class TritonPythonModel: responses = [] for request in requests: + request_id = request.request_id() # Extract input tensors wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") @@ -292,6 +309,11 @@ class TritonPythonModel: prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() + + flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1] + token_hop_len = 25 + flow_pre_lookahead_len = 3 + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') @@ -308,24 +330,46 @@ class TritonPythonModel: # Generate semantic tokens with LLM generated_ids_iter = self.forward_llm(input_ids) + prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) + print(f"here2") if self.decoupled: response_sender = request.get_response_sender() - request_id = request.request_id() - generated_ids = [] - for generated_id in generated_ids_iter: - # convert the numpy array into a int32 tensor - generated_id = generated_id.tolist() - if len(generated_id) > 0: - assert len(generated_id) == 1, "Generated ID is not a single integer" - generated_ids.append(generated_id[0]) - generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device) - prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) - audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) - # Prepare response - audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) + + + semantic_token_ids_arr, token_offset = [], 0 + for generated_ids in generated_ids_iter: + + generated_ids = generated_ids.tolist() + print(f"generated_id: {generated_ids}") + semantic_token_ids_arr.extend(generated_ids) + + prompt_token_pad = int(np.ceil(flow_prompt_speech_token_len / token_hop_len) * token_hop_len - flow_prompt_speech_token_len) + this_token_hop_len = token_hop_len + prompt_token_pad if token_offset == 0 else token_hop_len + print(f"this_token_hop_len: {this_token_hop_len}") + if len(semantic_token_ids_arr) - token_offset >= this_token_hop_len + flow_pre_lookahead_len: + this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + flow_pre_lookahead_len] + print(f"this_tts_speech_token: {this_tts_speech_token}") + this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) + print(f"here3") + + sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, False) + print(f"here4") + # Prepare response to send + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + response_sender.send(inference_response) + + self.logger.log_info(f"[{request_id}]") + token_offset += this_token_hop_len + print(f"here") + + this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device) + sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True) + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) self.logger.log_info("send tritonserver_response_complete_final to end") else: @@ -334,8 +378,7 @@ class TritonPythonModel: if generated_ids is None or len(generated_ids) == 0: raise pb_utils.TritonModelException("Generated IDs is None or empty") - prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) - audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) + audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id) # Prepare response audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) diff --git a/runtime/triton_trtllm/model_repo/token2wav/1/model.py b/runtime/triton_trtllm/model_repo/token2wav/1/model.py index d38f8a4..00dc258 100644 --- a/runtime/triton_trtllm/model_repo/token2wav/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav/1/model.py @@ -32,12 +32,16 @@ from typing import List, Dict import torch from torch.utils.dlpack import to_dlpack +from torch.nn import functional as F import triton_python_backend_utils as pb_utils from hyperpyyaml import load_hyperpyyaml +from cosyvoice.utils.common import fade_in_out from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm from cosyvoice.utils.common import TrtContextWrapper +from collections import defaultdict +import numpy as np logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) @@ -81,6 +85,13 @@ class CosyVoice2Model: if self.fp16 is True: self.flow.half() + # streaming tts config + self.token_hop_len = 25 + self.mel_cache_len = 8 + self.source_cache_len = int(self.mel_cache_len * 480) + self.speech_window = np.hamming(2 * self.source_cache_len) + self.hift_cache_dict = defaultdict(lambda: None) + def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder @@ -112,6 +123,43 @@ class CosyVoice2Model: return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): + with torch.cuda.amp.autocast(self.fp16): + tts_mel, _ = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + streaming=stream, + finalize=finalize) + tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] + # append hift cache + if self.hift_cache_dict[uuid] is not None: + hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] + tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = torch.zeros(1, 1, 0) + # keep overlap mel and hift cache + if finalize is False: + tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) + self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:], + 'source': tts_source[:, :, -self.source_cache_len:], + 'speech': tts_speech[:, -self.source_cache_len:]} + tts_speech = tts_speech[:, :-self.source_cache_len] + else: + if speed != 1.0: + assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode' + tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear') + tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) + return tts_speech + + class TritonPythonModel: """Triton Python model for vocoder. @@ -166,25 +214,49 @@ class TritonPythonModel: prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE - tts_mel, _ = self.token2wav_model.model.flow.inference( - token=target_speech_tokens, - token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( - self.device - ), - prompt_token=prompt_speech_tokens, - prompt_token_len=torch.tensor( - [prompt_speech_tokens.shape[1]], dtype=torch.int32 - ).to(self.device), - prompt_feat=prompt_speech_feat, - prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=prompt_spk_embedding, - streaming=False, - finalize=True, - ) + # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. + token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset") + if token_offset is not None: + token_offset = token_offset.as_numpy().item() + finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() + if not finalize: + stream = True + else: + stream = False + request_id = request.request_id() + print(f"token_offset: {token_offset}, finalize: {finalize}, request_id: {request_id}") + audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens, + prompt_token=prompt_speech_tokens, + prompt_feat=prompt_speech_feat, + embedding=prompt_spk_embedding, + token_offset=token_offset, + uuid=request_id, + stream=stream, + finalize=finalize) + if finalize: + print(f"dict keys: {self.token2wav_model.model.hift_cache_dict.keys()}") + self.token2wav_model.model.hift_cache_dict.pop(request_id) - audio_hat, _ = self.token2wav_model.model.hift.inference( - speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) - ) + else: + tts_mel, _ = self.token2wav_model.model.flow.inference( + token=target_speech_tokens, + token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( + self.device + ), + prompt_token=prompt_speech_tokens, + prompt_token_len=torch.tensor( + [prompt_speech_tokens.shape[1]], dtype=torch.int32 + ).to(self.device), + prompt_feat=prompt_speech_feat, + prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=prompt_spk_embedding, + streaming=False, + finalize=True, + ) + + audio_hat, _ = self.token2wav_model.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) generated_wave = audio_hat.squeeze(0).cpu().numpy() diff --git a/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt index 36489ff..9ea3b88 100644 --- a/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/token2wav/config.pbtxt @@ -45,6 +45,20 @@ input [ name: "prompt_spk_embedding" data_type: TYPE_FP16 dims: [-1] + }, + { + name: "token_offset" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "finalize" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true } ] output [