remove cache router

This commit is contained in:
root
2025-09-26 15:14:31 +08:00
parent 31a0adc73d
commit 79116ac32e
7 changed files with 219 additions and 243 deletions

View File

@@ -59,12 +59,14 @@ import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
import tritonclient.grpc as grpcclient_sync # Added sync client import
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
from datetime import datetime
# --- Added UserData and callback ---
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
self._first_chunk_time = None
self._second_chunk_time = None
self._start_time = None
def record_start_time(self):
@@ -75,14 +77,44 @@ class UserData:
return self._first_chunk_time - self._start_time
return None
def get_second_chunk_latency(self):
if self._first_chunk_time and self._second_chunk_time:
return self._second_chunk_time - self._first_chunk_time
return None
def callback(user_data, result, error):
if user_data._first_chunk_time is None and not error:
user_data._first_chunk_time = time.time() # Record time of first successful chunk
if not error:
if user_data._first_chunk_time is None:
user_data._first_chunk_time = time.time() # Record time of first successful chunk
elif user_data._second_chunk_time is None:
user_data._second_chunk_time = time.time()
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
def stream_callback(user_data_map, result, error):
request_id = None
if error:
# Note: InferenceServerException doesn't have a public request_id() method in all versions.
# This part might need adjustment depending on the tritonclient library version.
# A more robust way would be to wrap the error with the request_id if possible.
# For now, we assume we can't get request_id from error and it will timeout on the client side.
print(f"An error occurred in the stream callback: {error}")
else:
request_id = result.get_response().id
if request_id:
user_data = user_data_map.get(request_id)
if user_data:
callback(user_data, result, error)
else:
print(f"Warning: Could not find user_data for request_id {request_id}")
# --- End Added UserData and callback ---
@@ -142,6 +174,68 @@ def write_triton_stats(stats, summary_file):
)
def subtract_stats(stats_after, stats_before):
"""Subtracts two Triton inference statistics objects."""
# Deep copy to avoid modifying the original stats_after
stats_diff = json.loads(json.dumps(stats_after))
model_stats_before_map = {
s["name"]: {
"version": s["version"],
"last_inference": s.get("last_inference", 0),
"inference_count": s.get("inference_count", 0),
"execution_count": s.get("execution_count", 0),
"inference_stats": s.get("inference_stats", {}),
"batch_stats": s.get("batch_stats", []),
}
for s in stats_before["model_stats"]
}
for model_stat_after in stats_diff["model_stats"]:
model_name = model_stat_after["name"]
if model_name in model_stats_before_map:
model_stat_before = model_stats_before_map[model_name]
# Subtract counts
model_stat_after["inference_count"] = str(
int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
)
model_stat_after["execution_count"] = str(
int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
)
# Subtract aggregate stats (like queue, compute times)
if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
if "ns" in model_stat_after["inference_stats"][key]:
ns_after = int(model_stat_after["inference_stats"][key]["ns"])
ns_before = int(model_stat_before["inference_stats"][key]["ns"])
model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
if "count" in model_stat_after["inference_stats"][key]:
count_after = int(model_stat_after["inference_stats"][key]["count"])
count_before = int(model_stat_before["inference_stats"][key]["count"])
model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
# Subtract batch execution stats
if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
for batch_stat_after in model_stat_after["batch_stats"]:
bs = batch_stat_after["batch_size"]
if bs in batch_stats_before_map:
batch_stat_before = batch_stats_before_map[bs]
for key in ["compute_input", "compute_infer", "compute_output"]:
if key in batch_stat_after and key in batch_stat_before:
count_after = int(batch_stat_after[key]["count"])
count_before = int(batch_stat_before[key]["count"])
batch_stat_after[key]["count"] = str(count_after - count_before)
ns_after = int(batch_stat_after[key]["ns"])
ns_before = int(batch_stat_before[key]["ns"])
batch_stat_after[key]["ns"] = str(ns_after - ns_before)
return stats_diff
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@@ -357,10 +451,10 @@ def run_sync_streaming_inference(
"""Helper function to run the blocking sync streaming call."""
start_time_total = time.time()
user_data.record_start_time() # Record start time for first chunk latency calculation
# e.g. 08:47:34.827758
# Establish stream
sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
print(f"Record start time in human readable: {datetime.now()}")
# input()
# Send request
sync_triton_client.async_stream_infer(
model_name,
@@ -374,11 +468,11 @@ def run_sync_streaming_inference(
audios = []
while True:
try:
result = user_data._completed_requests.get() # Add timeout
result = user_data._completed_requests.get(timeout=20) # Add timeout
if isinstance(result, InferenceServerException):
print(f"Received InferenceServerException: {result}")
sync_triton_client.stop_stream()
return None, None, None # Indicate error
# Don't stop the stream here, just return error
return None, None, None, None
# Get response metadata
response = result.get_response()
final = response.parameters["triton_final_response"].bool_param
@@ -393,13 +487,13 @@ def run_sync_streaming_inference(
except queue.Empty:
print(f"Timeout waiting for response for request id {request_id}")
sync_triton_client.stop_stream()
return None, None, None # Indicate error
# Don't stop stream here, just return error
return None, None, None, None
sync_triton_client.stop_stream()
end_time_total = time.time()
total_request_latency = end_time_total - start_time_total
first_chunk_latency = user_data.get_first_chunk_latency()
second_chunk_latency = user_data.get_second_chunk_latency()
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
actual_duration = 0
@@ -448,7 +542,7 @@ def run_sync_streaming_inference(
print("Warning: No audio chunks received.")
actual_duration = 0
return total_request_latency, first_chunk_latency, actual_duration
return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
async def send_streaming(
@@ -468,10 +562,12 @@ async def send_streaming(
latency_data = []
task_id = int(name[5:])
sync_triton_client = None # Initialize client variable
user_data_map = {}
try: # Wrap in try...finally to ensure client closing
print(f"{name}: Initializing sync client for streaming...")
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
for i, item in enumerate(manifest_item_list):
@@ -494,10 +590,11 @@ async def send_streaming(
request_id = str(uuid.uuid4())
user_data = UserData()
user_data_map[request_id] = user_data
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
print("target_text: ", target_text, "time: ", datetime.now())
total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
run_sync_streaming_inference,
sync_triton_client,
model_name,
@@ -511,12 +608,18 @@ async def send_streaming(
)
if total_request_latency is not None:
print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
print(
f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
)
latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
total_duration += actual_duration
else:
print(f"{name}: Item {i} failed.")
del user_data_map[request_id]
except FileNotFoundError:
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
except Exception as e:
@@ -527,7 +630,8 @@ async def send_streaming(
finally: # Ensure client is closed
if sync_triton_client:
try:
print(f"{name}: Closing sync client...")
print(f"{name}: Closing stream and sync client...")
sync_triton_client.stop_stream()
sync_triton_client.close()
except Exception as e:
print(f"{name}: Error closing sync client: {e}")
@@ -685,9 +789,22 @@ async def main():
"target_text": dataset[i]["target_text"],
}
)
# manifest_item_list = manifest_item_list[:4]
else:
manifest_item_list = load_manifests(args.manifest_path)
# --- Statistics Fetching (Before) ---
stats_client = None
stats_before = None
try:
print("Initializing temporary async client for fetching stats...")
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
print("Fetching inference statistics before running tasks...")
stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
except Exception as e:
print(f"Could not retrieve statistics before running tasks: {e}")
# --- End Statistics Fetching (Before) ---
num_tasks = min(args.num_tasks, len(manifest_item_list))
manifest_item_list = split_data(manifest_item_list, num_tasks)
@@ -776,8 +893,9 @@ async def main():
elif args.mode == "streaming":
# Calculate stats for total request latency and first chunk latency
total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
s += "\n--- Total Request Latency ---\n"
if total_latency_list:
@@ -804,6 +922,19 @@ async def main():
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
else:
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
s += "\n--- Second Chunk Latency ---\n"
if second_chunk_latency_list:
avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
else:
s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
else:
s += "No latency data collected.\n"
# --- End Statistics Reporting ---
@@ -822,20 +953,23 @@ async def main():
# --- Statistics Fetching using temporary Async Client ---
# Use a separate async client for fetching stats regardless of mode
stats_client = None
try:
print("Initializing temporary async client for fetching stats...")
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
print("Fetching inference statistics...")
# Fetching for all models, filtering might be needed depending on server setup
stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
print("Fetching model config...")
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
if stats_client and stats_before:
print("Fetching inference statistics after running tasks...")
stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
print("Calculating statistics difference...")
stats = subtract_stats(stats_after, stats_before)
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
json.dump(metadata, f, indent=4)
print("Fetching model config...")
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
json.dump(metadata, f, indent=4)
else:
print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
except Exception as e:
print(f"Could not retrieve statistics or config: {e}")

View File

@@ -109,7 +109,6 @@ class TritonPythonModel:
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
self.http_client = httpx.AsyncClient()
self.runtime_cache = {}
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
"""Converts a tensor or list of speech token IDs to a string representation."""
@@ -264,38 +263,11 @@ class TritonPythonModel:
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# optional cache inputs
if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None:
# inputs_tensor.extend([
# pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()),
# pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()),
# pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()),
# pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()),
# pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()),
# pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()),
# pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()),
# ])
inputs_tensor.extend([
self.runtime_cache[request_id]["conformer_cnn_cache"],
self.runtime_cache[request_id]["conformer_att_cache"],
self.runtime_cache[request_id]["estimator_cnn_cache"],
self.runtime_cache[request_id]["estimator_att_cache"],
self.runtime_cache[request_id]["mel"],
self.runtime_cache[request_id]["source"],
self.runtime_cache[request_id]["speech"],
])
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav_dit',
requested_output_names=[
"waveform",
"conformer_cnn_cache",
"conformer_att_cache",
"estimator_cnn_cache",
"estimator_att_cache",
"mel",
"source",
"speech",
],
inputs=inputs_tensor,
request_id=request_id,
@@ -306,14 +278,6 @@ class TritonPythonModel:
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache")
self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache")
self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache")
self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache")
self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel")
self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source")
self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech")
# Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
@@ -339,16 +303,6 @@ class TritonPythonModel:
async def _process_request(self, request):
request_id = request.request_id()
if request_id not in self.runtime_cache:
self.runtime_cache[request_id] = {
"conformer_cnn_cache": None,
"conformer_att_cache": None,
"estimator_cnn_cache": None,
"estimator_att_cache": None,
"mel": None,
"source": None,
"speech": None,
}
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
@@ -369,7 +323,7 @@ class TritonPythonModel:
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
# reference_text = self.default_spk_info["prompt_text"]
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
@@ -453,9 +407,7 @@ class TritonPythonModel:
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
if request_id in self.runtime_cache:
del self.runtime_cache[request_id]
self.logger.log_info(f"Deleted cache for request_id: {request_id}")
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:

View File

@@ -31,7 +31,7 @@ parameters [
value: {string_value:"${model_dir}"}
}
]
parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}}
input [
{
name: "reference_wav"

View File

@@ -103,91 +103,47 @@ class TritonPythonModel:
List of inference responses containing generated waveforms
"""
responses = []
# Process each request in batch
for request in requests:
request_id = request.request_id()
# Get inputs
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens")
target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack())
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
# shift the speech tokens according to the original vocab size
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens.squeeze().tolist()
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item()
wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0)
request_id = request.request_id()
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(
request, "reference_wav_len").as_numpy().item()
wav_array = torch.from_numpy(wav_array)
# Prepare inputs
wav = wav_array[:, :wav_len].squeeze(0)
spk_id = get_spk_id_from_prompt_audio(wav)
# wav = wav.to(self.device)
# Handle cache
conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache")
if conformer_cnn_cache is not None:
self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack())
conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache")
self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1)
estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache")
self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0)
# update cache before forward
# self.token2wav_model.streaming_flow_cache[request_id]
# self.token2wav_model.hift_cache_dict[request_id]
estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache")
self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0)
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
mel_np = pb_utils.get_input_tensor_by_name(request, "mel")
self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack())
source_np = pb_utils.get_input_tensor_by_name(request, "source")
self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack())
speech_np = pb_utils.get_input_tensor_by_name(request, "speech")
self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack())
# Forward pass
audio_hat = self.token2wav_model.forward_streaming(
target_speech_tokens,
finalize,
request_id=request_id,
speaker_id=f"{spk_id}",
prompt_audio=wav,
prompt_audio_sample_rate=16000
)
# Prepare outputs
# get the cache after forward
outputs = []
generated_wave = audio_hat.squeeze(0).cpu().numpy()
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
outputs.append(wav_tensor)
if request_id in self.token2wav_model.streaming_flow_cache:
cache = self.token2wav_model.streaming_flow_cache[request_id]
hifigan_cache = self.token2wav_model.hift_cache_dict[request_id]
conformer_cnn_cache = cache['conformer_cnn_cache']
conformer_att_cache = cache['conformer_att_cache'].transpose(0,1)
estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0)
estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0)
mel = hifigan_cache['mel']
source = hifigan_cache['source']
speech = hifigan_cache['speech']
outputs.extend([
pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)),
pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)),
pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)),
pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)),
pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)),
pb_utils.Tensor.from_dlpack("source", to_dlpack(source)),
pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)),
])
else:
outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)),
pb_utils.Tensor("mel", np.array([], dtype=np.float32)),
pb_utils.Tensor("source", np.array([], dtype=np.float32)),
pb_utils.Tensor("speech", np.array([], dtype=np.float32)),
])
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
responses.append(inference_response)
return responses
def finalize(self):
self.logger.log_info("Finalizing Token2WavDiT model")
return responses

View File

@@ -22,7 +22,6 @@ dynamic_batching {
default_priority_level: 10
}
parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}}
parameters [
{
key: "model_dir",
@@ -52,48 +51,6 @@ input [
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "conformer_cnn_cache"
data_type: TYPE_FP16
dims: [ 512, -1 ]
optional: true
},
{
name: "conformer_att_cache"
data_type: TYPE_FP16
dims: [ 10, 8, -1, 128 ]
optional: true
},
{
name: "estimator_cnn_cache"
data_type: TYPE_FP16
dims: [ 10, 16, -1, 1024, 2 ]
optional: true
},
{
name: "estimator_att_cache"
data_type: TYPE_FP16
dims: [ 10, 16, -1, 8, -1, 128 ]
optional: true
},
{
name: "mel"
data_type: TYPE_FP32
dims: [ 80, -1 ]
optional: true
},
{
name: "source"
data_type: TYPE_FP32
dims: [ 1, -1 ]
optional: true
},
{
name: "speech"
data_type: TYPE_FP32
dims: [ -1 ]
optional: true
}
]
output [
@@ -101,41 +58,6 @@ output [
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
},
{
name: "conformer_cnn_cache"
data_type: TYPE_FP16
dims: [ 512, -1 ]
},
{
name: "conformer_att_cache"
data_type: TYPE_FP16
dims: [ 10, 8, -1, 128 ]
},
{
name: "estimator_cnn_cache"
data_type: TYPE_FP16
dims: [ 10, 16, -1, 1024, 2 ]
},
{
name: "estimator_att_cache"
data_type: TYPE_FP16
dims: [ 10, 16, -1, 8, -1, 128 ]
},
{
name: "mel"
data_type: TYPE_FP32
dims: [ 80, -1 ]
},
{
name: "source"
data_type: TYPE_FP32
dims: [ 1, -1 ]
},
{
name: "speech"
data_type: TYPE_FP32
dims: [ -1 ]
}
]

View File

@@ -1,6 +1,6 @@
#!/bin/bash
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=0
cosyvoice_path=/workspace/CosyVoice
cosyvoice_path=/workspace_yuekai/tts/CosyVoice
stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
@@ -112,7 +112,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
MODEL_DIR=$model_scope_model_local_dir
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
BLS_INSTANCE_NUM=4
TRITON_MAX_BATCH_SIZE=32
TRITON_MAX_BATCH_SIZE=1
DECOUPLED_MODE=True # True for streaming, False for offline
STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
@@ -154,7 +154,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
--num-tasks $num_task \
--mode $mode \
--huggingface-dataset yuekai/seed_tts_cosy2 \
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_no_att_cnn_cache_new
--log-dir ./log_debug_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
@@ -185,14 +185,14 @@ fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
python3 streaming_inference.py
CUDA_VISIBLE_DEVICES=2 python3 streaming_inference.py --enable-trt --strategy exponential
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16
CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4
fi

View File

@@ -31,6 +31,7 @@ def get_args():
parser.add_argument("--output-dir", type=str, default="generated_wavs")
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"])
return parser.parse_args()
@@ -53,12 +54,14 @@ if __name__ == "__main__":
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
flow_pre_lookahead_len = 3
CHUNK_SIZE = 25
CHUNK_SIZE = 15
token_frame_rate = 25
OVERLAP_SIZE = 0
warmup_times = 3
for _ in range(warmup_times):
start_time = time.time()
total_forward_count = 0
for batch in data_loader:
tts_speech_list = []
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
@@ -83,17 +86,26 @@ if __name__ == "__main__":
buffer = generated_speech_tokens
output_wavs = []
chunk_index = 0
while True:
if args.strategy == "equal":
this_chunk_size = CHUNK_SIZE
elif args.strategy == "exponential":
this_chunk_size = token_frame_rate * (2 ** chunk_index)
if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len:
wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:]
if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
output_wavs.append(wavs)
total_forward_count += 1
chunk_index += 1
else:
wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
output_wavs.append(wavs)
total_forward_count += 1
# chunk_index += 1
break
for i, wav in enumerate(output_wavs):
@@ -112,4 +124,4 @@ if __name__ == "__main__":
if _ == 0:
token2wav_model.speaker_cache = {}
print(f"Warmup time: {end_time - start_time} seconds")
print(f"Total forward count: {total_forward_count}")