diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py index b4a6348..c472968 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py @@ -43,6 +43,7 @@ import torchaudio from matcha.utils.audio import mel_spectrogram +from datetime import datetime ORIGINAL_VOCAB_SIZE = 151663 torch.set_num_threads(1) @@ -86,6 +87,7 @@ class TritonPythonModel: model_params = {k: v["string_value"] for k, v in parameters.items()} self.logger.log_info(f"model_params:{model_params}") self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" + # self.dynamic_chunk_strategy = "equal" self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") # Initialize tokenizer @@ -105,7 +107,9 @@ class TritonPythonModel: if not os.path.exists(spk_info_path): raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) - # self.default_spk_info = spk_info["001"] + self.default_spk_info = spk_info["001"] + self.http_client = httpx.AsyncClient() + self.runtime_cache = {} def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: """Converts a tensor or list of speech token IDs to a string representation.""" @@ -131,7 +135,6 @@ class TritonPythonModel: {"role": "user", "content": full_text}, {"role": "assistant", "content": prompt_speech_tokens_str} ] - print(chat) payload = { "model": "trt_engines_bfloat16", @@ -148,31 +151,33 @@ class TritonPythonModel: api_base = "http://localhost:8000/v1/chat/completions" buffer = "" - async with httpx.AsyncClient() as client: - async with client.stream("POST", api_base, json=payload, timeout=None) as response: - response.raise_for_status() - async for line in response.aiter_lines(): - if line.startswith("data: "): - line_data = line[len("data: "):].strip() - if line_data == "[DONE]": - break - try: - json_data = json.loads(line_data) - content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") - if content: - buffer += content - while True: - match = re.search(r"<\|s_(\d+)\|>", buffer) - if not match: - break + async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response: + print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}") + print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}") + response.raise_for_status() + async for line in response.aiter_lines(): + if line.startswith("data: "): + line_data = line[len("data: "):].strip() + if line_data == "[DONE]": + break + try: + json_data = json.loads(line_data) + content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") + if content: + buffer += content + print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}") + while True: + match = re.search(r"<\|s_(\d+)\|>", buffer) + if not match: + break - token_num = int(match.group(1)) - final_id = token_num + ORIGINAL_VOCAB_SIZE - yield final_id - buffer = buffer[match.end():] - except json.JSONDecodeError: - self.logger.log_info(f"Skipping non-JSON line: {line_data}") - continue + token_num = int(match.group(1)) + final_id = token_num + ORIGINAL_VOCAB_SIZE + yield final_id + buffer = buffer[match.end():] + except json.JSONDecodeError: + self.logger.log_info(f"Skipping non-JSON line: {line_data}") + continue # Process any remaining complete tokens in the buffer after the stream ends while True: @@ -236,7 +241,7 @@ class TritonPythonModel: return prompt_spk_embedding - def forward_token2wav( + async def forward_token2wav( self, index: int, target_speech_tokens: torch.Tensor, @@ -258,20 +263,57 @@ class TritonPythonModel: target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] - + + # optional cache inputs + if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None: + # inputs_tensor.extend([ + # pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()), + # pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()), + # pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()), + # pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()), + # pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()), + # pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()), + # pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()), + # ]) + inputs_tensor.extend([ + self.runtime_cache[request_id]["conformer_cnn_cache"], + self.runtime_cache[request_id]["conformer_att_cache"], + self.runtime_cache[request_id]["estimator_cnn_cache"], + self.runtime_cache[request_id]["estimator_att_cache"], + self.runtime_cache[request_id]["mel"], + self.runtime_cache[request_id]["source"], + self.runtime_cache[request_id]["speech"], + ]) # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav_dit', - requested_output_names=['waveform'], + requested_output_names=[ + "waveform", + "conformer_cnn_cache", + "conformer_att_cache", + "estimator_cnn_cache", + "estimator_att_cache", + "mel", + "source", + "speech", + ], inputs=inputs_tensor, request_id=request_id, parameters={"priority": index+1}, ) - inference_response = inference_request.exec() + inference_response = await inference_request.async_exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) + self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache") + self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache") + self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache") + self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache") + self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel") + self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source") + self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech") + # Extract and convert output waveform waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() @@ -297,6 +339,16 @@ class TritonPythonModel: async def _process_request(self, request): request_id = request.request_id() + if request_id not in self.runtime_cache: + self.runtime_cache[request_id] = { + "conformer_cnn_cache": None, + "conformer_att_cache": None, + "estimator_cnn_cache": None, + "estimator_att_cache": None, + "mel": None, + "source": None, + "speech": None, + } # Extract input tensors wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") @@ -308,6 +360,7 @@ class TritonPythonModel: wav_tensor = wav.as_numpy() wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}") prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) speech_feat = self._extract_speech_feat(prompt_speech_resample) token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) @@ -316,7 +369,7 @@ class TritonPythonModel: reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') - # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) # reference_text = self.default_spk_info["prompt_text"] # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE @@ -333,6 +386,7 @@ class TritonPythonModel: target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode('utf-8') + print(f"target_text: {target_text}, time: {datetime.now()}") if self.decoupled: response_sender = request.get_response_sender() @@ -341,7 +395,7 @@ class TritonPythonModel: token_offset, chunk_index = 0, 0 start_time = time.time() this_token_hop_len = self.token_hop_len - + print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}") async for generated_ids in self.forward_llm_async( target_text=target_text, reference_text=reference_text, @@ -350,18 +404,18 @@ class TritonPythonModel: if not generated_ids: break semantic_token_ids_arr.append(generated_ids) - + print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}") while True: pending_num = len(semantic_token_ids_arr) - token_offset if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) - - sub_tts_speech = self.forward_token2wav( + print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}") + sub_tts_speech = await self.forward_token2wav( chunk_index, this_tts_speech_token, request_id, wav, wav_len, False ) - + print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}") audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) @@ -371,6 +425,8 @@ class TritonPythonModel: if self.dynamic_chunk_strategy == "exponential": this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) + elif self.dynamic_chunk_strategy == "equal": + this_token_hop_len = self.token_hop_len elif self.dynamic_chunk_strategy == "time_based": # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 cost_time = time.time() - start_time @@ -393,29 +449,13 @@ class TritonPythonModel: break this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) - sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) + sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) - - ## debug - ## save semantic_token_ids_arr and reference_text, target_text to a single json file - # save into a torch .pt - # for i, item in enumerate(semantic_token_ids_arr): - # semantic_token_ids_arr[i] = item - ORIGINAL_VOCAB_SIZE - # import json - # data = { - # "semantic_token_ids_arr": semantic_token_ids_arr, - # "reference_text": reference_text, - # "target_text": target_text - # } - # with open(f"semantic_token_ids_arr_debug_{request_id}.pt", "wb") as f: - # torch.save(data, f) - # with open(f"semantic_token_ids_arr_debug_{request_id}.json", "w") as f: - # json.dump(data, f) - - # ## - + if request_id in self.runtime_cache: + del self.runtime_cache[request_id] + self.logger.log_info(f"Deleted cache for request_id: {request_id}") response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) self.logger.log_info("send tritonserver_response_complete_final to end") else: @@ -436,3 +476,8 @@ class TritonPythonModel: ] await asyncio.gather(*tasks) return None + + def finalize(self): + self.logger.log_info("Finalizing CosyVoice DIT model") + if hasattr(self, "http_client"): + asyncio.run(self.http_client.aclose()) diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt index e64647e..b119227 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt @@ -31,7 +31,7 @@ parameters [ value: {string_value:"${model_dir}"} } ] - +parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}} input [ { name: "reference_wav" diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index 8f9ffba..e95ce99 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -103,39 +103,91 @@ class TritonPythonModel: List of inference responses containing generated waveforms """ responses = [] - # Process each request in batch for request in requests: - target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() - target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device) - # shift the speech tokens according to the original vocab size - target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE + request_id = request.request_id() + + # Get inputs + target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens") + target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack()) target_speech_tokens = target_speech_tokens.squeeze().tolist() - # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. - finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() - - request_id = request.request_id() - - - wav_array = pb_utils.get_input_tensor_by_name( - request, "reference_wav").as_numpy() - wav_len = pb_utils.get_input_tensor_by_name( - request, "reference_wav_len").as_numpy().item() - - wav_array = torch.from_numpy(wav_array) - # Prepare inputs - wav = wav_array[:, :wav_len].squeeze(0) - + wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy() + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item() + wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0) spk_id = get_spk_id_from_prompt_audio(wav) - # wav = wav.to(self.device) - audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) + # Handle cache + conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache") + if conformer_cnn_cache is not None: + self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack()) + + conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache") + self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1) + + estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache") + self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0) - generated_wave = audio_hat.squeeze(0).cpu().numpy() + estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache") + self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0) + mel_np = pb_utils.get_input_tensor_by_name(request, "mel") + self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack()) + + source_np = pb_utils.get_input_tensor_by_name(request, "source") + self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack()) + + speech_np = pb_utils.get_input_tensor_by_name(request, "speech") + self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack()) + + # Forward pass + audio_hat = self.token2wav_model.forward_streaming( + target_speech_tokens, + finalize, + request_id=request_id, + speaker_id=f"{spk_id}", + prompt_audio=wav, + prompt_audio_sample_rate=16000 + ) + + # Prepare outputs + outputs = [] wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) - inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) - responses.append(inference_response) + outputs.append(wav_tensor) + + if request_id in self.token2wav_model.streaming_flow_cache: + cache = self.token2wav_model.streaming_flow_cache[request_id] + hifigan_cache = self.token2wav_model.hift_cache_dict[request_id] + conformer_cnn_cache = cache['conformer_cnn_cache'] + conformer_att_cache = cache['conformer_att_cache'].transpose(0,1) + estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0) + estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0) + mel = hifigan_cache['mel'] + source = hifigan_cache['source'] + speech = hifigan_cache['speech'] + outputs.extend([ + pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)), + pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)), + pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)), + pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)), + pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)), + pb_utils.Tensor.from_dlpack("source", to_dlpack(source)), + pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)), + ]) + else: + outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)), + pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)), + pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)), + pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)), + pb_utils.Tensor("mel", np.array([], dtype=np.float32)), + pb_utils.Tensor("source", np.array([], dtype=np.float32)), + pb_utils.Tensor("speech", np.array([], dtype=np.float32)), + ]) + + inference_response = pb_utils.InferenceResponse(output_tensors=outputs) + responses.append(inference_response) return responses + + def finalize(self): + self.logger.log_info("Finalizing Token2WavDiT model") diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index 3b696e9..63dce14 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -372,7 +372,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000 ): if speaker_id not in self.speaker_cache: - # if 1: assert prompt_audio is not None, "prompt_audio is required for new speaker" assert prompt_audio_sample_rate == 16000 @@ -384,20 +383,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module): prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow} - # if speaker_id not in self.speaker_cache: - # if 1: - cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} print(f"speaker_id {speaker_id} added to cache") - # get a clone of cache dict ['estimator_att_cache'] and later check if it would be change - att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'].clone() - cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'].clone() - conformer_cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'].clone() - conformer_att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'].clone() - - if request_id not in self.streaming_flow_cache: self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} self.hift_cache_dict[request_id] = dict( @@ -405,6 +394,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module): source = torch.zeros(1, 1, 0, device='cuda'), speech = torch.zeros(1, 0, device='cuda'), ) + # else: + # for k, v in self.streaming_flow_cache[request_id].items(): + # print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}") + # for k, v in self.hift_cache_dict[request_id].items(): + # print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}") + # breakpoint() current_request_cache = self.streaming_flow_cache[request_id] @@ -420,33 +415,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): n_timesteps=10, ) - # get the original att_cache - original_att_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'] - original_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'] - original_conformer_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'] - original_conformer_att_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'] - if not torch.allclose(original_att_cache, att_cache_clone): - print("att_cache changed") - # print the last 10 elements of original_att_cache and att_cache_clone - print(original_att_cache[:, :, :, -10:]) - print(att_cache_clone[:, :, :, -10:]) - breakpoint() - if not torch.allclose(original_cnn_cache, cnn_cache_clone): - print("cnn_cache changed") - print(original_cnn_cache[..., -10:]) - print(cnn_cache_clone[..., -10:]) - breakpoint() - if not torch.allclose(original_conformer_cnn_cache, conformer_cnn_cache_clone): - print("conformer_cnn_cache changed") - print(original_conformer_cnn_cache[..., -10:]) - print(conformer_cnn_cache_clone[..., -10:]) - breakpoint() - if not torch.allclose(original_conformer_att_cache, conformer_att_cache_clone): - print("conformer_att_cache changed") - print(original_conformer_att_cache[..., -10:]) - print(conformer_att_cache_clone[..., -10:]) - breakpoint() - self.streaming_flow_cache[request_id] = new_streaming_flow_cache @@ -482,7 +450,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module): assert request_id in self.streaming_flow_cache self.streaming_flow_cache.pop(request_id) self.hift_cache_dict.pop(request_id) - # breakpoint() + return speech def collate_fn(batch): diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt index 2040cfe..aed7561 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt @@ -15,11 +15,14 @@ name: "token2wav_dit" backend: "python" max_batch_size: ${triton_max_batch_size} + dynamic_batching { max_queue_delay_microseconds: ${max_queue_delay_microseconds} priority_levels: 10 default_priority_level: 10 } + +parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}} parameters [ { key: "model_dir", @@ -49,6 +52,48 @@ input [ dims: [ 1 ] reshape: { shape: [ ] } optional: true + }, + { + name: "conformer_cnn_cache" + data_type: TYPE_FP16 + dims: [ 512, -1 ] + optional: true + }, + { + name: "conformer_att_cache" + data_type: TYPE_FP16 + dims: [ 10, 8, -1, 128 ] + optional: true + }, + { + name: "estimator_cnn_cache" + data_type: TYPE_FP16 + dims: [ 10, 16, -1, 1024, 2 ] + optional: true + }, + { + name: "estimator_att_cache" + data_type: TYPE_FP16 + dims: [ 10, 16, -1, 8, -1, 128 ] + optional: true + }, + { + name: "mel" + data_type: TYPE_FP32 + dims: [ 80, -1 ] + optional: true + }, + { + name: "source" + data_type: TYPE_FP32 + dims: [ 1, -1 ] + optional: true + }, + { + name: "speech" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true } ] output [ @@ -56,6 +101,41 @@ output [ name: "waveform" data_type: TYPE_FP32 dims: [ -1 ] + }, + { + name: "conformer_cnn_cache" + data_type: TYPE_FP16 + dims: [ 512, -1 ] + }, + { + name: "conformer_att_cache" + data_type: TYPE_FP16 + dims: [ 10, 8, -1, 128 ] + }, + { + name: "estimator_cnn_cache" + data_type: TYPE_FP16 + dims: [ 10, 16, -1, 1024, 2 ] + }, + { + name: "estimator_att_cache" + data_type: TYPE_FP16 + dims: [ 10, 16, -1, 8, -1, 128 ] + }, + { + name: "mel" + data_type: TYPE_FP32 + dims: [ 80, -1 ] + }, + { + name: "source" + data_type: TYPE_FP32 + dims: [ 1, -1 ] + }, + { + name: "speech" + data_type: TYPE_FP32 + dims: [ -1 ] } ]