This commit is contained in:
root
2025-07-29 08:39:41 +00:00
parent 1b8d194b67
commit 07cbc51cd1
8 changed files with 165 additions and 157 deletions

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Nvidia (authors: Yuekai Zhang) # 2023 Nvidia (authors: Yuekai Zhang)
# 2023 Recurrent.ai (authors: Songtao Shi) # 2023 Recurrent.ai (authors: Songtao Shi)
@@ -46,7 +45,7 @@ import asyncio
import json import json
import queue # Added import queue # Added
import uuid # Added import uuid # Added
import functools # Added import functools # Added
import os import os
import time import time
@@ -56,9 +55,9 @@ from pathlib import Path
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf
import tritonclient import tritonclient
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
import tritonclient.grpc as grpcclient_sync # Added sync client import import tritonclient.grpc as grpcclient_sync # Added sync client import
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
# --- Added UserData and callback --- # --- Added UserData and callback ---
@@ -76,9 +75,10 @@ class UserData:
return self._first_chunk_time - self._start_time return self._first_chunk_time - self._start_time
return None return None
def callback(user_data, result, error): def callback(user_data, result, error):
if user_data._first_chunk_time is None and not error: if user_data._first_chunk_time is None and not error:
user_data._first_chunk_time = time.time() # Record time of first successful chunk user_data._first_chunk_time = time.time() # Record time of first successful chunk
if error: if error:
user_data._completed_requests.put(error) user_data._completed_requests.put(error)
else: else:
@@ -206,8 +206,11 @@ def get_args():
"--model-name", "--model-name",
type=str, type=str,
default="f5_tts", default="f5_tts",
choices=["f5_tts", "spark_tts", "cosyvoice2"], choices=[
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", "f5_tts",
"spark_tts",
"cosyvoice2"],
help="triton model_repo module name to request",
) )
parser.add_argument( parser.add_argument(
@@ -273,13 +276,14 @@ def load_audio(wav_path, target_sample_rate=16000):
waveform = resample(waveform, num_samples) waveform = resample(waveform, num_samples)
return waveform, target_sample_rate return waveform, target_sample_rate
def prepare_request_input_output( def prepare_request_input_output(
protocol_client, # Can be grpcclient_aio or grpcclient_sync protocol_client, # Can be grpcclient_aio or grpcclient_sync
waveform, waveform,
reference_text, reference_text,
target_text, target_text,
sample_rate=16000, sample_rate=16000,
padding_duration: int = None # Optional padding for offline mode padding_duration: int = None # Optional padding for offline mode
): ):
"""Prepares inputs for Triton inference (offline or streaming).""" """Prepares inputs for Triton inference (offline or streaming)."""
assert len(waveform.shape) == 1, "waveform should be 1D" assert len(waveform.shape) == 1, "waveform should be 1D"
@@ -291,9 +295,9 @@ def prepare_request_input_output(
# Estimate target duration based on text length ratio (crude estimation) # Estimate target duration based on text length ratio (crude estimation)
# Avoid division by zero if reference_text is empty # Avoid division by zero if reference_text is empty
if reference_text: if reference_text:
estimated_target_duration = duration / len(reference_text) * len(target_text) estimated_target_duration = duration / len(reference_text) * len(target_text)
else: else:
estimated_target_duration = duration # Assume target duration similar to reference if no text estimated_target_duration = duration # Assume target duration similar to reference if no text
# Calculate required samples based on estimated total duration # Calculate required samples based on estimated total duration
required_total_samples = padding_duration * sample_rate * ( required_total_samples = padding_duration * sample_rate * (
@@ -329,6 +333,7 @@ def prepare_request_input_output(
return inputs, outputs return inputs, outputs
def run_sync_streaming_inference( def run_sync_streaming_inference(
sync_triton_client: tritonclient.grpc.InferenceServerClient, sync_triton_client: tritonclient.grpc.InferenceServerClient,
model_name: str, model_name: str,
@@ -342,7 +347,7 @@ def run_sync_streaming_inference(
): ):
"""Helper function to run the blocking sync streaming call.""" """Helper function to run the blocking sync streaming call."""
start_time_total = time.time() start_time_total = time.time()
user_data.record_start_time() # Record start time for first chunk latency calculation user_data.record_start_time() # Record start time for first chunk latency calculation
# Establish stream # Establish stream
sync_triton_client.start_stream(callback=functools.partial(callback, user_data)) sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
@@ -360,11 +365,11 @@ def run_sync_streaming_inference(
audios = [] audios = []
while True: while True:
try: try:
result = user_data._completed_requests.get() # Add timeout result = user_data._completed_requests.get() # Add timeout
if isinstance(result, InferenceServerException): if isinstance(result, InferenceServerException):
print(f"Received InferenceServerException: {result}") print(f"Received InferenceServerException: {result}")
sync_triton_client.stop_stream() sync_triton_client.stop_stream()
return None, None, None # Indicate error return None, None, None # Indicate error
# Get response metadata # Get response metadata
response = result.get_response() response = result.get_response()
final = response.parameters["triton_final_response"].bool_param final = response.parameters["triton_final_response"].bool_param
@@ -372,15 +377,15 @@ def run_sync_streaming_inference(
break break
audio_chunk = result.as_numpy("waveform").reshape(-1) audio_chunk = result.as_numpy("waveform").reshape(-1)
if audio_chunk.size > 0: # Only append non-empty chunks if audio_chunk.size > 0: # Only append non-empty chunks
audios.append(audio_chunk) audios.append(audio_chunk)
else: else:
print("Warning: received empty audio chunk.") print("Warning: received empty audio chunk.")
except queue.Empty: except queue.Empty:
print(f"Timeout waiting for response for request id {request_id}") print(f"Timeout waiting for response for request id {request_id}")
sync_triton_client.stop_stream() sync_triton_client.stop_stream()
return None, None, None # Indicate error return None, None, None # Indicate error
sync_triton_client.stop_stream() sync_triton_client.stop_stream()
end_time_total = time.time() end_time_total = time.time()
@@ -398,19 +403,19 @@ def run_sync_streaming_inference(
# Simplified reconstruction based on client_grpc_streaming.py # Simplified reconstruction based on client_grpc_streaming.py
if not audios: if not audios:
print("Warning: No audio chunks received.") print("Warning: No audio chunks received.")
reconstructed_audio = np.array([], dtype=np.float32) # Empty array reconstructed_audio = np.array([], dtype=np.float32) # Empty array
elif len(audios) == 1: elif len(audios) == 1:
reconstructed_audio = audios[0] reconstructed_audio = audios[0]
else: else:
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
for i in range(1, len(audios)): for i in range(1, len(audios)):
# Cross-fade section # Cross-fade section
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
audios[i - 1][-cross_fade_samples:] * fade_out) audios[i - 1][-cross_fade_samples:] * fade_out)
# Middle section of the current chunk # Middle section of the current chunk
middle_part = audios[i][cross_fade_samples:-cross_fade_samples] middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
# Concatenate # Concatenate
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
# Add the last part of the final chunk # Add the last part of the final chunk
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
@@ -421,11 +426,11 @@ def run_sync_streaming_inference(
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
else: else:
print("Warning: No audio chunks received or reconstructed.") print("Warning: No audio chunks received or reconstructed.")
actual_duration = 0 # Set duration to 0 if no audio actual_duration = 0 # Set duration to 0 if no audio
else: else:
print("Warning: No audio chunks received.") print("Warning: No audio chunks received.")
actual_duration = 0 actual_duration = 0
return total_request_latency, first_chunk_latency, actual_duration return total_request_latency, first_chunk_latency, actual_duration
@@ -433,7 +438,7 @@ def run_sync_streaming_inference(
async def send_streaming( async def send_streaming(
manifest_item_list: list, manifest_item_list: list,
name: str, name: str,
server_url: str, # Changed from sync_triton_client server_url: str, # Changed from sync_triton_client
protocol_client: types.ModuleType, protocol_client: types.ModuleType,
log_interval: int, log_interval: int,
model_name: str, model_name: str,
@@ -445,11 +450,11 @@ async def send_streaming(
total_duration = 0.0 total_duration = 0.0
latency_data = [] latency_data = []
task_id = int(name[5:]) task_id = int(name[5:])
sync_triton_client = None # Initialize client variable sync_triton_client = None # Initialize client variable
try: # Wrap in try...finally to ensure client closing try: # Wrap in try...finally to ensure client closing
print(f"{name}: Initializing sync client for streaming...") print(f"{name}: Initializing sync client for streaming...")
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.") print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
for i, item in enumerate(manifest_item_list): for i, item in enumerate(manifest_item_list):
@@ -491,8 +496,7 @@ async def send_streaming(
latency_data.append((total_request_latency, first_chunk_latency, actual_duration)) latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
total_duration += actual_duration total_duration += actual_duration
else: else:
print(f"{name}: Item {i} failed.") print(f"{name}: Item {i} failed.")
except FileNotFoundError: except FileNotFoundError:
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}") print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
@@ -501,8 +505,7 @@ async def send_streaming(
import traceback import traceback
traceback.print_exc() traceback.print_exc()
finally: # Ensure client is closed
finally: # Ensure client is closed
if sync_triton_client: if sync_triton_client:
try: try:
print(f"{name}: Closing sync client...") print(f"{name}: Closing sync client...")
@@ -510,10 +513,10 @@ async def send_streaming(
except Exception as e: except Exception as e:
print(f"{name}: Error closing sync client: {e}") print(f"{name}: Error closing sync client: {e}")
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s") print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
return total_duration, latency_data return total_duration, latency_data
async def send( async def send(
manifest_item_list: list, manifest_item_list: list,
name: str, name: str,
@@ -605,6 +608,7 @@ def split_data(data, k):
return result return result
async def main(): async def main():
args = get_args() args = get_args()
url = f"{args.server_addr}:{args.server_port}" url = f"{args.server_addr}:{args.server_port}"
@@ -622,7 +626,7 @@ async def main():
# Use the sync client for streaming tasks, handled via asyncio.to_thread # Use the sync client for streaming tasks, handled via asyncio.to_thread
# We will create one sync client instance PER TASK inside send_streaming. # We will create one sync client instance PER TASK inside send_streaming.
# triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
protocol_client = grpcclient_sync # protocol client for input prep protocol_client = grpcclient_sync # protocol client for input prep
else: else:
raise ValueError(f"Invalid mode: {args.mode}") raise ValueError(f"Invalid mode: {args.mode}")
# --- End Client Initialization --- # --- End Client Initialization ---
@@ -682,11 +686,11 @@ async def main():
) )
) )
elif args.mode == "streaming": elif args.mode == "streaming":
task = asyncio.create_task( task = asyncio.create_task(
send_streaming( send_streaming(
manifest_item_list[i], manifest_item_list[i],
name=f"task-{i}", name=f"task-{i}",
server_url=url, # Pass URL instead of client server_url=url, # Pass URL instead of client
protocol_client=protocol_client, protocol_client=protocol_client,
log_interval=args.log_interval, log_interval=args.log_interval,
model_name=args.model_name, model_name=args.model_name,
@@ -709,16 +713,15 @@ async def main():
for ans in ans_list: for ans in ans_list:
if ans: if ans:
total_duration += ans[0] total_duration += ans[0]
latency_data.extend(ans[1]) # Use extend for list of lists latency_data.extend(ans[1]) # Use extend for list of lists
else: else:
print("Warning: A task returned None, possibly due to an error.") print("Warning: A task returned None, possibly due to an error.")
if total_duration == 0: if total_duration == 0:
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.") print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
rtf = float('inf') rtf = float('inf')
else: else:
rtf = elapsed / total_duration rtf = elapsed / total_duration
s = f"Mode: {args.mode}\n" s = f"Mode: {args.mode}\n"
s += f"RTF: {rtf:.4f}\n" s += f"RTF: {rtf:.4f}\n"
@@ -759,7 +762,7 @@ async def main():
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n" s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n" s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
else: else:
s += "No total request latency data collected.\n" s += "No total request latency data collected.\n"
s += "\n--- First Chunk Latency ---\n" s += "\n--- First Chunk Latency ---\n"
if first_chunk_latency_list: if first_chunk_latency_list:
@@ -772,7 +775,7 @@ async def main():
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n" s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n" s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
else: else:
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n" s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
else: else:
s += "No latency data collected.\n" s += "No latency data collected.\n"
# --- End Statistics Reporting --- # --- End Statistics Reporting ---
@@ -785,7 +788,7 @@ async def main():
elif args.reference_audio: elif args.reference_audio:
name = Path(args.reference_audio).stem name = Path(args.reference_audio).stem
else: else:
name = "results" # Default name if no manifest/split/audio provided name = "results" # Default name if no manifest/split/audio provided
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
f.write(s) f.write(s)

View File

@@ -29,6 +29,7 @@ import json
import numpy as np import numpy as np
import argparse import argparse
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -67,9 +68,10 @@ def get_args():
type=str, type=str,
default="spark_tts", default="spark_tts",
choices=[ choices=[
"f5_tts", "spark_tts", "cosyvoice2" "f5_tts",
], "spark_tts",
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", "cosyvoice2"],
help="triton model_repo module name to request",
) )
parser.add_argument( parser.add_argument(
@@ -80,6 +82,7 @@ def get_args():
) )
return parser.parse_args() return parser.parse_args()
def prepare_request( def prepare_request(
waveform, waveform,
reference_text, reference_text,
@@ -97,7 +100,7 @@ def prepare_request(
1, 1,
padding_duration padding_duration
* sample_rate * sample_rate
* ((int(duration) // padding_duration) + 1), * ((int(len(waveform) / sample_rate) // padding_duration) + 1),
), ),
dtype=np.float32, dtype=np.float32,
) )
@@ -109,7 +112,7 @@ def prepare_request(
samples = samples.reshape(1, -1).astype(np.float32) samples = samples.reshape(1, -1).astype(np.float32)
data = { data = {
"inputs":[ "inputs": [
{ {
"name": "reference_wav", "name": "reference_wav",
"shape": samples.shape, "shape": samples.shape,
@@ -139,6 +142,7 @@ def prepare_request(
return data return data
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
server_url = args.server_url server_url = args.server_url

View File

@@ -35,6 +35,7 @@ import s3tokenizer
ORIGINAL_VOCAB_SIZE = 151663 ORIGINAL_VOCAB_SIZE = 151663
class TritonPythonModel: class TritonPythonModel:
"""Triton Python model for audio tokenization. """Triton Python model for audio tokenization.

View File

@@ -42,6 +42,7 @@ import onnxruntime
from matcha.utils.audio import mel_spectrogram from matcha.utils.audio import mel_spectrogram
class TritonPythonModel: class TritonPythonModel:
"""Triton Python model for Spark TTS. """Triton Python model for Spark TTS.
@@ -187,7 +188,12 @@ class TritonPythonModel:
return prompt_speech_tokens return prompt_speech_tokens
def forward_token2wav(self, prompt_speech_tokens: torch.Tensor, prompt_speech_feat: torch.Tensor, prompt_spk_embedding: torch.Tensor, target_speech_tokens: torch.Tensor) -> torch.Tensor: def forward_token2wav(
self,
prompt_speech_tokens: torch.Tensor,
prompt_speech_feat: torch.Tensor,
prompt_spk_embedding: torch.Tensor,
target_speech_tokens: torch.Tensor) -> torch.Tensor:
"""Forward pass through the vocoder component. """Forward pass through the vocoder component.
Args: Args:
@@ -231,18 +237,29 @@ class TritonPythonModel:
def _extract_spk_embedding(self, speech): def _extract_spk_embedding(self, speech):
feat = kaldi.fbank(speech, feat = kaldi.fbank(speech,
num_mel_bins=80, num_mel_bins=80,
dither=0, dither=0,
sample_frequency=16000) sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True) feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.campplus_session.run(None, embedding = self.campplus_session.run(None,
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device).half() embedding = torch.tensor([embedding]).to(self.device).half()
return embedding return embedding
def _extract_speech_feat(self, speech): def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(speech, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000).squeeze(dim=0).transpose(0, 1).to(self.device) speech_feat = mel_spectrogram(
speech,
n_fft=1920,
num_mels=80,
sampling_rate=24000,
hop_size=480,
win_size=1920,
fmin=0,
fmax=8000).squeeze(
dim=0).transpose(
0,
1).to(
self.device)
speech_feat = speech_feat.unsqueeze(dim=0) speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat return speech_feat
@@ -267,7 +284,6 @@ class TritonPythonModel:
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
wav_tensor = wav.as_numpy() wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
@@ -311,7 +327,7 @@ class TritonPythonModel:
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response) response_sender.send(inference_response)
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info(f"send tritonserver_response_complete_final to end") self.logger.log_info("send tritonserver_response_complete_final to end")
else: else:
generated_ids = next(generated_ids_iter) generated_ids = next(generated_ids_iter)
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device) generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)

View File

@@ -44,6 +44,7 @@ logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663 ORIGINAL_VOCAB_SIZE = 151663
class CosyVoice2: class CosyVoice2:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
@@ -66,6 +67,7 @@ class CosyVoice2:
trt_concurrent, trt_concurrent,
self.fp16) self.fp16)
class CosyVoice2Model: class CosyVoice2Model:
def __init__(self, def __init__(self,
@@ -109,6 +111,7 @@ class CosyVoice2Model:
input_names = ["x", "mask", "mu", "cond"] input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
class TritonPythonModel: class TritonPythonModel:
"""Triton Python model for vocoder. """Triton Python model for vocoder.
@@ -137,7 +140,6 @@ class TritonPythonModel:
logger.info("Token2Wav initialized successfully") logger.info("Token2Wav initialized successfully")
def execute(self, requests): def execute(self, requests):
"""Execute inference on the batched requests. """Execute inference on the batched requests.
@@ -191,7 +193,3 @@ class TritonPythonModel:
responses.append(inference_response) responses.append(inference_response)
return responses return responses

View File

@@ -35,8 +35,7 @@ def parse_arguments():
type=str, type=str,
default='auto', default='auto',
choices=['auto', 'float16', 'bfloat16', 'float32'], choices=['auto', 'float16', 'bfloat16', 'float32'],
help= help="The data type for the model weights and activations if not quantized. "
"The data type for the model weights and activations if not quantized. "
"If 'auto', the data type is automatically inferred from the source model; " "If 'auto', the data type is automatically inferred from the source model; "
"however, if the source dtype is float32, it is converted to float16.") "however, if the source dtype is float32, it is converted to float16.")
parser.add_argument( parser.add_argument(
@@ -49,8 +48,7 @@ def parse_arguments():
'--disable_weight_only_quant_plugin', '--disable_weight_only_quant_plugin',
default=False, default=False,
action="store_true", action="store_true",
help= help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
'You must also use --use_weight_only for that argument to have an impact.' 'You must also use --use_weight_only for that argument to have an impact.'
) )
parser.add_argument( parser.add_argument(
@@ -60,16 +58,14 @@ def parse_arguments():
nargs='?', nargs='?',
default='int8', default='int8',
choices=['int8', 'int4', 'int4_gptq'], choices=['int8', 'int4', 'int4_gptq'],
help= help='Define the precision for the weights when using weight-only quantization.'
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.' 'You must also use --use_weight_only for that argument to have an impact.'
) )
parser.add_argument( parser.add_argument(
'--calib_dataset', '--calib_dataset',
type=str, type=str,
default='ccdv/cnn_dailymail', default='ccdv/cnn_dailymail',
help= help="The huggingface dataset name or the local directory of the dataset for calibration."
"The huggingface dataset name or the local directory of the dataset for calibration."
) )
parser.add_argument( parser.add_argument(
"--smoothquant", "--smoothquant",
@@ -83,31 +79,27 @@ def parse_arguments():
'--per_channel', '--per_channel',
action="store_true", action="store_true",
default=False, default=False,
help= help='By default, we use a single static scaling factor for the GEMM\'s result. '
'By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. ' 'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.') 'The latter is usually more accurate, but a little slower.')
parser.add_argument( parser.add_argument(
'--per_token', '--per_token',
action="store_true", action="store_true",
default=False, default=False,
help= help='By default, we use a single static scaling factor to scale activations in the int8 range. '
'By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. ' 'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.') 'The latter is usually more accurate, but a little slower.')
parser.add_argument( parser.add_argument(
'--int8_kv_cache', '--int8_kv_cache',
default=False, default=False,
action="store_true", action="store_true",
help= help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
) )
parser.add_argument( parser.add_argument(
'--per_group', '--per_group',
default=False, default=False,
action="store_true", action="store_true",
help= help='By default, we use a single static scaling factor to scale weights in the int4 range. '
'By default, we use a single static scaling factor to scale weights in the int4 range. '
'per_group chooses at run time, and for each group, a custom scaling factor. ' 'per_group chooses at run time, and for each group, a custom scaling factor. '
'The flag is built for GPTQ/AWQ quantization.') 'The flag is built for GPTQ/AWQ quantization.')
@@ -121,16 +113,14 @@ def parse_arguments():
'--use_parallel_embedding', '--use_parallel_embedding',
action="store_true", action="store_true",
default=False, default=False,
help= help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
) )
parser.add_argument( parser.add_argument(
'--embedding_sharding_dim', '--embedding_sharding_dim',
type=int, type=int,
default=0, default=0,
choices=[0, 1], choices=[0, 1],
help= help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1' 'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
) )
@@ -147,15 +137,13 @@ def parse_arguments():
'--moe_tp_size', '--moe_tp_size',
type=int, type=int,
default=-1, default=-1,
help= help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
) )
parser.add_argument( parser.add_argument(
'--moe_ep_size', '--moe_ep_size',
type=int, type=int,
default=-1, default=-1,
help= help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
) )
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -249,7 +237,7 @@ def convert_and_save_hf(args):
trust_remote_code=True) trust_remote_code=True)
quant_config, override_fields = update_quant_config_from_hf( quant_config, override_fields = update_quant_config_from_hf(
quant_config, hf_config, override_fields) quant_config, hf_config, override_fields)
except: except BaseException:
logger.warning("AutoConfig cannot load the huggingface config.") logger.warning("AutoConfig cannot load the huggingface config.")
if args.smoothquant is not None or args.int8_kv_cache: if args.smoothquant is not None or args.int8_kv_cache:

View File

@@ -1,4 +1,4 @@
#! /usr/bin/env python3 # /usr/bin/env python3
from argparse import ArgumentParser from argparse import ArgumentParser
from string import Template from string import Template
@@ -59,8 +59,7 @@ if __name__ == "__main__":
parser.add_argument("file_path", help="path of the .pbtxt to modify") parser.add_argument("file_path", help="path of the .pbtxt to modify")
parser.add_argument( parser.add_argument(
"substitutions", "substitutions",
help= help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
"substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
) )
parser.add_argument("--in_place", parser.add_argument("--in_place",
"-i", "-i",

View File

@@ -46,7 +46,6 @@ def parse_arguments(args=None):
parser.add_argument('--top_k', type=int, default=50) parser.add_argument('--top_k', type=int, default=50)
parser.add_argument('--top_p', type=float, default=0.95) parser.add_argument('--top_p', type=float, default=0.95)
return parser.parse_args(args=args) return parser.parse_args(args=args)