This commit is contained in:
root
2025-07-29 08:39:41 +00:00
parent 1b8d194b67
commit 07cbc51cd1
8 changed files with 165 additions and 157 deletions

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Nvidia (authors: Yuekai Zhang) # 2023 Nvidia (authors: Yuekai Zhang)
# 2023 Recurrent.ai (authors: Songtao Shi) # 2023 Recurrent.ai (authors: Songtao Shi)
@@ -46,7 +45,7 @@ import asyncio
import json import json
import queue # Added import queue # Added
import uuid # Added import uuid # Added
import functools # Added import functools # Added
import os import os
import time import time
@@ -56,9 +55,9 @@ from pathlib import Path
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf
import tritonclient import tritonclient
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
import tritonclient.grpc as grpcclient_sync # Added sync client import import tritonclient.grpc as grpcclient_sync # Added sync client import
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
# --- Added UserData and callback --- # --- Added UserData and callback ---
@@ -76,9 +75,10 @@ class UserData:
return self._first_chunk_time - self._start_time return self._first_chunk_time - self._start_time
return None return None
def callback(user_data, result, error): def callback(user_data, result, error):
if user_data._first_chunk_time is None and not error: if user_data._first_chunk_time is None and not error:
user_data._first_chunk_time = time.time() # Record time of first successful chunk user_data._first_chunk_time = time.time() # Record time of first successful chunk
if error: if error:
user_data._completed_requests.put(error) user_data._completed_requests.put(error)
else: else:
@@ -206,8 +206,11 @@ def get_args():
"--model-name", "--model-name",
type=str, type=str,
default="f5_tts", default="f5_tts",
choices=["f5_tts", "spark_tts", "cosyvoice2"], choices=[
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", "f5_tts",
"spark_tts",
"cosyvoice2"],
help="triton model_repo module name to request",
) )
parser.add_argument( parser.add_argument(
@@ -273,13 +276,14 @@ def load_audio(wav_path, target_sample_rate=16000):
waveform = resample(waveform, num_samples) waveform = resample(waveform, num_samples)
return waveform, target_sample_rate return waveform, target_sample_rate
def prepare_request_input_output( def prepare_request_input_output(
protocol_client, # Can be grpcclient_aio or grpcclient_sync protocol_client, # Can be grpcclient_aio or grpcclient_sync
waveform, waveform,
reference_text, reference_text,
target_text, target_text,
sample_rate=16000, sample_rate=16000,
padding_duration: int = None # Optional padding for offline mode padding_duration: int = None # Optional padding for offline mode
): ):
"""Prepares inputs for Triton inference (offline or streaming).""" """Prepares inputs for Triton inference (offline or streaming)."""
assert len(waveform.shape) == 1, "waveform should be 1D" assert len(waveform.shape) == 1, "waveform should be 1D"
@@ -291,9 +295,9 @@ def prepare_request_input_output(
# Estimate target duration based on text length ratio (crude estimation) # Estimate target duration based on text length ratio (crude estimation)
# Avoid division by zero if reference_text is empty # Avoid division by zero if reference_text is empty
if reference_text: if reference_text:
estimated_target_duration = duration / len(reference_text) * len(target_text) estimated_target_duration = duration / len(reference_text) * len(target_text)
else: else:
estimated_target_duration = duration # Assume target duration similar to reference if no text estimated_target_duration = duration # Assume target duration similar to reference if no text
# Calculate required samples based on estimated total duration # Calculate required samples based on estimated total duration
required_total_samples = padding_duration * sample_rate * ( required_total_samples = padding_duration * sample_rate * (
@@ -329,6 +333,7 @@ def prepare_request_input_output(
return inputs, outputs return inputs, outputs
def run_sync_streaming_inference( def run_sync_streaming_inference(
sync_triton_client: tritonclient.grpc.InferenceServerClient, sync_triton_client: tritonclient.grpc.InferenceServerClient,
model_name: str, model_name: str,
@@ -342,7 +347,7 @@ def run_sync_streaming_inference(
): ):
"""Helper function to run the blocking sync streaming call.""" """Helper function to run the blocking sync streaming call."""
start_time_total = time.time() start_time_total = time.time()
user_data.record_start_time() # Record start time for first chunk latency calculation user_data.record_start_time() # Record start time for first chunk latency calculation
# Establish stream # Establish stream
sync_triton_client.start_stream(callback=functools.partial(callback, user_data)) sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
@@ -360,11 +365,11 @@ def run_sync_streaming_inference(
audios = [] audios = []
while True: while True:
try: try:
result = user_data._completed_requests.get() # Add timeout result = user_data._completed_requests.get() # Add timeout
if isinstance(result, InferenceServerException): if isinstance(result, InferenceServerException):
print(f"Received InferenceServerException: {result}") print(f"Received InferenceServerException: {result}")
sync_triton_client.stop_stream() sync_triton_client.stop_stream()
return None, None, None # Indicate error return None, None, None # Indicate error
# Get response metadata # Get response metadata
response = result.get_response() response = result.get_response()
final = response.parameters["triton_final_response"].bool_param final = response.parameters["triton_final_response"].bool_param
@@ -372,15 +377,15 @@ def run_sync_streaming_inference(
break break
audio_chunk = result.as_numpy("waveform").reshape(-1) audio_chunk = result.as_numpy("waveform").reshape(-1)
if audio_chunk.size > 0: # Only append non-empty chunks if audio_chunk.size > 0: # Only append non-empty chunks
audios.append(audio_chunk) audios.append(audio_chunk)
else: else:
print("Warning: received empty audio chunk.") print("Warning: received empty audio chunk.")
except queue.Empty: except queue.Empty:
print(f"Timeout waiting for response for request id {request_id}") print(f"Timeout waiting for response for request id {request_id}")
sync_triton_client.stop_stream() sync_triton_client.stop_stream()
return None, None, None # Indicate error return None, None, None # Indicate error
sync_triton_client.stop_stream() sync_triton_client.stop_stream()
end_time_total = time.time() end_time_total = time.time()
@@ -398,19 +403,19 @@ def run_sync_streaming_inference(
# Simplified reconstruction based on client_grpc_streaming.py # Simplified reconstruction based on client_grpc_streaming.py
if not audios: if not audios:
print("Warning: No audio chunks received.") print("Warning: No audio chunks received.")
reconstructed_audio = np.array([], dtype=np.float32) # Empty array reconstructed_audio = np.array([], dtype=np.float32) # Empty array
elif len(audios) == 1: elif len(audios) == 1:
reconstructed_audio = audios[0] reconstructed_audio = audios[0]
else: else:
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
for i in range(1, len(audios)): for i in range(1, len(audios)):
# Cross-fade section # Cross-fade section
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
audios[i - 1][-cross_fade_samples:] * fade_out) audios[i - 1][-cross_fade_samples:] * fade_out)
# Middle section of the current chunk # Middle section of the current chunk
middle_part = audios[i][cross_fade_samples:-cross_fade_samples] middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
# Concatenate # Concatenate
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
# Add the last part of the final chunk # Add the last part of the final chunk
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
@@ -421,11 +426,11 @@ def run_sync_streaming_inference(
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
else: else:
print("Warning: No audio chunks received or reconstructed.") print("Warning: No audio chunks received or reconstructed.")
actual_duration = 0 # Set duration to 0 if no audio actual_duration = 0 # Set duration to 0 if no audio
else: else:
print("Warning: No audio chunks received.") print("Warning: No audio chunks received.")
actual_duration = 0 actual_duration = 0
return total_request_latency, first_chunk_latency, actual_duration return total_request_latency, first_chunk_latency, actual_duration
@@ -433,7 +438,7 @@ def run_sync_streaming_inference(
async def send_streaming( async def send_streaming(
manifest_item_list: list, manifest_item_list: list,
name: str, name: str,
server_url: str, # Changed from sync_triton_client server_url: str, # Changed from sync_triton_client
protocol_client: types.ModuleType, protocol_client: types.ModuleType,
log_interval: int, log_interval: int,
model_name: str, model_name: str,
@@ -445,11 +450,11 @@ async def send_streaming(
total_duration = 0.0 total_duration = 0.0
latency_data = [] latency_data = []
task_id = int(name[5:]) task_id = int(name[5:])
sync_triton_client = None # Initialize client variable sync_triton_client = None # Initialize client variable
try: # Wrap in try...finally to ensure client closing try: # Wrap in try...finally to ensure client closing
print(f"{name}: Initializing sync client for streaming...") print(f"{name}: Initializing sync client for streaming...")
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.") print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
for i, item in enumerate(manifest_item_list): for i, item in enumerate(manifest_item_list):
@@ -491,8 +496,7 @@ async def send_streaming(
latency_data.append((total_request_latency, first_chunk_latency, actual_duration)) latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
total_duration += actual_duration total_duration += actual_duration
else: else:
print(f"{name}: Item {i} failed.") print(f"{name}: Item {i} failed.")
except FileNotFoundError: except FileNotFoundError:
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}") print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
@@ -501,8 +505,7 @@ async def send_streaming(
import traceback import traceback
traceback.print_exc() traceback.print_exc()
finally: # Ensure client is closed
finally: # Ensure client is closed
if sync_triton_client: if sync_triton_client:
try: try:
print(f"{name}: Closing sync client...") print(f"{name}: Closing sync client...")
@@ -510,10 +513,10 @@ async def send_streaming(
except Exception as e: except Exception as e:
print(f"{name}: Error closing sync client: {e}") print(f"{name}: Error closing sync client: {e}")
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s") print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
return total_duration, latency_data return total_duration, latency_data
async def send( async def send(
manifest_item_list: list, manifest_item_list: list,
name: str, name: str,
@@ -605,6 +608,7 @@ def split_data(data, k):
return result return result
async def main(): async def main():
args = get_args() args = get_args()
url = f"{args.server_addr}:{args.server_port}" url = f"{args.server_addr}:{args.server_port}"
@@ -622,7 +626,7 @@ async def main():
# Use the sync client for streaming tasks, handled via asyncio.to_thread # Use the sync client for streaming tasks, handled via asyncio.to_thread
# We will create one sync client instance PER TASK inside send_streaming. # We will create one sync client instance PER TASK inside send_streaming.
# triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
protocol_client = grpcclient_sync # protocol client for input prep protocol_client = grpcclient_sync # protocol client for input prep
else: else:
raise ValueError(f"Invalid mode: {args.mode}") raise ValueError(f"Invalid mode: {args.mode}")
# --- End Client Initialization --- # --- End Client Initialization ---
@@ -682,11 +686,11 @@ async def main():
) )
) )
elif args.mode == "streaming": elif args.mode == "streaming":
task = asyncio.create_task( task = asyncio.create_task(
send_streaming( send_streaming(
manifest_item_list[i], manifest_item_list[i],
name=f"task-{i}", name=f"task-{i}",
server_url=url, # Pass URL instead of client server_url=url, # Pass URL instead of client
protocol_client=protocol_client, protocol_client=protocol_client,
log_interval=args.log_interval, log_interval=args.log_interval,
model_name=args.model_name, model_name=args.model_name,
@@ -709,16 +713,15 @@ async def main():
for ans in ans_list: for ans in ans_list:
if ans: if ans:
total_duration += ans[0] total_duration += ans[0]
latency_data.extend(ans[1]) # Use extend for list of lists latency_data.extend(ans[1]) # Use extend for list of lists
else: else:
print("Warning: A task returned None, possibly due to an error.") print("Warning: A task returned None, possibly due to an error.")
if total_duration == 0: if total_duration == 0:
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.") print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
rtf = float('inf') rtf = float('inf')
else: else:
rtf = elapsed / total_duration rtf = elapsed / total_duration
s = f"Mode: {args.mode}\n" s = f"Mode: {args.mode}\n"
s += f"RTF: {rtf:.4f}\n" s += f"RTF: {rtf:.4f}\n"
@@ -759,7 +762,7 @@ async def main():
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n" s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n" s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
else: else:
s += "No total request latency data collected.\n" s += "No total request latency data collected.\n"
s += "\n--- First Chunk Latency ---\n" s += "\n--- First Chunk Latency ---\n"
if first_chunk_latency_list: if first_chunk_latency_list:
@@ -772,7 +775,7 @@ async def main():
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n" s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n" s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
else: else:
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n" s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
else: else:
s += "No latency data collected.\n" s += "No latency data collected.\n"
# --- End Statistics Reporting --- # --- End Statistics Reporting ---
@@ -785,7 +788,7 @@ async def main():
elif args.reference_audio: elif args.reference_audio:
name = Path(args.reference_audio).stem name = Path(args.reference_audio).stem
else: else:
name = "results" # Default name if no manifest/split/audio provided name = "results" # Default name if no manifest/split/audio provided
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
f.write(s) f.write(s)

View File

@@ -29,6 +29,7 @@ import json
import numpy as np import numpy as np
import argparse import argparse
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -67,9 +68,10 @@ def get_args():
type=str, type=str,
default="spark_tts", default="spark_tts",
choices=[ choices=[
"f5_tts", "spark_tts", "cosyvoice2" "f5_tts",
], "spark_tts",
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", "cosyvoice2"],
help="triton model_repo module name to request",
) )
parser.add_argument( parser.add_argument(
@@ -80,6 +82,7 @@ def get_args():
) )
return parser.parse_args() return parser.parse_args()
def prepare_request( def prepare_request(
waveform, waveform,
reference_text, reference_text,
@@ -97,7 +100,7 @@ def prepare_request(
1, 1,
padding_duration padding_duration
* sample_rate * sample_rate
* ((int(duration) // padding_duration) + 1), * ((int(len(waveform) / sample_rate) // padding_duration) + 1),
), ),
dtype=np.float32, dtype=np.float32,
) )
@@ -105,11 +108,11 @@ def prepare_request(
samples[0, : len(waveform)] = waveform samples[0, : len(waveform)] = waveform
else: else:
samples = waveform samples = waveform
samples = samples.reshape(1, -1).astype(np.float32) samples = samples.reshape(1, -1).astype(np.float32)
data = { data = {
"inputs":[ "inputs": [
{ {
"name": "reference_wav", "name": "reference_wav",
"shape": samples.shape, "shape": samples.shape,
@@ -139,16 +142,17 @@ def prepare_request(
return data return data
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
server_url = args.server_url server_url = args.server_url
if not server_url.startswith(("http://", "https://")): if not server_url.startswith(("http://", "https://")):
server_url = f"http://{server_url}" server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer" url = f"{server_url}/v2/models/{args.model_name}/infer"
waveform, sr = sf.read(args.reference_audio) waveform, sr = sf.read(args.reference_audio)
assert sr == 16000, "sample rate hardcoded in server" assert sr == 16000, "sample rate hardcoded in server"
samples = np.array(waveform, dtype=np.float32) samples = np.array(waveform, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text) data = prepare_request(samples, args.reference_text, args.target_text)
@@ -166,4 +170,4 @@ if __name__ == "__main__":
sample_rate = 16000 sample_rate = 16000
else: else:
sample_rate = 24000 sample_rate = 24000
sf.write(args.output_audio, audio, sample_rate, "PCM_16") sf.write(args.output_audio, audio, sample_rate, "PCM_16")

View File

@@ -35,33 +35,34 @@ import s3tokenizer
ORIGINAL_VOCAB_SIZE = 151663 ORIGINAL_VOCAB_SIZE = 151663
class TritonPythonModel: class TritonPythonModel:
"""Triton Python model for audio tokenization. """Triton Python model for audio tokenization.
This model takes reference audio input and extracts semantic tokens This model takes reference audio input and extracts semantic tokens
using s3tokenizer. using s3tokenizer.
""" """
def initialize(self, args): def initialize(self, args):
"""Initialize the model. """Initialize the model.
Args: Args:
args: Dictionary containing model configuration args: Dictionary containing model configuration
""" """
# Parse model parameters # Parse model parameters
parameters = json.loads(args['model_config'])['parameters'] parameters = json.loads(args['model_config'])['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()} model_params = {k: v["string_value"] for k, v in parameters.items()}
self.device = torch.device("cuda") self.device = torch.device("cuda")
model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx") model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device) self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
def execute(self, requests): def execute(self, requests):
"""Execute inference on the batched requests. """Execute inference on the batched requests.
Args: Args:
requests: List of inference requests requests: List of inference requests
Returns: Returns:
List of inference responses containing tokenized outputs List of inference responses containing tokenized outputs
""" """
@@ -79,18 +80,18 @@ class TritonPythonModel:
# Prepare inputs # Prepare inputs
wav = wav_array[:, :wav_len].squeeze(0) wav = wav_array[:, :wav_len].squeeze(0)
mels.append(s3tokenizer.log_mel_spectrogram(wav)) mels.append(s3tokenizer.log_mel_spectrogram(wav))
mels, mels_lens = s3tokenizer.padding(mels) mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device)) codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
codes = codes.clone() + ORIGINAL_VOCAB_SIZE codes = codes.clone() + ORIGINAL_VOCAB_SIZE
responses = [] responses = []
for i in range(len(requests)): for i in range(len(requests)):
prompt_speech_tokens = codes[i, :codes_lens[i].item()] prompt_speech_tokens = codes[i, :codes_lens[i].item()]
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack( prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
"prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) "prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
inference_response = pb_utils.InferenceResponse( inference_response = pb_utils.InferenceResponse(
output_tensors=[prompt_speech_tokens_tensor]) output_tensors=[prompt_speech_tokens_tensor])
responses.append(inference_response) responses.append(inference_response)
return responses return responses

View File

@@ -42,16 +42,17 @@ import onnxruntime
from matcha.utils.audio import mel_spectrogram from matcha.utils.audio import mel_spectrogram
class TritonPythonModel: class TritonPythonModel:
"""Triton Python model for Spark TTS. """Triton Python model for Spark TTS.
This model orchestrates the end-to-end TTS pipeline by coordinating This model orchestrates the end-to-end TTS pipeline by coordinating
between audio tokenizer, LLM, and vocoder components. between audio tokenizer, LLM, and vocoder components.
""" """
def initialize(self, args): def initialize(self, args):
"""Initialize the model. """Initialize the model.
Args: Args:
args: Dictionary containing model configuration args: Dictionary containing model configuration
""" """
@@ -116,58 +117,58 @@ class TritonPythonModel:
"input_ids": input_ids, "input_ids": input_ids,
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
} }
# Convert inputs to Triton tensors # Convert inputs to Triton tensors
input_tensor_list = [ input_tensor_list = [
pb_utils.Tensor(k, v) for k, v in input_dict.items() pb_utils.Tensor(k, v) for k, v in input_dict.items()
] ]
# Create and execute inference request # Create and execute inference request
llm_request = pb_utils.InferenceRequest( llm_request = pb_utils.InferenceRequest(
model_name="tensorrt_llm", model_name="tensorrt_llm",
requested_output_names=["output_ids", "sequence_length"], requested_output_names=["output_ids", "sequence_length"],
inputs=input_tensor_list, inputs=input_tensor_list,
) )
llm_responses = llm_request.exec(decoupled=self.decoupled) llm_responses = llm_request.exec(decoupled=self.decoupled)
if self.decoupled: if self.decoupled:
for llm_response in llm_responses: for llm_response in llm_responses:
if llm_response.has_error(): if llm_response.has_error():
raise pb_utils.TritonModelException(llm_response.error().message()) raise pb_utils.TritonModelException(llm_response.error().message())
# Extract and process output # Extract and process output
output_ids = pb_utils.get_output_tensor_by_name( output_ids = pb_utils.get_output_tensor_by_name(
llm_response, "output_ids").as_numpy() llm_response, "output_ids").as_numpy()
seq_lens = pb_utils.get_output_tensor_by_name( seq_lens = pb_utils.get_output_tensor_by_name(
llm_response, "sequence_length").as_numpy() llm_response, "sequence_length").as_numpy()
# Get actual output IDs up to the sequence length # Get actual output IDs up to the sequence length
actual_output_ids = output_ids[0][0][:seq_lens[0][0]] actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
yield actual_output_ids yield actual_output_ids
else: else:
llm_response = llm_responses llm_response = llm_responses
if llm_response.has_error(): if llm_response.has_error():
raise pb_utils.TritonModelException(llm_response.error().message()) raise pb_utils.TritonModelException(llm_response.error().message())
# Extract and process output # Extract and process output
output_ids = pb_utils.get_output_tensor_by_name( output_ids = pb_utils.get_output_tensor_by_name(
llm_response, "output_ids").as_numpy() llm_response, "output_ids").as_numpy()
seq_lens = pb_utils.get_output_tensor_by_name( seq_lens = pb_utils.get_output_tensor_by_name(
llm_response, "sequence_length").as_numpy() llm_response, "sequence_length").as_numpy()
# Get actual output IDs up to the sequence length # Get actual output IDs up to the sequence length
actual_output_ids = output_ids[0][0][:seq_lens[0][0]] actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
yield actual_output_ids yield actual_output_ids
def forward_audio_tokenizer(self, wav, wav_len): def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component. """Forward pass through the audio tokenizer component.
Args: Args:
wav: Input waveform tensor wav: Input waveform tensor
wav_len: Waveform length tensor wav_len: Waveform length tensor
Returns: Returns:
Tuple of global and semantic tokens Tuple of global and semantic tokens
""" """
@@ -176,26 +177,31 @@ class TritonPythonModel:
requested_output_names=['prompt_speech_tokens'], requested_output_names=['prompt_speech_tokens'],
inputs=[wav, wav_len] inputs=[wav, wav_len]
) )
inference_response = inference_request.exec() inference_response = inference_request.exec()
if inference_response.has_error(): if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message()) raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors # Extract and convert output tensors
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
return prompt_speech_tokens return prompt_speech_tokens
def forward_token2wav(self, prompt_speech_tokens: torch.Tensor, prompt_speech_feat: torch.Tensor, prompt_spk_embedding: torch.Tensor, target_speech_tokens: torch.Tensor) -> torch.Tensor: def forward_token2wav(
self,
prompt_speech_tokens: torch.Tensor,
prompt_speech_feat: torch.Tensor,
prompt_spk_embedding: torch.Tensor,
target_speech_tokens: torch.Tensor) -> torch.Tensor:
"""Forward pass through the vocoder component. """Forward pass through the vocoder component.
Args: Args:
prompt_speech_tokens: Prompt speech tokens tensor prompt_speech_tokens: Prompt speech tokens tensor
prompt_speech_feat: Prompt speech feat tensor prompt_speech_feat: Prompt speech feat tensor
prompt_spk_embedding: Prompt spk embedding tensor prompt_spk_embedding: Prompt spk embedding tensor
target_speech_tokens: Target speech tokens tensor target_speech_tokens: Target speech tokens tensor
Returns: Returns:
Generated waveform tensor Generated waveform tensor
""" """
@@ -203,22 +209,22 @@ class TritonPythonModel:
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat)) prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding)) prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
# Create and execute inference request # Create and execute inference request
inference_request = pb_utils.InferenceRequest( inference_request = pb_utils.InferenceRequest(
model_name='token2wav', model_name='token2wav',
requested_output_names=['waveform'], requested_output_names=['waveform'],
inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor] inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
) )
inference_response = inference_request.exec() inference_response = inference_request.exec()
if inference_response.has_error(): if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message()) raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output waveform # Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform return waveform
def parse_input(self, text, prompt_text, prompt_speech_tokens): def parse_input(self, text, prompt_text, prompt_speech_tokens):
@@ -231,43 +237,53 @@ class TritonPythonModel:
def _extract_spk_embedding(self, speech): def _extract_spk_embedding(self, speech):
feat = kaldi.fbank(speech, feat = kaldi.fbank(speech,
num_mel_bins=80, num_mel_bins=80,
dither=0, dither=0,
sample_frequency=16000) sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True) feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.campplus_session.run(None, embedding = self.campplus_session.run(None,
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device).half() embedding = torch.tensor([embedding]).to(self.device).half()
return embedding return embedding
def _extract_speech_feat(self, speech): def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(speech, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000).squeeze(dim=0).transpose(0, 1).to(self.device) speech_feat = mel_spectrogram(
speech,
n_fft=1920,
num_mels=80,
sampling_rate=24000,
hop_size=480,
win_size=1920,
fmin=0,
fmax=8000).squeeze(
dim=0).transpose(
0,
1).to(
self.device)
speech_feat = speech_feat.unsqueeze(dim=0) speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat return speech_feat
def execute(self, requests): def execute(self, requests):
"""Execute inference on the batched requests. """Execute inference on the batched requests.
Args: Args:
requests: List of inference requests requests: List of inference requests
Returns: Returns:
List of inference responses containing generated audio List of inference responses containing generated audio
""" """
responses = [] responses = []
for request in requests: for request in requests:
# Extract input tensors # Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
# Process reference audio through audio tokenizer # Process reference audio through audio tokenizer
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
wav_tensor = wav.as_numpy() wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
@@ -275,20 +291,20 @@ class TritonPythonModel:
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8') reference_text = reference_text[0][0].decode('utf-8')
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8') target_text = target_text[0][0].decode('utf-8')
# Prepare prompt for LLM # Prepare prompt for LLM
input_ids = self.parse_input( input_ids = self.parse_input(
text=target_text, text=target_text,
prompt_text=reference_text, prompt_text=reference_text,
prompt_speech_tokens=prompt_speech_tokens, prompt_speech_tokens=prompt_speech_tokens,
) )
# Generate semantic tokens with LLM # Generate semantic tokens with LLM
generated_ids_iter = self.forward_llm(input_ids) generated_ids_iter = self.forward_llm(input_ids)
@@ -305,13 +321,13 @@ class TritonPythonModel:
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device) generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
# Prepare response # Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response) response_sender.send(inference_response)
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info(f"send tritonserver_response_complete_final to end") self.logger.log_info("send tritonserver_response_complete_final to end")
else: else:
generated_ids = next(generated_ids_iter) generated_ids = next(generated_ids_iter)
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device) generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
@@ -320,11 +336,11 @@ class TritonPythonModel:
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor) prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids) audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
# Prepare response # Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
responses.append(inference_response) responses.append(inference_response)
if not self.decoupled: if not self.decoupled:
return responses return responses

View File

@@ -44,6 +44,7 @@ logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663 ORIGINAL_VOCAB_SIZE = 151663
class CosyVoice2: class CosyVoice2:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
@@ -66,6 +67,7 @@ class CosyVoice2:
trt_concurrent, trt_concurrent,
self.fp16) self.fp16)
class CosyVoice2Model: class CosyVoice2Model:
def __init__(self, def __init__(self,
@@ -109,16 +111,17 @@ class CosyVoice2Model:
input_names = ["x", "mask", "mu", "cond"] input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
class TritonPythonModel: class TritonPythonModel:
"""Triton Python model for vocoder. """Triton Python model for vocoder.
This model takes global and semantic tokens as input and generates audio waveforms This model takes global and semantic tokens as input and generates audio waveforms
using the BiCodec vocoder. using the BiCodec vocoder.
""" """
def initialize(self, args): def initialize(self, args):
"""Initialize the model. """Initialize the model.
Args: Args:
args: Dictionary containing model configuration args: Dictionary containing model configuration
""" """
@@ -126,24 +129,23 @@ class TritonPythonModel:
parameters = json.loads(args['model_config'])['parameters'] parameters = json.loads(args['model_config'])['parameters']
model_params = {key: value["string_value"] for key, value in parameters.items()} model_params = {key: value["string_value"] for key, value in parameters.items()}
model_dir = model_params["model_dir"] model_dir = model_params["model_dir"]
# Initialize device and vocoder # Initialize device and vocoder
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}") logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
self.token2wav_model = CosyVoice2( self.token2wav_model = CosyVoice2(
model_dir, load_jit=True, load_trt=True, fp16=True model_dir, load_jit=True, load_trt=True, fp16=True
) )
logger.info("Token2Wav initialized successfully") logger.info("Token2Wav initialized successfully")
def execute(self, requests): def execute(self, requests):
"""Execute inference on the batched requests. """Execute inference on the batched requests.
Args: Args:
requests: List of inference requests requests: List of inference requests
Returns: Returns:
List of inference responses containing generated waveforms List of inference responses containing generated waveforms
""" """
@@ -163,7 +165,7 @@ class TritonPythonModel:
# shift the speech tokens according to the original vocab size # shift the speech tokens according to the original vocab size
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
tts_mel, _ = self.token2wav_model.model.flow.inference( tts_mel, _ = self.token2wav_model.model.flow.inference(
token=target_speech_tokens, token=target_speech_tokens,
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to( token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
@@ -189,9 +191,5 @@ class TritonPythonModel:
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat)) wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
responses.append(inference_response) responses.append(inference_response)
return responses return responses

View File

@@ -35,8 +35,7 @@ def parse_arguments():
type=str, type=str,
default='auto', default='auto',
choices=['auto', 'float16', 'bfloat16', 'float32'], choices=['auto', 'float16', 'bfloat16', 'float32'],
help= help="The data type for the model weights and activations if not quantized. "
"The data type for the model weights and activations if not quantized. "
"If 'auto', the data type is automatically inferred from the source model; " "If 'auto', the data type is automatically inferred from the source model; "
"however, if the source dtype is float32, it is converted to float16.") "however, if the source dtype is float32, it is converted to float16.")
parser.add_argument( parser.add_argument(
@@ -49,8 +48,7 @@ def parse_arguments():
'--disable_weight_only_quant_plugin', '--disable_weight_only_quant_plugin',
default=False, default=False,
action="store_true", action="store_true",
help= help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
'You must also use --use_weight_only for that argument to have an impact.' 'You must also use --use_weight_only for that argument to have an impact.'
) )
parser.add_argument( parser.add_argument(
@@ -60,16 +58,14 @@ def parse_arguments():
nargs='?', nargs='?',
default='int8', default='int8',
choices=['int8', 'int4', 'int4_gptq'], choices=['int8', 'int4', 'int4_gptq'],
help= help='Define the precision for the weights when using weight-only quantization.'
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.' 'You must also use --use_weight_only for that argument to have an impact.'
) )
parser.add_argument( parser.add_argument(
'--calib_dataset', '--calib_dataset',
type=str, type=str,
default='ccdv/cnn_dailymail', default='ccdv/cnn_dailymail',
help= help="The huggingface dataset name or the local directory of the dataset for calibration."
"The huggingface dataset name or the local directory of the dataset for calibration."
) )
parser.add_argument( parser.add_argument(
"--smoothquant", "--smoothquant",
@@ -83,31 +79,27 @@ def parse_arguments():
'--per_channel', '--per_channel',
action="store_true", action="store_true",
default=False, default=False,
help= help='By default, we use a single static scaling factor for the GEMM\'s result. '
'By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. ' 'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.') 'The latter is usually more accurate, but a little slower.')
parser.add_argument( parser.add_argument(
'--per_token', '--per_token',
action="store_true", action="store_true",
default=False, default=False,
help= help='By default, we use a single static scaling factor to scale activations in the int8 range. '
'By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. ' 'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.') 'The latter is usually more accurate, but a little slower.')
parser.add_argument( parser.add_argument(
'--int8_kv_cache', '--int8_kv_cache',
default=False, default=False,
action="store_true", action="store_true",
help= help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
) )
parser.add_argument( parser.add_argument(
'--per_group', '--per_group',
default=False, default=False,
action="store_true", action="store_true",
help= help='By default, we use a single static scaling factor to scale weights in the int4 range. '
'By default, we use a single static scaling factor to scale weights in the int4 range. '
'per_group chooses at run time, and for each group, a custom scaling factor. ' 'per_group chooses at run time, and for each group, a custom scaling factor. '
'The flag is built for GPTQ/AWQ quantization.') 'The flag is built for GPTQ/AWQ quantization.')
@@ -121,16 +113,14 @@ def parse_arguments():
'--use_parallel_embedding', '--use_parallel_embedding',
action="store_true", action="store_true",
default=False, default=False,
help= help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
) )
parser.add_argument( parser.add_argument(
'--embedding_sharding_dim', '--embedding_sharding_dim',
type=int, type=int,
default=0, default=0,
choices=[0, 1], choices=[0, 1],
help= help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1' 'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
) )
@@ -147,15 +137,13 @@ def parse_arguments():
'--moe_tp_size', '--moe_tp_size',
type=int, type=int,
default=-1, default=-1,
help= help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
) )
parser.add_argument( parser.add_argument(
'--moe_ep_size', '--moe_ep_size',
type=int, type=int,
default=-1, default=-1,
help= help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
) )
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -249,7 +237,7 @@ def convert_and_save_hf(args):
trust_remote_code=True) trust_remote_code=True)
quant_config, override_fields = update_quant_config_from_hf( quant_config, override_fields = update_quant_config_from_hf(
quant_config, hf_config, override_fields) quant_config, hf_config, override_fields)
except: except BaseException:
logger.warning("AutoConfig cannot load the huggingface config.") logger.warning("AutoConfig cannot load the huggingface config.")
if args.smoothquant is not None or args.int8_kv_cache: if args.smoothquant is not None or args.int8_kv_cache:
@@ -339,4 +327,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@@ -1,4 +1,4 @@
#! /usr/bin/env python3 # /usr/bin/env python3
from argparse import ArgumentParser from argparse import ArgumentParser
from string import Template from string import Template
@@ -59,8 +59,7 @@ if __name__ == "__main__":
parser.add_argument("file_path", help="path of the .pbtxt to modify") parser.add_argument("file_path", help="path of the .pbtxt to modify")
parser.add_argument( parser.add_argument(
"substitutions", "substitutions",
help= help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
"substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
) )
parser.add_argument("--in_place", parser.add_argument("--in_place",
"-i", "-i",

View File

@@ -46,7 +46,6 @@ def parse_arguments(args=None):
parser.add_argument('--top_k', type=int, default=50) parser.add_argument('--top_k', type=int, default=50)
parser.add_argument('--top_p', type=float, default=0.95) parser.add_argument('--top_p', type=float, default=0.95)
return parser.parse_args(args=args) return parser.parse_args(args=args)
@@ -60,7 +59,7 @@ def parse_input(tokenizer,
input_ids = tokenizer.encode( input_ids = tokenizer.encode(
curr_text) curr_text)
batch_input_ids.append(input_ids) batch_input_ids.append(input_ids)
batch_input_ids = [ batch_input_ids = [
torch.tensor(x, dtype=torch.int32) for x in batch_input_ids torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
] ]