diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 7aa8d7d..718fe86 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -43,9 +43,9 @@ python3 client_grpc.py \ import argparse import asyncio import json -import queue # Added -import uuid # Added -import functools # Added +import queue +import uuid +import functools import os import time @@ -55,13 +55,11 @@ from pathlib import Path import numpy as np import soundfile as sf import tritonclient -import tritonclient.grpc.aio as grpcclient_aio # Renamed original import -import tritonclient.grpc as grpcclient_sync # Added sync client import -from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException +import tritonclient.grpc.aio as grpcclient_aio +import tritonclient.grpc as grpcclient_sync +from tritonclient.utils import np_to_triton_dtype, InferenceServerException -from datetime import datetime -# --- Added UserData and callback --- class UserData: def __init__(self): self._completed_requests = queue.Queue() @@ -86,7 +84,7 @@ class UserData: def callback(user_data, result, error): if not error: if user_data._first_chunk_time is None: - user_data._first_chunk_time = time.time() # Record time of first successful chunk + user_data._first_chunk_time = time.time() elif user_data._second_chunk_time is None: user_data._second_chunk_time = time.time() @@ -99,10 +97,6 @@ def callback(user_data, result, error): def stream_callback(user_data_map, result, error): request_id = None if error: - # Note: InferenceServerException doesn't have a public request_id() method in all versions. - # This part might need adjustment depending on the tritonclient library version. - # A more robust way would be to wrap the error with the request_id if possible. - # For now, we assume we can't get request_id from error and it will timeout on the client side. print(f"An error occurred in the stream callback: {error}") else: request_id = result.get_response().id @@ -115,31 +109,9 @@ def stream_callback(user_data_map, result, error): print(f"Warning: Could not find user_data for request_id {request_id}") -# --- End Added UserData and callback --- - - def write_triton_stats(stats, summary_file): with open(summary_file, "w") as summary_f: model_stats = stats["model_stats"] - # write a note, the log is from triton_client.get_inference_statistics(), to better human readability - summary_f.write( - "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n" - ) - summary_f.write("To learn more about the log, please refer to: \n") - summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n") - summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n") - summary_f.write( - "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n" - ) - summary_f.write( - "However, there is a trade-off between the increased queue time and the increased batch size. \n" - ) - summary_f.write( - "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n" - ) - summary_f.write( - "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n" - ) for model_state in model_stats: if "last_inference" not in model_state: continue @@ -150,7 +122,7 @@ def write_triton_stats(stats, summary_file): total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 summary_f.write( - f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa + f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" ) model_batch_stats = model_state["batch_stats"] for batch in model_batch_stats: @@ -164,19 +136,18 @@ def write_triton_stats(stats, summary_file): compute_input_time_ms = int(compute_input["ns"]) / 1e6 compute_output_time_ms = int(compute_output["ns"]) / 1e6 summary_f.write( - f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa + f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" ) summary_f.write( - f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa + f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " ) summary_f.write( - f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa + f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" ) def subtract_stats(stats_after, stats_before): """Subtracts two Triton inference statistics objects.""" - # Deep copy to avoid modifying the original stats_after stats_diff = json.loads(json.dumps(stats_after)) model_stats_before_map = { @@ -196,7 +167,6 @@ def subtract_stats(stats_after, stats_before): if model_name in model_stats_before_map: model_stat_before = model_stats_before_map[model_name] - # Subtract counts model_stat_after["inference_count"] = str( int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0)) ) @@ -204,7 +174,6 @@ def subtract_stats(stats_after, stats_before): int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0)) ) - # Subtract aggregate stats (like queue, compute times) if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before: for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]: if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]: @@ -217,7 +186,6 @@ def subtract_stats(stats_after, stats_before): count_before = int(model_stat_before["inference_stats"][key]["count"]) model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before) - # Subtract batch execution stats if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before: batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]} for batch_stat_after in model_stat_after["batch_stats"]: @@ -338,7 +306,6 @@ def get_args(): help="log directory", ) - # --- Added arguments --- parser.add_argument( "--mode", type=str, @@ -379,39 +346,33 @@ def load_audio(wav_path, target_sample_rate=16000): def prepare_request_input_output( - protocol_client, # Can be grpcclient_aio or grpcclient_sync + protocol_client, waveform, reference_text, target_text, sample_rate=16000, - padding_duration: int = None, # Optional padding for offline mode + padding_duration: int = None, use_spk2info_cache: bool = False ): """Prepares inputs for Triton inference (offline or streaming).""" assert len(waveform.shape) == 1, "waveform should be 1D" lengths = np.array([[len(waveform)]], dtype=np.int32) - # Apply padding only if padding_duration is provided (for offline) if padding_duration: duration = len(waveform) / sample_rate - # Estimate target duration based on text length ratio (crude estimation) - # Avoid division by zero if reference_text is empty if reference_text: estimated_target_duration = duration / len(reference_text) * len(target_text) else: - estimated_target_duration = duration # Assume target duration similar to reference if no text + estimated_target_duration = duration - # Calculate required samples based on estimated total duration required_total_samples = padding_duration * sample_rate * ( (int(estimated_target_duration + duration) // padding_duration) + 1 ) samples = np.zeros((1, required_total_samples), dtype=np.float32) samples[0, : len(waveform)] = waveform else: - # No padding for streaming or if padding_duration is None samples = waveform.reshape(1, -1).astype(np.float32) - # Common input creation logic inputs = [ protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)), protocol_client.InferInput( @@ -450,12 +411,8 @@ def run_sync_streaming_inference( ): """Helper function to run the blocking sync streaming call.""" start_time_total = time.time() - user_data.record_start_time() # Record start time for first chunk latency calculation - # e.g. 08:47:34.827758 + user_data.record_start_time() - print(f"Record start time in human readable: {datetime.now()}") - # input() - # Send request sync_triton_client.async_stream_infer( model_name, inputs, @@ -464,30 +421,26 @@ def run_sync_streaming_inference( enable_empty_final_response=True, ) - # Process results audios = [] while True: try: - result = user_data._completed_requests.get(timeout=20) # Add timeout + result = user_data._completed_requests.get(timeout=20) if isinstance(result, InferenceServerException): print(f"Received InferenceServerException: {result}") - # Don't stop the stream here, just return error return None, None, None, None - # Get response metadata response = result.get_response() final = response.parameters["triton_final_response"].bool_param if final is True: break audio_chunk = result.as_numpy("waveform").reshape(-1) - if audio_chunk.size > 0: # Only append non-empty chunks + if audio_chunk.size > 0: audios.append(audio_chunk) else: print("Warning: received empty audio chunk.") except queue.Empty: print(f"Timeout waiting for response for request id {request_id}") - # Don't stop stream here, just return error return None, None, None, None end_time_total = time.time() @@ -495,47 +448,36 @@ def run_sync_streaming_inference( first_chunk_latency = user_data.get_first_chunk_latency() second_chunk_latency = user_data.get_second_chunk_latency() - # Reconstruct audio using cross-fade (from client_grpc_streaming.py) - actual_duration = 0 if audios: - # Only spark_tts model uses cross-fade if model_name == "spark_tts": cross_fade_samples = int(chunk_overlap_duration * save_sample_rate) fade_out = np.linspace(1, 0, cross_fade_samples) fade_in = np.linspace(0, 1, cross_fade_samples) reconstructed_audio = None - # Simplified reconstruction based on client_grpc_streaming.py if not audios: print("Warning: No audio chunks received.") - reconstructed_audio = np.array([], dtype=np.float32) # Empty array + reconstructed_audio = np.array([], dtype=np.float32) elif len(audios) == 1: reconstructed_audio = audios[0] else: - reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap + reconstructed_audio = audios[0][:-cross_fade_samples] for i in range(1, len(audios)): - # Cross-fade section cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + audios[i - 1][-cross_fade_samples:] * fade_out) - # Middle section of the current chunk middle_part = audios[i][cross_fade_samples:-cross_fade_samples] - # Concatenate reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) - # Add the last part of the final chunk reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) if reconstructed_audio is not None and reconstructed_audio.size > 0: actual_duration = len(reconstructed_audio) / save_sample_rate - # Save reconstructed audio sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") else: print("Warning: No audio chunks received or reconstructed.") - actual_duration = 0 # Set duration to 0 if no audio + actual_duration = 0 else: reconstructed_audio = np.concatenate(audios) - print(f"reconstructed_audio: {reconstructed_audio.shape}") actual_duration = len(reconstructed_audio) / save_sample_rate - # Save reconstructed audio sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") else: @@ -548,7 +490,7 @@ def run_sync_streaming_inference( async def send_streaming( manifest_item_list: list, name: str, - server_url: str, # Changed from sync_triton_client + server_url: str, protocol_client: types.ModuleType, log_interval: int, model_name: str, @@ -561,12 +503,12 @@ async def send_streaming( total_duration = 0.0 latency_data = [] task_id = int(name[5:]) - sync_triton_client = None # Initialize client variable + sync_triton_client = None user_data_map = {} - try: # Wrap in try...finally to ensure client closing + try: print(f"{name}: Initializing sync client for streaming...") - sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here + sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map)) print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.") @@ -593,7 +535,6 @@ async def send_streaming( user_data_map[request_id] = user_data audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") - print("target_text: ", target_text, "time: ", datetime.now()) total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread( run_sync_streaming_inference, sync_triton_client, @@ -627,7 +568,7 @@ async def send_streaming( import traceback traceback.print_exc() - finally: # Ensure client is closed + finally: if sync_triton_client: try: print(f"{name}: Closing stream and sync client...") @@ -656,7 +597,6 @@ async def send( latency_data = [] task_id = int(name[5:]) - print(f"manifest_item_list: {manifest_item_list}") for i, item in enumerate(manifest_item_list): if i % log_interval == 0: print(f"{name}: {i}/{len(manifest_item_list)}") @@ -697,7 +637,6 @@ def load_manifests(manifest_path): assert len(line.strip().split("|")) == 4 utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") utt = Path(utt).stem - # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav") if not os.path.isabs(prompt_wav): prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav) manifest_list.append( @@ -738,23 +677,17 @@ async def main(): args = get_args() url = f"{args.server_addr}:{args.server_port}" - # --- Client Initialization based on mode --- triton_client = None protocol_client = None if args.mode == "offline": print("Initializing gRPC client for offline mode...") - # Use the async client for offline tasks triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) protocol_client = grpcclient_aio elif args.mode == "streaming": print("Initializing gRPC client for streaming mode...") - # Use the sync client for streaming tasks, handled via asyncio.to_thread - # We will create one sync client instance PER TASK inside send_streaming. - # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now - protocol_client = grpcclient_sync # protocol client for input prep + protocol_client = grpcclient_sync else: raise ValueError(f"Invalid mode: {args.mode}") - # --- End Client Initialization --- if args.reference_audio: args.num_tasks = 1 @@ -776,24 +709,18 @@ async def main(): trust_remote_code=True, ) manifest_item_list = [] - tmp_audio_path="./asset_zero_shot_prompt.wav" - tmp_audio_text="希望你以后能够做的比我还好呦。" for i in range(len(dataset)): manifest_item_list.append( { "audio_filepath": dataset[i]["prompt_audio"], "reference_text": dataset[i]["prompt_text"], - # "audio_filepath": tmp_audio_path, - # "reference_text": tmp_audio_text, "target_audio_path": dataset[i]["id"], "target_text": dataset[i]["target_text"], } ) - # manifest_item_list = manifest_item_list[:4] else: manifest_item_list = load_manifests(args.manifest_path) - # --- Statistics Fetching (Before) --- stats_client = None stats_before = None try: @@ -803,7 +730,6 @@ async def main(): stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True) except Exception as e: print(f"Could not retrieve statistics before running tasks: {e}") - # --- End Statistics Fetching (Before) --- num_tasks = min(args.num_tasks, len(manifest_item_list)) manifest_item_list = split_data(manifest_item_list, num_tasks) @@ -813,7 +739,6 @@ async def main(): tasks = [] start_time = time.time() for i in range(num_tasks): - # --- Task Creation based on mode --- if args.mode == "offline": task = asyncio.create_task( send( @@ -834,7 +759,7 @@ async def main(): send_streaming( manifest_item_list[i], name=f"task-{i}", - server_url=url, # Pass URL instead of client + server_url=url, protocol_client=protocol_client, log_interval=args.log_interval, model_name=args.model_name, @@ -845,7 +770,6 @@ async def main(): use_spk2info_cache=args.use_spk2info_cache, ) ) - # --- End Task Creation --- tasks.append(task) ans_list = await asyncio.gather(*tasks) @@ -858,7 +782,7 @@ async def main(): for ans in ans_list: if ans: total_duration += ans[0] - latency_data.extend(ans[1]) # Use extend for list of lists + latency_data.extend(ans[1]) else: print("Warning: A task returned None, possibly due to an error.") @@ -874,10 +798,8 @@ async def main(): s += f"({total_duration / 3600:.2f} hours)\n" s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n" - # --- Statistics Reporting based on mode --- if latency_data: if args.mode == "offline": - # Original offline latency calculation latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data] if latency_list: latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0 @@ -892,7 +814,6 @@ async def main(): s += "No latency data collected for offline mode.\n" elif args.mode == "streaming": - # Calculate stats for total request latency and first chunk latency total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None] first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None] second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None] @@ -937,7 +858,6 @@ async def main(): s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n" else: s += "No latency data collected.\n" - # --- End Statistics Reporting --- print(s) if args.manifest_path: @@ -947,12 +867,10 @@ async def main(): elif args.reference_audio: name = Path(args.reference_audio).stem else: - name = "results" # Default name if no manifest/split/audio provided + name = "results" with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: f.write(s) - # --- Statistics Fetching using temporary Async Client --- - # Use a separate async client for fetching stats regardless of mode try: if stats_client and stats_before: print("Fetching inference statistics after running tasks...") @@ -980,11 +898,9 @@ async def main(): await stats_client.close() except Exception as e: print(f"Error closing async stats client: {e}") - # --- End Statistics Fetching --- if __name__ == "__main__": - # asyncio.run(main()) # Use TaskGroup for better exception handling if needed async def run_main(): try: await main() diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py index 2f81786..827925c 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py @@ -43,7 +43,7 @@ import torchaudio from matcha.utils.audio import mel_spectrogram -from datetime import datetime + ORIGINAL_VOCAB_SIZE = 151663 torch.set_num_threads(1) @@ -85,9 +85,7 @@ class TritonPythonModel: self.model_config = json.loads(args['model_config']) parameters = self.model_config['parameters'] model_params = {k: v["string_value"] for k, v in parameters.items()} - self.logger.log_info(f"model_params:{model_params}") self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" - # self.dynamic_chunk_strategy = "equal" self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") # Initialize tokenizer @@ -103,12 +101,8 @@ class TritonPythonModel: self.flow_pre_lookahead_len = 3 self.token_hop_len = 15 - spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt") - if not os.path.exists(spk_info_path): - raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") - spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) - self.default_spk_info = spk_info["001"] self.http_client = httpx.AsyncClient() + self.api_base = "http://localhost:8000/v1/chat/completions" def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: """Converts a tensor or list of speech token IDs to a string representation.""" @@ -147,12 +141,8 @@ class TritonPythonModel: "stream": True, } - api_base = "http://localhost:8000/v1/chat/completions" - buffer = "" - async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response: - print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}") - print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}") + async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): @@ -164,7 +154,6 @@ class TritonPythonModel: content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") if content: buffer += content - print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}") while True: match = re.search(r"<\|s_(\d+)\|>", buffer) if not match: @@ -307,40 +296,24 @@ class TritonPythonModel: wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") # Process reference audio through audio tokenizer - if wav is not None: - wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") - prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) - prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - wav_tensor = wav.as_numpy() - wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] - print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}") - prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) - speech_feat = self._extract_speech_feat(prompt_speech_resample) - token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) - prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() - prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) + prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() - reference_text = reference_text[0][0].decode('utf-8') - # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) + wav_tensor = wav.as_numpy() + wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] + prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) + speech_feat = self._extract_speech_feat(prompt_speech_resample) + token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) + prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() + prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() - # reference_text = self.default_spk_info["prompt_text"] - # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - # prompt_speech_feat = None - # prompt_spk_embedding = None - - else: - # using pre-cached reference text - assert False, "using pre-cached reference text is not supported" - reference_text = self.default_spk_info["prompt_text"] - prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE - prompt_speech_feat = None - prompt_spk_embedding = None + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode('utf-8') - print(f"target_text: {target_text}, time: {datetime.now()}") if self.decoupled: response_sender = request.get_response_sender() @@ -349,7 +322,6 @@ class TritonPythonModel: token_offset, chunk_index = 0, 0 start_time = time.time() this_token_hop_len = self.token_hop_len - print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}") async for generated_ids in self.forward_llm_async( target_text=target_text, reference_text=reference_text, @@ -358,24 +330,20 @@ class TritonPythonModel: if not generated_ids: break semantic_token_ids_arr.append(generated_ids) - print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}") while True: pending_num = len(semantic_token_ids_arr) - token_offset if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) - print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}") sub_tts_speech = await self.forward_token2wav( chunk_index, this_tts_speech_token, request_id, wav, wav_len, False ) - print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}") audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) token_offset += this_token_hop_len - self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}") if self.dynamic_chunk_strategy == "exponential": this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) @@ -389,7 +357,6 @@ class TritonPythonModel: avg_chunk_processing_time = cost_time / (chunk_index + 1) if avg_chunk_processing_time > 0: multiples = (duration - cost_time) / avg_chunk_processing_time - self.logger.log_info(f"multiples: {multiples}") next_pending_num = len(semantic_token_ids_arr) - token_offset if multiples > 4: this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len @@ -409,9 +376,8 @@ class TritonPythonModel: response_sender.send(inference_response) response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) - self.logger.log_info("send tritonserver_response_complete_final to end") else: - raise NotImplementedError("Decoupled mode is not supported") + raise NotImplementedError("Offline TTS mode is not supported") async def execute(self, requests): """Execute inference on the batched requests. diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py index 230bad0..1f6b591 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py @@ -106,13 +106,10 @@ class TritonPythonModel: # Process each request in batch for request in requests: target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy() - target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device) - # shift the speech tokens according to the original vocab size + target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor) target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE target_speech_tokens = target_speech_tokens.squeeze().tolist() - # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts. - finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item() request_id = request.request_id() @@ -124,23 +121,14 @@ class TritonPythonModel: request, "reference_wav_len").as_numpy().item() wav_array = torch.from_numpy(wav_array) - # Prepare inputs wav = wav_array[:, :wav_len].squeeze(0) spk_id = get_spk_id_from_prompt_audio(wav) - # wav = wav.to(self.device) - - # update cache before forward - # self.token2wav_model.streaming_flow_cache[request_id] - # self.token2wav_model.hift_cache_dict[request_id] audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000) - # get the cache after forward outputs = [] - generated_wave = audio_hat.squeeze(0).cpu().numpy() - wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) outputs.append(wav_tensor) inference_response = pb_utils.InferenceResponse(output_tensors=outputs) diff --git a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py index 63dce14..bda4cb1 100644 --- a/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py +++ b/runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py @@ -320,7 +320,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def forward( self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] ): - # assert all item in prompt_audios_sample_rate is 16000 assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) @@ -335,7 +334,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): def prepare_prompt_audio( self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int] ): - # assert all item in prompt_audios_sample_rate is 16000 assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate) @@ -385,7 +383,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict} - print(f"speaker_id {speaker_id} added to cache") if request_id not in self.streaming_flow_cache: self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()} @@ -394,12 +391,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module): source = torch.zeros(1, 1, 0, device='cuda'), speech = torch.zeros(1, 0, device='cuda'), ) - # else: - # for k, v in self.streaming_flow_cache[request_id].items(): - # print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}") - # for k, v in self.hift_cache_dict[request_id].items(): - # print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}") - # breakpoint() current_request_cache = self.streaming_flow_cache[request_id] @@ -477,7 +468,6 @@ def get_args(): if __name__ == "__main__": args = get_args() model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt) - # mkdir output_dir if not exists if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) dataset_name = "yuekai/seed_tts_cosy2" diff --git a/runtime/triton_trtllm/streaming_inference.py b/runtime/triton_trtllm/streaming_inference.py index 026feb5..a5404e2 100644 --- a/runtime/triton_trtllm/streaming_inference.py +++ b/runtime/triton_trtllm/streaming_inference.py @@ -35,12 +35,6 @@ def get_args(): return parser.parse_args() -def fake_generated_id_iter(generated_speech_tokens_list): - for i in range(len(generated_speech_tokens_list)): - yield generated_speech_tokens_list[i] - - - if __name__ == "__main__": args = get_args() @@ -53,7 +47,6 @@ if __name__ == "__main__": token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True) - flow_pre_lookahead_len = 3 CHUNK_SIZE = 25 token_frame_rate = 25 OVERLAP_SIZE = 0