diff --git a/runtime/triton_trtllm/client_grpc.py b/runtime/triton_trtllm/client_grpc.py index 7dba493..881b519 100644 --- a/runtime/triton_trtllm/client_grpc.py +++ b/runtime/triton_trtllm/client_grpc.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) # 2023 Nvidia (authors: Yuekai Zhang) # 2023 Recurrent.ai (authors: Songtao Shi) @@ -46,7 +45,7 @@ import asyncio import json import queue # Added import uuid # Added -import functools # Added +import functools # Added import os import time @@ -56,9 +55,9 @@ from pathlib import Path import numpy as np import soundfile as sf import tritonclient -import tritonclient.grpc.aio as grpcclient_aio # Renamed original import -import tritonclient.grpc as grpcclient_sync # Added sync client import -from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException +import tritonclient.grpc.aio as grpcclient_aio # Renamed original import +import tritonclient.grpc as grpcclient_sync # Added sync client import +from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException # --- Added UserData and callback --- @@ -76,9 +75,10 @@ class UserData: return self._first_chunk_time - self._start_time return None + def callback(user_data, result, error): if user_data._first_chunk_time is None and not error: - user_data._first_chunk_time = time.time() # Record time of first successful chunk + user_data._first_chunk_time = time.time() # Record time of first successful chunk if error: user_data._completed_requests.put(error) else: @@ -206,8 +206,11 @@ def get_args(): "--model-name", type=str, default="f5_tts", - choices=["f5_tts", "spark_tts", "cosyvoice2"], - help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", + choices=[ + "f5_tts", + "spark_tts", + "cosyvoice2"], + help="triton model_repo module name to request", ) parser.add_argument( @@ -273,13 +276,14 @@ def load_audio(wav_path, target_sample_rate=16000): waveform = resample(waveform, num_samples) return waveform, target_sample_rate + def prepare_request_input_output( - protocol_client, # Can be grpcclient_aio or grpcclient_sync + protocol_client, # Can be grpcclient_aio or grpcclient_sync waveform, reference_text, target_text, sample_rate=16000, - padding_duration: int = None # Optional padding for offline mode + padding_duration: int = None # Optional padding for offline mode ): """Prepares inputs for Triton inference (offline or streaming).""" assert len(waveform.shape) == 1, "waveform should be 1D" @@ -291,9 +295,9 @@ def prepare_request_input_output( # Estimate target duration based on text length ratio (crude estimation) # Avoid division by zero if reference_text is empty if reference_text: - estimated_target_duration = duration / len(reference_text) * len(target_text) + estimated_target_duration = duration / len(reference_text) * len(target_text) else: - estimated_target_duration = duration # Assume target duration similar to reference if no text + estimated_target_duration = duration # Assume target duration similar to reference if no text # Calculate required samples based on estimated total duration required_total_samples = padding_duration * sample_rate * ( @@ -329,6 +333,7 @@ def prepare_request_input_output( return inputs, outputs + def run_sync_streaming_inference( sync_triton_client: tritonclient.grpc.InferenceServerClient, model_name: str, @@ -342,7 +347,7 @@ def run_sync_streaming_inference( ): """Helper function to run the blocking sync streaming call.""" start_time_total = time.time() - user_data.record_start_time() # Record start time for first chunk latency calculation + user_data.record_start_time() # Record start time for first chunk latency calculation # Establish stream sync_triton_client.start_stream(callback=functools.partial(callback, user_data)) @@ -360,11 +365,11 @@ def run_sync_streaming_inference( audios = [] while True: try: - result = user_data._completed_requests.get() # Add timeout + result = user_data._completed_requests.get() # Add timeout if isinstance(result, InferenceServerException): print(f"Received InferenceServerException: {result}") sync_triton_client.stop_stream() - return None, None, None # Indicate error + return None, None, None # Indicate error # Get response metadata response = result.get_response() final = response.parameters["triton_final_response"].bool_param @@ -372,15 +377,15 @@ def run_sync_streaming_inference( break audio_chunk = result.as_numpy("waveform").reshape(-1) - if audio_chunk.size > 0: # Only append non-empty chunks - audios.append(audio_chunk) + if audio_chunk.size > 0: # Only append non-empty chunks + audios.append(audio_chunk) else: print("Warning: received empty audio chunk.") except queue.Empty: print(f"Timeout waiting for response for request id {request_id}") sync_triton_client.stop_stream() - return None, None, None # Indicate error + return None, None, None # Indicate error sync_triton_client.stop_stream() end_time_total = time.time() @@ -398,19 +403,19 @@ def run_sync_streaming_inference( # Simplified reconstruction based on client_grpc_streaming.py if not audios: print("Warning: No audio chunks received.") - reconstructed_audio = np.array([], dtype=np.float32) # Empty array + reconstructed_audio = np.array([], dtype=np.float32) # Empty array elif len(audios) == 1: reconstructed_audio = audios[0] else: - reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap + reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap for i in range(1, len(audios)): - # Cross-fade section - cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + - audios[i - 1][-cross_fade_samples:] * fade_out) - # Middle section of the current chunk - middle_part = audios[i][cross_fade_samples:-cross_fade_samples] - # Concatenate - reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) + # Cross-fade section + cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + + audios[i - 1][-cross_fade_samples:] * fade_out) + # Middle section of the current chunk + middle_part = audios[i][cross_fade_samples:-cross_fade_samples] + # Concatenate + reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) # Add the last part of the final chunk reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) @@ -421,11 +426,11 @@ def run_sync_streaming_inference( sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") else: print("Warning: No audio chunks received or reconstructed.") - actual_duration = 0 # Set duration to 0 if no audio + actual_duration = 0 # Set duration to 0 if no audio else: - print("Warning: No audio chunks received.") - actual_duration = 0 + print("Warning: No audio chunks received.") + actual_duration = 0 return total_request_latency, first_chunk_latency, actual_duration @@ -433,7 +438,7 @@ def run_sync_streaming_inference( async def send_streaming( manifest_item_list: list, name: str, - server_url: str, # Changed from sync_triton_client + server_url: str, # Changed from sync_triton_client protocol_client: types.ModuleType, log_interval: int, model_name: str, @@ -445,11 +450,11 @@ async def send_streaming( total_duration = 0.0 latency_data = [] task_id = int(name[5:]) - sync_triton_client = None # Initialize client variable + sync_triton_client = None # Initialize client variable - try: # Wrap in try...finally to ensure client closing + try: # Wrap in try...finally to ensure client closing print(f"{name}: Initializing sync client for streaming...") - sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here + sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.") for i, item in enumerate(manifest_item_list): @@ -491,8 +496,7 @@ async def send_streaming( latency_data.append((total_request_latency, first_chunk_latency, actual_duration)) total_duration += actual_duration else: - print(f"{name}: Item {i} failed.") - + print(f"{name}: Item {i} failed.") except FileNotFoundError: print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}") @@ -501,8 +505,7 @@ async def send_streaming( import traceback traceback.print_exc() - - finally: # Ensure client is closed + finally: # Ensure client is closed if sync_triton_client: try: print(f"{name}: Closing sync client...") @@ -510,10 +513,10 @@ async def send_streaming( except Exception as e: print(f"{name}: Error closing sync client: {e}") - print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s") return total_duration, latency_data + async def send( manifest_item_list: list, name: str, @@ -605,6 +608,7 @@ def split_data(data, k): return result + async def main(): args = get_args() url = f"{args.server_addr}:{args.server_port}" @@ -622,7 +626,7 @@ async def main(): # Use the sync client for streaming tasks, handled via asyncio.to_thread # We will create one sync client instance PER TASK inside send_streaming. # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now - protocol_client = grpcclient_sync # protocol client for input prep + protocol_client = grpcclient_sync # protocol client for input prep else: raise ValueError(f"Invalid mode: {args.mode}") # --- End Client Initialization --- @@ -682,11 +686,11 @@ async def main(): ) ) elif args.mode == "streaming": - task = asyncio.create_task( + task = asyncio.create_task( send_streaming( manifest_item_list[i], name=f"task-{i}", - server_url=url, # Pass URL instead of client + server_url=url, # Pass URL instead of client protocol_client=protocol_client, log_interval=args.log_interval, model_name=args.model_name, @@ -709,16 +713,15 @@ async def main(): for ans in ans_list: if ans: total_duration += ans[0] - latency_data.extend(ans[1]) # Use extend for list of lists + latency_data.extend(ans[1]) # Use extend for list of lists else: - print("Warning: A task returned None, possibly due to an error.") - + print("Warning: A task returned None, possibly due to an error.") if total_duration == 0: print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.") rtf = float('inf') else: - rtf = elapsed / total_duration + rtf = elapsed / total_duration s = f"Mode: {args.mode}\n" s += f"RTF: {rtf:.4f}\n" @@ -759,7 +762,7 @@ async def main(): s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n" s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n" else: - s += "No total request latency data collected.\n" + s += "No total request latency data collected.\n" s += "\n--- First Chunk Latency ---\n" if first_chunk_latency_list: @@ -772,7 +775,7 @@ async def main(): s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n" s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n" else: - s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n" + s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n" else: s += "No latency data collected.\n" # --- End Statistics Reporting --- @@ -785,7 +788,7 @@ async def main(): elif args.reference_audio: name = Path(args.reference_audio).stem else: - name = "results" # Default name if no manifest/split/audio provided + name = "results" # Default name if no manifest/split/audio provided with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: f.write(s) diff --git a/runtime/triton_trtllm/client_http.py b/runtime/triton_trtllm/client_http.py index e22f4eb..4d73e0b 100644 --- a/runtime/triton_trtllm/client_http.py +++ b/runtime/triton_trtllm/client_http.py @@ -29,6 +29,7 @@ import json import numpy as np import argparse + def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -67,9 +68,10 @@ def get_args(): type=str, default="spark_tts", choices=[ - "f5_tts", "spark_tts", "cosyvoice2" - ], - help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", + "f5_tts", + "spark_tts", + "cosyvoice2"], + help="triton model_repo module name to request", ) parser.add_argument( @@ -80,6 +82,7 @@ def get_args(): ) return parser.parse_args() + def prepare_request( waveform, reference_text, @@ -97,7 +100,7 @@ def prepare_request( 1, padding_duration * sample_rate - * ((int(duration) // padding_duration) + 1), + * ((int(len(waveform) / sample_rate) // padding_duration) + 1), ), dtype=np.float32, ) @@ -105,11 +108,11 @@ def prepare_request( samples[0, : len(waveform)] = waveform else: samples = waveform - + samples = samples.reshape(1, -1).astype(np.float32) data = { - "inputs":[ + "inputs": [ { "name": "reference_wav", "shape": samples.shape, @@ -139,16 +142,17 @@ def prepare_request( return data + if __name__ == "__main__": args = get_args() server_url = args.server_url if not server_url.startswith(("http://", "https://")): server_url = f"http://{server_url}" - + url = f"{server_url}/v2/models/{args.model_name}/infer" waveform, sr = sf.read(args.reference_audio) assert sr == 16000, "sample rate hardcoded in server" - + samples = np.array(waveform, dtype=np.float32) data = prepare_request(samples, args.reference_text, args.target_text) @@ -166,4 +170,4 @@ if __name__ == "__main__": sample_rate = 16000 else: sample_rate = 24000 - sf.write(args.output_audio, audio, sample_rate, "PCM_16") \ No newline at end of file + sf.write(args.output_audio, audio, sample_rate, "PCM_16") diff --git a/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py index 105ffa1..47383e2 100644 --- a/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py +++ b/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py @@ -35,33 +35,34 @@ import s3tokenizer ORIGINAL_VOCAB_SIZE = 151663 + class TritonPythonModel: """Triton Python model for audio tokenization. - + This model takes reference audio input and extracts semantic tokens using s3tokenizer. """ def initialize(self, args): """Initialize the model. - + Args: args: Dictionary containing model configuration """ # Parse model parameters parameters = json.loads(args['model_config'])['parameters'] model_params = {k: v["string_value"] for k, v in parameters.items()} - + self.device = torch.device("cuda") model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx") self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device) def execute(self, requests): """Execute inference on the batched requests. - + Args: requests: List of inference requests - + Returns: List of inference responses containing tokenized outputs """ @@ -79,18 +80,18 @@ class TritonPythonModel: # Prepare inputs wav = wav_array[:, :wav_len].squeeze(0) mels.append(s3tokenizer.log_mel_spectrogram(wav)) - + mels, mels_lens = s3tokenizer.padding(mels) codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device)) codes = codes.clone() + ORIGINAL_VOCAB_SIZE - + responses = [] for i in range(len(requests)): - prompt_speech_tokens = codes[i, :codes_lens[i].item()] + prompt_speech_tokens = codes[i, :codes_lens[i].item()] prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack( "prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) inference_response = pb_utils.InferenceResponse( output_tensors=[prompt_speech_tokens_tensor]) responses.append(inference_response) - - return responses \ No newline at end of file + + return responses diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py index cb91677..77a440b 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py @@ -42,16 +42,17 @@ import onnxruntime from matcha.utils.audio import mel_spectrogram + class TritonPythonModel: """Triton Python model for Spark TTS. - + This model orchestrates the end-to-end TTS pipeline by coordinating between audio tokenizer, LLM, and vocoder components. """ - + def initialize(self, args): """Initialize the model. - + Args: args: Dictionary containing model configuration """ @@ -116,58 +117,58 @@ class TritonPythonModel: "input_ids": input_ids, "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), } - + # Convert inputs to Triton tensors input_tensor_list = [ pb_utils.Tensor(k, v) for k, v in input_dict.items() ] - + # Create and execute inference request llm_request = pb_utils.InferenceRequest( model_name="tensorrt_llm", requested_output_names=["output_ids", "sequence_length"], inputs=input_tensor_list, ) - + llm_responses = llm_request.exec(decoupled=self.decoupled) if self.decoupled: for llm_response in llm_responses: if llm_response.has_error(): raise pb_utils.TritonModelException(llm_response.error().message()) - + # Extract and process output output_ids = pb_utils.get_output_tensor_by_name( llm_response, "output_ids").as_numpy() seq_lens = pb_utils.get_output_tensor_by_name( llm_response, "sequence_length").as_numpy() - + # Get actual output IDs up to the sequence length actual_output_ids = output_ids[0][0][:seq_lens[0][0]] - + yield actual_output_ids else: llm_response = llm_responses if llm_response.has_error(): raise pb_utils.TritonModelException(llm_response.error().message()) - + # Extract and process output output_ids = pb_utils.get_output_tensor_by_name( llm_response, "output_ids").as_numpy() seq_lens = pb_utils.get_output_tensor_by_name( llm_response, "sequence_length").as_numpy() - + # Get actual output IDs up to the sequence length actual_output_ids = output_ids[0][0][:seq_lens[0][0]] - - yield actual_output_ids - + + yield actual_output_ids + def forward_audio_tokenizer(self, wav, wav_len): """Forward pass through the audio tokenizer component. - + Args: wav: Input waveform tensor wav_len: Waveform length tensor - + Returns: Tuple of global and semantic tokens """ @@ -176,26 +177,31 @@ class TritonPythonModel: requested_output_names=['prompt_speech_tokens'], inputs=[wav, wav_len] ) - + inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) - + # Extract and convert output tensors prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() return prompt_speech_tokens - def forward_token2wav(self, prompt_speech_tokens: torch.Tensor, prompt_speech_feat: torch.Tensor, prompt_spk_embedding: torch.Tensor, target_speech_tokens: torch.Tensor) -> torch.Tensor: + def forward_token2wav( + self, + prompt_speech_tokens: torch.Tensor, + prompt_speech_feat: torch.Tensor, + prompt_spk_embedding: torch.Tensor, + target_speech_tokens: torch.Tensor) -> torch.Tensor: """Forward pass through the vocoder component. - + Args: prompt_speech_tokens: Prompt speech tokens tensor prompt_speech_feat: Prompt speech feat tensor prompt_spk_embedding: Prompt spk embedding tensor target_speech_tokens: Target speech tokens tensor - + Returns: Generated waveform tensor """ @@ -203,22 +209,22 @@ class TritonPythonModel: prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) - + # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav', requested_output_names=['waveform'], inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor] ) - + inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) - + # Extract and convert output waveform waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() - + return waveform def parse_input(self, text, prompt_text, prompt_speech_tokens): @@ -231,43 +237,53 @@ class TritonPythonModel: def _extract_spk_embedding(self, speech): feat = kaldi.fbank(speech, - num_mel_bins=80, - dither=0, - sample_frequency=16000) + num_mel_bins=80, + dither=0, + sample_frequency=16000) feat = feat - feat.mean(dim=0, keepdim=True) embedding = self.campplus_session.run(None, - {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() + {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() embedding = torch.tensor([embedding]).to(self.device).half() return embedding - def _extract_speech_feat(self, speech): - speech_feat = mel_spectrogram(speech, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000).squeeze(dim=0).transpose(0, 1).to(self.device) + speech_feat = mel_spectrogram( + speech, + n_fft=1920, + num_mels=80, + sampling_rate=24000, + hop_size=480, + win_size=1920, + fmin=0, + fmax=8000).squeeze( + dim=0).transpose( + 0, + 1).to( + self.device) speech_feat = speech_feat.unsqueeze(dim=0) return speech_feat def execute(self, requests): """Execute inference on the batched requests. - + Args: requests: List of inference requests - + Returns: List of inference responses containing generated audio """ responses = [] - + for request in requests: # Extract input tensors wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") - + # Process reference audio through audio tokenizer prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) - wav_tensor = wav.as_numpy() wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) @@ -275,20 +291,20 @@ class TritonPythonModel: token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() - + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') - + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode('utf-8') - + # Prepare prompt for LLM input_ids = self.parse_input( text=target_text, prompt_text=reference_text, prompt_speech_tokens=prompt_speech_tokens, ) - + # Generate semantic tokens with LLM generated_ids_iter = self.forward_llm(input_ids) @@ -305,13 +321,13 @@ class TritonPythonModel: generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device) prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) - + # Prepare response audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) - self.logger.log_info(f"send tritonserver_response_complete_final to end") + self.logger.log_info("send tritonserver_response_complete_final to end") else: generated_ids = next(generated_ids_iter) generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device) @@ -320,11 +336,11 @@ class TritonPythonModel: prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) - + # Prepare response audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) responses.append(inference_response) - + if not self.decoupled: - return responses \ No newline at end of file + return responses diff --git a/runtime/triton_trtllm/model_repo/token2wav/1/model.py b/runtime/triton_trtllm/model_repo/token2wav/1/model.py index d6735a1..d38f8a4 100644 --- a/runtime/triton_trtllm/model_repo/token2wav/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav/1/model.py @@ -44,6 +44,7 @@ logger = logging.getLogger(__name__) ORIGINAL_VOCAB_SIZE = 151663 + class CosyVoice2: def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): @@ -66,6 +67,7 @@ class CosyVoice2: trt_concurrent, self.fp16) + class CosyVoice2Model: def __init__(self, @@ -109,16 +111,17 @@ class CosyVoice2Model: input_names = ["x", "mask", "mu", "cond"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} + class TritonPythonModel: """Triton Python model for vocoder. - + This model takes global and semantic tokens as input and generates audio waveforms using the BiCodec vocoder. """ def initialize(self, args): """Initialize the model. - + Args: args: Dictionary containing model configuration """ @@ -126,24 +129,23 @@ class TritonPythonModel: parameters = json.loads(args['model_config'])['parameters'] model_params = {key: value["string_value"] for key, value in parameters.items()} model_dir = model_params["model_dir"] - + # Initialize device and vocoder self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Initializing vocoder from {model_dir} on {self.device}") - + self.token2wav_model = CosyVoice2( model_dir, load_jit=True, load_trt=True, fp16=True ) logger.info("Token2Wav initialized successfully") - def execute(self, requests): """Execute inference on the batched requests. - + Args: requests: List of inference requests - + Returns: List of inference responses containing generated waveforms """ @@ -163,7 +165,7 @@ class TritonPythonModel: # shift the speech tokens according to the original vocab size prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE - + tts_mel, _ = self.token2wav_model.model.flow.inference( token=target_speech_tokens, token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( @@ -189,9 +191,5 @@ class TritonPythonModel: wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) responses.append(inference_response) - + return responses - - - - diff --git a/runtime/triton_trtllm/scripts/convert_checkpoint.py b/runtime/triton_trtllm/scripts/convert_checkpoint.py index 932cdf8..7cd166f 100644 --- a/runtime/triton_trtllm/scripts/convert_checkpoint.py +++ b/runtime/triton_trtllm/scripts/convert_checkpoint.py @@ -35,8 +35,7 @@ def parse_arguments(): type=str, default='auto', choices=['auto', 'float16', 'bfloat16', 'float32'], - help= - "The data type for the model weights and activations if not quantized. " + help="The data type for the model weights and activations if not quantized. " "If 'auto', the data type is automatically inferred from the source model; " "however, if the source dtype is float32, it is converted to float16.") parser.add_argument( @@ -49,8 +48,7 @@ def parse_arguments(): '--disable_weight_only_quant_plugin', default=False, action="store_true", - help= - 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' + help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument( @@ -60,16 +58,14 @@ def parse_arguments(): nargs='?', default='int8', choices=['int8', 'int4', 'int4_gptq'], - help= - 'Define the precision for the weights when using weight-only quantization.' + help='Define the precision for the weights when using weight-only quantization.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument( '--calib_dataset', type=str, default='ccdv/cnn_dailymail', - help= - "The huggingface dataset name or the local directory of the dataset for calibration." + help="The huggingface dataset name or the local directory of the dataset for calibration." ) parser.add_argument( "--smoothquant", @@ -83,31 +79,27 @@ def parse_arguments(): '--per_channel', action="store_true", default=False, - help= - 'By default, we use a single static scaling factor for the GEMM\'s result. ' + help='By default, we use a single static scaling factor for the GEMM\'s result. ' 'per_channel instead uses a different static scaling factor for each channel. ' 'The latter is usually more accurate, but a little slower.') parser.add_argument( '--per_token', action="store_true", default=False, - help= - 'By default, we use a single static scaling factor to scale activations in the int8 range. ' + help='By default, we use a single static scaling factor to scale activations in the int8 range. ' 'per_token chooses at run time, and for each token, a custom scaling factor. ' 'The latter is usually more accurate, but a little slower.') parser.add_argument( '--int8_kv_cache', default=False, action="store_true", - help= - 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' + help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' ) parser.add_argument( '--per_group', default=False, action="store_true", - help= - 'By default, we use a single static scaling factor to scale weights in the int4 range. ' + help='By default, we use a single static scaling factor to scale weights in the int4 range. ' 'per_group chooses at run time, and for each group, a custom scaling factor. ' 'The flag is built for GPTQ/AWQ quantization.') @@ -121,16 +113,14 @@ def parse_arguments(): '--use_parallel_embedding', action="store_true", default=False, - help= - 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' ) parser.add_argument( '--embedding_sharding_dim', type=int, default=0, choices=[0, 1], - help= - 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' 'To shard it along hidden dimension, set embedding_sharding_dim=1' 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' ) @@ -147,15 +137,13 @@ def parse_arguments(): '--moe_tp_size', type=int, default=-1, - help= - 'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE' + help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE' ) parser.add_argument( '--moe_ep_size', type=int, default=-1, - help= - 'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE' + help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE' ) args = parser.parse_args() return args @@ -249,7 +237,7 @@ def convert_and_save_hf(args): trust_remote_code=True) quant_config, override_fields = update_quant_config_from_hf( quant_config, hf_config, override_fields) - except: + except BaseException: logger.warning("AutoConfig cannot load the huggingface config.") if args.smoothquant is not None or args.int8_kv_cache: @@ -339,4 +327,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/runtime/triton_trtllm/scripts/fill_template.py b/runtime/triton_trtllm/scripts/fill_template.py index 5c629f7..6e6a2bc 100644 --- a/runtime/triton_trtllm/scripts/fill_template.py +++ b/runtime/triton_trtllm/scripts/fill_template.py @@ -1,4 +1,4 @@ -#! /usr/bin/env python3 +# /usr/bin/env python3 from argparse import ArgumentParser from string import Template @@ -59,8 +59,7 @@ if __name__ == "__main__": parser.add_argument("file_path", help="path of the .pbtxt to modify") parser.add_argument( "substitutions", - help= - "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." + help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." ) parser.add_argument("--in_place", "-i", diff --git a/runtime/triton_trtllm/scripts/test_llm.py b/runtime/triton_trtllm/scripts/test_llm.py index 9ffe9cf..d52d724 100644 --- a/runtime/triton_trtllm/scripts/test_llm.py +++ b/runtime/triton_trtllm/scripts/test_llm.py @@ -46,7 +46,6 @@ def parse_arguments(args=None): parser.add_argument('--top_k', type=int, default=50) parser.add_argument('--top_p', type=float, default=0.95) - return parser.parse_args(args=args) @@ -60,7 +59,7 @@ def parse_input(tokenizer, input_ids = tokenizer.encode( curr_text) batch_input_ids.append(input_ids) - + batch_input_ids = [ torch.tensor(x, dtype=torch.int32) for x in batch_input_ids ]