diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index afbab68..7aa8d7d 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -59,12 +59,14 @@ import tritonclient.grpc.aio as grpcclient_aio # Renamed original import import tritonclient.grpc as grpcclient_sync # Added sync client import from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException +from datetime import datetime # --- Added UserData and callback --- class UserData: def __init__(self): self._completed_requests = queue.Queue() self._first_chunk_time = None + self._second_chunk_time = None self._start_time = None def record_start_time(self): @@ -75,14 +77,44 @@ class UserData: return self._first_chunk_time - self._start_time return None + def get_second_chunk_latency(self): + if self._first_chunk_time and self._second_chunk_time: + return self._second_chunk_time - self._first_chunk_time + return None + def callback(user_data, result, error): - if user_data._first_chunk_time is None and not error: - user_data._first_chunk_time = time.time() # Record time of first successful chunk + if not error: + if user_data._first_chunk_time is None: + user_data._first_chunk_time = time.time() # Record time of first successful chunk + elif user_data._second_chunk_time is None: + user_data._second_chunk_time = time.time() + if error: user_data._completed_requests.put(error) else: user_data._completed_requests.put(result) + + +def stream_callback(user_data_map, result, error): + request_id = None + if error: + # Note: InferenceServerException doesn't have a public request_id() method in all versions. + # This part might need adjustment depending on the tritonclient library version. + # A more robust way would be to wrap the error with the request_id if possible. + # For now, we assume we can't get request_id from error and it will timeout on the client side. + print(f"An error occurred in the stream callback: {error}") + else: + request_id = result.get_response().id + + if request_id: + user_data = user_data_map.get(request_id) + if user_data: + callback(user_data, result, error) + else: + print(f"Warning: Could not find user_data for request_id {request_id}") + + # --- End Added UserData and callback --- @@ -142,6 +174,68 @@ def write_triton_stats(stats, summary_file): ) +def subtract_stats(stats_after, stats_before): + """Subtracts two Triton inference statistics objects.""" + # Deep copy to avoid modifying the original stats_after + stats_diff = json.loads(json.dumps(stats_after)) + + model_stats_before_map = { + s["name"]: { + "version": s["version"], + "last_inference": s.get("last_inference", 0), + "inference_count": s.get("inference_count", 0), + "execution_count": s.get("execution_count", 0), + "inference_stats": s.get("inference_stats", {}), + "batch_stats": s.get("batch_stats", []), + } + for s in stats_before["model_stats"] + } + + for model_stat_after in stats_diff["model_stats"]: + model_name = model_stat_after["name"] + if model_name in model_stats_before_map: + model_stat_before = model_stats_before_map[model_name] + + # Subtract counts + model_stat_after["inference_count"] = str( + int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0)) + ) + model_stat_after["execution_count"] = str( + int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0)) + ) + + # Subtract aggregate stats (like queue, compute times) + if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before: + for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]: + if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]: + if "ns" in model_stat_after["inference_stats"][key]: + ns_after = int(model_stat_after["inference_stats"][key]["ns"]) + ns_before = int(model_stat_before["inference_stats"][key]["ns"]) + model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before) + if "count" in model_stat_after["inference_stats"][key]: + count_after = int(model_stat_after["inference_stats"][key]["count"]) + count_before = int(model_stat_before["inference_stats"][key]["count"]) + model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before) + + # Subtract batch execution stats + if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before: + batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]} + for batch_stat_after in model_stat_after["batch_stats"]: + bs = batch_stat_after["batch_size"] + if bs in batch_stats_before_map: + batch_stat_before = batch_stats_before_map[bs] + for key in ["compute_input", "compute_infer", "compute_output"]: + if key in batch_stat_after and key in batch_stat_before: + count_after = int(batch_stat_after[key]["count"]) + count_before = int(batch_stat_before[key]["count"]) + batch_stat_after[key]["count"] = str(count_after - count_before) + + ns_after = int(batch_stat_after[key]["ns"]) + ns_before = int(batch_stat_before[key]["ns"]) + batch_stat_after[key]["ns"] = str(ns_after - ns_before) + return stats_diff + + def get_args(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -357,10 +451,10 @@ def run_sync_streaming_inference( """Helper function to run the blocking sync streaming call.""" start_time_total = time.time() user_data.record_start_time() # Record start time for first chunk latency calculation + # e.g. 08:47:34.827758 - # Establish stream - sync_triton_client.start_stream(callback=functools.partial(callback, user_data)) - + print(f"Record start time in human readable: {datetime.now()}") + # input() # Send request sync_triton_client.async_stream_infer( model_name, @@ -374,11 +468,11 @@ def run_sync_streaming_inference( audios = [] while True: try: - result = user_data._completed_requests.get() # Add timeout + result = user_data._completed_requests.get(timeout=20) # Add timeout if isinstance(result, InferenceServerException): print(f"Received InferenceServerException: {result}") - sync_triton_client.stop_stream() - return None, None, None # Indicate error + # Don't stop the stream here, just return error + return None, None, None, None # Get response metadata response = result.get_response() final = response.parameters["triton_final_response"].bool_param @@ -393,13 +487,13 @@ def run_sync_streaming_inference( except queue.Empty: print(f"Timeout waiting for response for request id {request_id}") - sync_triton_client.stop_stream() - return None, None, None # Indicate error + # Don't stop stream here, just return error + return None, None, None, None - sync_triton_client.stop_stream() end_time_total = time.time() total_request_latency = end_time_total - start_time_total first_chunk_latency = user_data.get_first_chunk_latency() + second_chunk_latency = user_data.get_second_chunk_latency() # Reconstruct audio using cross-fade (from client_grpc_streaming.py) actual_duration = 0 @@ -448,7 +542,7 @@ def run_sync_streaming_inference( print("Warning: No audio chunks received.") actual_duration = 0 - return total_request_latency, first_chunk_latency, actual_duration + return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration async def send_streaming( @@ -468,10 +562,12 @@ async def send_streaming( latency_data = [] task_id = int(name[5:]) sync_triton_client = None # Initialize client variable + user_data_map = {} try: # Wrap in try...finally to ensure client closing print(f"{name}: Initializing sync client for streaming...") sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here + sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map)) print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.") for i, item in enumerate(manifest_item_list): @@ -494,10 +590,11 @@ async def send_streaming( request_id = str(uuid.uuid4()) user_data = UserData() + user_data_map[request_id] = user_data audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") - - total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread( + print("target_text: ", target_text, "time: ", datetime.now()) + total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread( run_sync_streaming_inference, sync_triton_client, model_name, @@ -511,12 +608,18 @@ async def send_streaming( ) if total_request_latency is not None: - print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s") - latency_data.append((total_request_latency, first_chunk_latency, actual_duration)) + print( + f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, " + f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, " + f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s" + ) + latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration)) total_duration += actual_duration else: print(f"{name}: Item {i} failed.") + del user_data_map[request_id] + except FileNotFoundError: print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}") except Exception as e: @@ -527,7 +630,8 @@ async def send_streaming( finally: # Ensure client is closed if sync_triton_client: try: - print(f"{name}: Closing sync client...") + print(f"{name}: Closing stream and sync client...") + sync_triton_client.stop_stream() sync_triton_client.close() except Exception as e: print(f"{name}: Error closing sync client: {e}") @@ -685,9 +789,22 @@ async def main(): "target_text": dataset[i]["target_text"], } ) + # manifest_item_list = manifest_item_list[:4] else: manifest_item_list = load_manifests(args.manifest_path) + # --- Statistics Fetching (Before) --- + stats_client = None + stats_before = None + try: + print("Initializing temporary async client for fetching stats...") + stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) + print("Fetching inference statistics before running tasks...") + stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True) + except Exception as e: + print(f"Could not retrieve statistics before running tasks: {e}") + # --- End Statistics Fetching (Before) --- + num_tasks = min(args.num_tasks, len(manifest_item_list)) manifest_item_list = split_data(manifest_item_list, num_tasks) @@ -776,8 +893,9 @@ async def main(): elif args.mode == "streaming": # Calculate stats for total request latency and first chunk latency - total_latency_list = [total for (total, first, duration) in latency_data if total is not None] - first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None] + total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None] + first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None] + second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None] s += "\n--- Total Request Latency ---\n" if total_latency_list: @@ -804,6 +922,19 @@ async def main(): s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n" else: s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n" + + s += "\n--- Second Chunk Latency ---\n" + if second_chunk_latency_list: + avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0 + variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0 + s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n" + s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n" + s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n" + s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n" + s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n" + s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n" + else: + s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n" else: s += "No latency data collected.\n" # --- End Statistics Reporting --- @@ -822,20 +953,23 @@ async def main(): # --- Statistics Fetching using temporary Async Client --- # Use a separate async client for fetching stats regardless of mode - stats_client = None try: - print("Initializing temporary async client for fetching stats...") - stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) - print("Fetching inference statistics...") - # Fetching for all models, filtering might be needed depending on server setup - stats = await stats_client.get_inference_statistics(model_name="", as_json=True) - print("Fetching model config...") - metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True) + if stats_client and stats_before: + print("Fetching inference statistics after running tasks...") + stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True) - write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") + print("Calculating statistics difference...") + stats = subtract_stats(stats_after, stats_before) - with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: - json.dump(metadata, f, indent=4) + print("Fetching model config...") + metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True) + + write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") + + with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: + json.dump(metadata, f, indent=4) + else: + print("Stats client not available or initial stats were not fetched. Skipping stats reporting.") except Exception as e: print(f"Could not retrieve statistics or config: {e}") diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py index c472968..2f81786 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py @@ -109,7 +109,6 @@ class TritonPythonModel: spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) self.default_spk_info = spk_info["001"] self.http_client = httpx.AsyncClient() - self.runtime_cache = {} def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: """Converts a tensor or list of speech token IDs to a string representation.""" @@ -264,38 +263,11 @@ class TritonPythonModel: finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] - # optional cache inputs - if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None: - # inputs_tensor.extend([ - # pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()), - # pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()), - # pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()), - # pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()), - # pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()), - # pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()), - # pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()), - # ]) - inputs_tensor.extend([ - self.runtime_cache[request_id]["conformer_cnn_cache"], - self.runtime_cache[request_id]["conformer_att_cache"], - self.runtime_cache[request_id]["estimator_cnn_cache"], - self.runtime_cache[request_id]["estimator_att_cache"], - self.runtime_cache[request_id]["mel"], - self.runtime_cache[request_id]["source"], - self.runtime_cache[request_id]["speech"], - ]) # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav_dit', requested_output_names=[ "waveform", - "conformer_cnn_cache", - "conformer_att_cache", - "estimator_cnn_cache", - "estimator_att_cache", - "mel", - "source", - "speech", ], inputs=inputs_tensor, request_id=request_id, @@ -306,14 +278,6 @@ class TritonPythonModel: if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) - self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache") - self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache") - self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache") - self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache") - self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel") - self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source") - self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech") - # Extract and convert output waveform waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() @@ -339,16 +303,6 @@ class TritonPythonModel: async def _process_request(self, request): request_id = request.request_id() - if request_id not in self.runtime_cache: - self.runtime_cache[request_id] = { - "conformer_cnn_cache": None, - "conformer_att_cache": None, - "estimator_cnn_cache": None, - "estimator_att_cache": None, - "mel": None, - "source": None, - "speech": None, - } # Extract input tensors wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") @@ -369,7 +323,7 @@ class TritonPythonModel: reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') - prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) # reference_text = self.default_spk_info["prompt_text"] # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE @@ -453,9 +407,7 @@ class TritonPythonModel: audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) - if request_id in self.runtime_cache: - del self.runtime_cache[request_id] - self.logger.log_info(f"Deleted cache for request_id: {request_id}") + response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) self.logger.log_info("send tritonserver_response_complete_final to end") else: diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt index b119227..e64647e 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt @@ -31,7 +31,7 @@ parameters [ value: {string_value:"${model_dir}"} } ] -parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}} + input [ { name: "reference_wav" diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index e95ce99..230bad0 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -103,91 +103,47 @@ class TritonPythonModel: List of inference responses containing generated waveforms """ responses = [] + # Process each request in batch for request in requests: - request_id = request.request_id() - - # Get inputs - target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens") - target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack()) + target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() + target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device) + # shift the speech tokens according to the original vocab size + target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE target_speech_tokens = target_speech_tokens.squeeze().tolist() + # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. + finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() - wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy() - wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item() - wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0) + + request_id = request.request_id() + + + wav_array = pb_utils.get_input_tensor_by_name( + request, "reference_wav").as_numpy() + wav_len = pb_utils.get_input_tensor_by_name( + request, "reference_wav_len").as_numpy().item() + + wav_array = torch.from_numpy(wav_array) + # Prepare inputs + wav = wav_array[:, :wav_len].squeeze(0) + spk_id = get_spk_id_from_prompt_audio(wav) + # wav = wav.to(self.device) - # Handle cache - conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache") - if conformer_cnn_cache is not None: - self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack()) - - conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache") - self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1) - - estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache") - self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0) + # update cache before forward + # self.token2wav_model.streaming_flow_cache[request_id] + # self.token2wav_model.hift_cache_dict[request_id] - estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache") - self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0) + audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) - mel_np = pb_utils.get_input_tensor_by_name(request, "mel") - self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack()) - - source_np = pb_utils.get_input_tensor_by_name(request, "source") - self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack()) - - speech_np = pb_utils.get_input_tensor_by_name(request, "speech") - self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack()) - - # Forward pass - audio_hat = self.token2wav_model.forward_streaming( - target_speech_tokens, - finalize, - request_id=request_id, - speaker_id=f"{spk_id}", - prompt_audio=wav, - prompt_audio_sample_rate=16000 - ) - - # Prepare outputs + # get the cache after forward outputs = [] + + generated_wave = audio_hat.squeeze(0).cpu().numpy() + wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) outputs.append(wav_tensor) - - if request_id in self.token2wav_model.streaming_flow_cache: - cache = self.token2wav_model.streaming_flow_cache[request_id] - hifigan_cache = self.token2wav_model.hift_cache_dict[request_id] - conformer_cnn_cache = cache['conformer_cnn_cache'] - conformer_att_cache = cache['conformer_att_cache'].transpose(0,1) - estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0) - estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0) - mel = hifigan_cache['mel'] - source = hifigan_cache['source'] - speech = hifigan_cache['speech'] - - outputs.extend([ - pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)), - pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)), - pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)), - pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)), - pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)), - pb_utils.Tensor.from_dlpack("source", to_dlpack(source)), - pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)), - ]) - else: - outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)), - pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)), - pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)), - pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)), - pb_utils.Tensor("mel", np.array([], dtype=np.float32)), - pb_utils.Tensor("source", np.array([], dtype=np.float32)), - pb_utils.Tensor("speech", np.array([], dtype=np.float32)), - ]) - inference_response = pb_utils.InferenceResponse(output_tensors=outputs) responses.append(inference_response) - return responses - def finalize(self): - self.logger.log_info("Finalizing Token2WavDiT model") + return responses diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt index aed7561..3f579aa 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt @@ -22,7 +22,6 @@ dynamic_batching { default_priority_level: 10 } -parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}} parameters [ { key: "model_dir", @@ -52,48 +51,6 @@ input [ dims: [ 1 ] reshape: { shape: [ ] } optional: true - }, - { - name: "conformer_cnn_cache" - data_type: TYPE_FP16 - dims: [ 512, -1 ] - optional: true - }, - { - name: "conformer_att_cache" - data_type: TYPE_FP16 - dims: [ 10, 8, -1, 128 ] - optional: true - }, - { - name: "estimator_cnn_cache" - data_type: TYPE_FP16 - dims: [ 10, 16, -1, 1024, 2 ] - optional: true - }, - { - name: "estimator_att_cache" - data_type: TYPE_FP16 - dims: [ 10, 16, -1, 8, -1, 128 ] - optional: true - }, - { - name: "mel" - data_type: TYPE_FP32 - dims: [ 80, -1 ] - optional: true - }, - { - name: "source" - data_type: TYPE_FP32 - dims: [ 1, -1 ] - optional: true - }, - { - name: "speech" - data_type: TYPE_FP32 - dims: [ -1 ] - optional: true } ] output [ @@ -101,41 +58,6 @@ output [ name: "waveform" data_type: TYPE_FP32 dims: [ -1 ] - }, - { - name: "conformer_cnn_cache" - data_type: TYPE_FP16 - dims: [ 512, -1 ] - }, - { - name: "conformer_att_cache" - data_type: TYPE_FP16 - dims: [ 10, 8, -1, 128 ] - }, - { - name: "estimator_cnn_cache" - data_type: TYPE_FP16 - dims: [ 10, 16, -1, 1024, 2 ] - }, - { - name: "estimator_att_cache" - data_type: TYPE_FP16 - dims: [ 10, 16, -1, 8, -1, 128 ] - }, - { - name: "mel" - data_type: TYPE_FP32 - dims: [ 80, -1 ] - }, - { - name: "source" - data_type: TYPE_FP32 - dims: [ 1, -1 ] - }, - { - name: "speech" - data_type: TYPE_FP32 - dims: [ -1 ] } ] diff --git a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh index ad3407e..2eabcf4 100644 --- a/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh +++ b/runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh @@ -1,6 +1,6 @@ #!/bin/bash # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) -export CUDA_VISIBLE_DEVICES=1 +export CUDA_VISIBLE_DEVICES=0 cosyvoice_path=/workspace/CosyVoice cosyvoice_path=/workspace_yuekai/tts/CosyVoice stepaudio2_path=/workspace_yuekai/tts/Step-Audio2 @@ -112,7 +112,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then MODEL_DIR=$model_scope_model_local_dir LLM_TOKENIZER_DIR=$huggingface_model_local_dir BLS_INSTANCE_NUM=4 - TRITON_MAX_BATCH_SIZE=32 + TRITON_MAX_BATCH_SIZE=1 DECOUPLED_MODE=True # True for streaming, False for offline STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav @@ -154,7 +154,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --num-tasks $num_task \ --mode $mode \ --huggingface-dataset yuekai/seed_tts_cosy2 \ - --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_no_att_cnn_cache_new + --log-dir ./log_debug_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM} fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then @@ -185,14 +185,14 @@ fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - python3 streaming_inference.py + CUDA_VISIBLE_DEVICES=2 python3 streaming_inference.py --enable-trt --strategy exponential fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 + CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4 fi diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index 863358c..93c6758 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -31,6 +31,7 @@ def get_args(): parser.add_argument("--output-dir", type=str, default="generated_wavs") parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts") parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2") + parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"]) return parser.parse_args() @@ -53,12 +54,14 @@ if __name__ == "__main__": token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) flow_pre_lookahead_len = 3 - CHUNK_SIZE = 25 + CHUNK_SIZE = 15 + token_frame_rate = 25 OVERLAP_SIZE = 0 warmup_times = 3 for _ in range(warmup_times): start_time = time.time() + total_forward_count = 0 for batch in data_loader: tts_speech_list = [] ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch @@ -83,17 +86,26 @@ if __name__ == "__main__": buffer = generated_speech_tokens output_wavs = [] + chunk_index = 0 while True: + if args.strategy == "equal": + this_chunk_size = CHUNK_SIZE + elif args.strategy == "exponential": + this_chunk_size = token_frame_rate * (2 ** chunk_index) - if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len: - wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) - buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:] + if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len: + wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) + buffer = buffer[this_chunk_size - OVERLAP_SIZE:] output_wavs.append(wavs) + total_forward_count += 1 + chunk_index += 1 else: wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate) output_wavs.append(wavs) + total_forward_count += 1 + # chunk_index += 1 break for i, wav in enumerate(output_wavs): @@ -112,4 +124,4 @@ if __name__ == "__main__": if _ == 0: token2wav_model.speaker_cache = {} print(f"Warmup time: {end_time - start_time} seconds") - + print(f"Total forward count: {total_forward_count}")