mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
remove cache router
This commit is contained in:
@@ -59,12 +59,14 @@ import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
|
|||||||
import tritonclient.grpc as grpcclient_sync # Added sync client import
|
import tritonclient.grpc as grpcclient_sync # Added sync client import
|
||||||
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
# --- Added UserData and callback ---
|
# --- Added UserData and callback ---
|
||||||
class UserData:
|
class UserData:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._completed_requests = queue.Queue()
|
self._completed_requests = queue.Queue()
|
||||||
self._first_chunk_time = None
|
self._first_chunk_time = None
|
||||||
|
self._second_chunk_time = None
|
||||||
self._start_time = None
|
self._start_time = None
|
||||||
|
|
||||||
def record_start_time(self):
|
def record_start_time(self):
|
||||||
@@ -75,14 +77,44 @@ class UserData:
|
|||||||
return self._first_chunk_time - self._start_time
|
return self._first_chunk_time - self._start_time
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_second_chunk_latency(self):
|
||||||
|
if self._first_chunk_time and self._second_chunk_time:
|
||||||
|
return self._second_chunk_time - self._first_chunk_time
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def callback(user_data, result, error):
|
def callback(user_data, result, error):
|
||||||
if user_data._first_chunk_time is None and not error:
|
if not error:
|
||||||
user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
if user_data._first_chunk_time is None:
|
||||||
|
user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
||||||
|
elif user_data._second_chunk_time is None:
|
||||||
|
user_data._second_chunk_time = time.time()
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
user_data._completed_requests.put(error)
|
user_data._completed_requests.put(error)
|
||||||
else:
|
else:
|
||||||
user_data._completed_requests.put(result)
|
user_data._completed_requests.put(result)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_callback(user_data_map, result, error):
|
||||||
|
request_id = None
|
||||||
|
if error:
|
||||||
|
# Note: InferenceServerException doesn't have a public request_id() method in all versions.
|
||||||
|
# This part might need adjustment depending on the tritonclient library version.
|
||||||
|
# A more robust way would be to wrap the error with the request_id if possible.
|
||||||
|
# For now, we assume we can't get request_id from error and it will timeout on the client side.
|
||||||
|
print(f"An error occurred in the stream callback: {error}")
|
||||||
|
else:
|
||||||
|
request_id = result.get_response().id
|
||||||
|
|
||||||
|
if request_id:
|
||||||
|
user_data = user_data_map.get(request_id)
|
||||||
|
if user_data:
|
||||||
|
callback(user_data, result, error)
|
||||||
|
else:
|
||||||
|
print(f"Warning: Could not find user_data for request_id {request_id}")
|
||||||
|
|
||||||
|
|
||||||
# --- End Added UserData and callback ---
|
# --- End Added UserData and callback ---
|
||||||
|
|
||||||
|
|
||||||
@@ -142,6 +174,68 @@ def write_triton_stats(stats, summary_file):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def subtract_stats(stats_after, stats_before):
|
||||||
|
"""Subtracts two Triton inference statistics objects."""
|
||||||
|
# Deep copy to avoid modifying the original stats_after
|
||||||
|
stats_diff = json.loads(json.dumps(stats_after))
|
||||||
|
|
||||||
|
model_stats_before_map = {
|
||||||
|
s["name"]: {
|
||||||
|
"version": s["version"],
|
||||||
|
"last_inference": s.get("last_inference", 0),
|
||||||
|
"inference_count": s.get("inference_count", 0),
|
||||||
|
"execution_count": s.get("execution_count", 0),
|
||||||
|
"inference_stats": s.get("inference_stats", {}),
|
||||||
|
"batch_stats": s.get("batch_stats", []),
|
||||||
|
}
|
||||||
|
for s in stats_before["model_stats"]
|
||||||
|
}
|
||||||
|
|
||||||
|
for model_stat_after in stats_diff["model_stats"]:
|
||||||
|
model_name = model_stat_after["name"]
|
||||||
|
if model_name in model_stats_before_map:
|
||||||
|
model_stat_before = model_stats_before_map[model_name]
|
||||||
|
|
||||||
|
# Subtract counts
|
||||||
|
model_stat_after["inference_count"] = str(
|
||||||
|
int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
|
||||||
|
)
|
||||||
|
model_stat_after["execution_count"] = str(
|
||||||
|
int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subtract aggregate stats (like queue, compute times)
|
||||||
|
if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
|
||||||
|
for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
|
||||||
|
if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
|
||||||
|
if "ns" in model_stat_after["inference_stats"][key]:
|
||||||
|
ns_after = int(model_stat_after["inference_stats"][key]["ns"])
|
||||||
|
ns_before = int(model_stat_before["inference_stats"][key]["ns"])
|
||||||
|
model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
|
||||||
|
if "count" in model_stat_after["inference_stats"][key]:
|
||||||
|
count_after = int(model_stat_after["inference_stats"][key]["count"])
|
||||||
|
count_before = int(model_stat_before["inference_stats"][key]["count"])
|
||||||
|
model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
|
||||||
|
|
||||||
|
# Subtract batch execution stats
|
||||||
|
if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
|
||||||
|
batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
|
||||||
|
for batch_stat_after in model_stat_after["batch_stats"]:
|
||||||
|
bs = batch_stat_after["batch_size"]
|
||||||
|
if bs in batch_stats_before_map:
|
||||||
|
batch_stat_before = batch_stats_before_map[bs]
|
||||||
|
for key in ["compute_input", "compute_infer", "compute_output"]:
|
||||||
|
if key in batch_stat_after and key in batch_stat_before:
|
||||||
|
count_after = int(batch_stat_after[key]["count"])
|
||||||
|
count_before = int(batch_stat_before[key]["count"])
|
||||||
|
batch_stat_after[key]["count"] = str(count_after - count_before)
|
||||||
|
|
||||||
|
ns_after = int(batch_stat_after[key]["ns"])
|
||||||
|
ns_before = int(batch_stat_before[key]["ns"])
|
||||||
|
batch_stat_after[key]["ns"] = str(ns_after - ns_before)
|
||||||
|
return stats_diff
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
@@ -357,10 +451,10 @@ def run_sync_streaming_inference(
|
|||||||
"""Helper function to run the blocking sync streaming call."""
|
"""Helper function to run the blocking sync streaming call."""
|
||||||
start_time_total = time.time()
|
start_time_total = time.time()
|
||||||
user_data.record_start_time() # Record start time for first chunk latency calculation
|
user_data.record_start_time() # Record start time for first chunk latency calculation
|
||||||
|
# e.g. 08:47:34.827758
|
||||||
|
|
||||||
# Establish stream
|
print(f"Record start time in human readable: {datetime.now()}")
|
||||||
sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
|
# input()
|
||||||
|
|
||||||
# Send request
|
# Send request
|
||||||
sync_triton_client.async_stream_infer(
|
sync_triton_client.async_stream_infer(
|
||||||
model_name,
|
model_name,
|
||||||
@@ -374,11 +468,11 @@ def run_sync_streaming_inference(
|
|||||||
audios = []
|
audios = []
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
result = user_data._completed_requests.get() # Add timeout
|
result = user_data._completed_requests.get(timeout=20) # Add timeout
|
||||||
if isinstance(result, InferenceServerException):
|
if isinstance(result, InferenceServerException):
|
||||||
print(f"Received InferenceServerException: {result}")
|
print(f"Received InferenceServerException: {result}")
|
||||||
sync_triton_client.stop_stream()
|
# Don't stop the stream here, just return error
|
||||||
return None, None, None # Indicate error
|
return None, None, None, None
|
||||||
# Get response metadata
|
# Get response metadata
|
||||||
response = result.get_response()
|
response = result.get_response()
|
||||||
final = response.parameters["triton_final_response"].bool_param
|
final = response.parameters["triton_final_response"].bool_param
|
||||||
@@ -393,13 +487,13 @@ def run_sync_streaming_inference(
|
|||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
print(f"Timeout waiting for response for request id {request_id}")
|
print(f"Timeout waiting for response for request id {request_id}")
|
||||||
sync_triton_client.stop_stream()
|
# Don't stop stream here, just return error
|
||||||
return None, None, None # Indicate error
|
return None, None, None, None
|
||||||
|
|
||||||
sync_triton_client.stop_stream()
|
|
||||||
end_time_total = time.time()
|
end_time_total = time.time()
|
||||||
total_request_latency = end_time_total - start_time_total
|
total_request_latency = end_time_total - start_time_total
|
||||||
first_chunk_latency = user_data.get_first_chunk_latency()
|
first_chunk_latency = user_data.get_first_chunk_latency()
|
||||||
|
second_chunk_latency = user_data.get_second_chunk_latency()
|
||||||
|
|
||||||
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
|
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
|
||||||
actual_duration = 0
|
actual_duration = 0
|
||||||
@@ -448,7 +542,7 @@ def run_sync_streaming_inference(
|
|||||||
print("Warning: No audio chunks received.")
|
print("Warning: No audio chunks received.")
|
||||||
actual_duration = 0
|
actual_duration = 0
|
||||||
|
|
||||||
return total_request_latency, first_chunk_latency, actual_duration
|
return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
|
||||||
|
|
||||||
|
|
||||||
async def send_streaming(
|
async def send_streaming(
|
||||||
@@ -468,10 +562,12 @@ async def send_streaming(
|
|||||||
latency_data = []
|
latency_data = []
|
||||||
task_id = int(name[5:])
|
task_id = int(name[5:])
|
||||||
sync_triton_client = None # Initialize client variable
|
sync_triton_client = None # Initialize client variable
|
||||||
|
user_data_map = {}
|
||||||
|
|
||||||
try: # Wrap in try...finally to ensure client closing
|
try: # Wrap in try...finally to ensure client closing
|
||||||
print(f"{name}: Initializing sync client for streaming...")
|
print(f"{name}: Initializing sync client for streaming...")
|
||||||
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
||||||
|
sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
|
||||||
|
|
||||||
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
||||||
for i, item in enumerate(manifest_item_list):
|
for i, item in enumerate(manifest_item_list):
|
||||||
@@ -494,10 +590,11 @@ async def send_streaming(
|
|||||||
|
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
user_data = UserData()
|
user_data = UserData()
|
||||||
|
user_data_map[request_id] = user_data
|
||||||
|
|
||||||
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||||
|
print("target_text: ", target_text, "time: ", datetime.now())
|
||||||
total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
|
total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
|
||||||
run_sync_streaming_inference,
|
run_sync_streaming_inference,
|
||||||
sync_triton_client,
|
sync_triton_client,
|
||||||
model_name,
|
model_name,
|
||||||
@@ -511,12 +608,18 @@ async def send_streaming(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if total_request_latency is not None:
|
if total_request_latency is not None:
|
||||||
print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
|
print(
|
||||||
latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
|
f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
|
||||||
|
f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
|
||||||
|
f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
|
||||||
|
)
|
||||||
|
latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
|
||||||
total_duration += actual_duration
|
total_duration += actual_duration
|
||||||
else:
|
else:
|
||||||
print(f"{name}: Item {i} failed.")
|
print(f"{name}: Item {i} failed.")
|
||||||
|
|
||||||
|
del user_data_map[request_id]
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -527,7 +630,8 @@ async def send_streaming(
|
|||||||
finally: # Ensure client is closed
|
finally: # Ensure client is closed
|
||||||
if sync_triton_client:
|
if sync_triton_client:
|
||||||
try:
|
try:
|
||||||
print(f"{name}: Closing sync client...")
|
print(f"{name}: Closing stream and sync client...")
|
||||||
|
sync_triton_client.stop_stream()
|
||||||
sync_triton_client.close()
|
sync_triton_client.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{name}: Error closing sync client: {e}")
|
print(f"{name}: Error closing sync client: {e}")
|
||||||
@@ -685,9 +789,22 @@ async def main():
|
|||||||
"target_text": dataset[i]["target_text"],
|
"target_text": dataset[i]["target_text"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
# manifest_item_list = manifest_item_list[:4]
|
||||||
else:
|
else:
|
||||||
manifest_item_list = load_manifests(args.manifest_path)
|
manifest_item_list = load_manifests(args.manifest_path)
|
||||||
|
|
||||||
|
# --- Statistics Fetching (Before) ---
|
||||||
|
stats_client = None
|
||||||
|
stats_before = None
|
||||||
|
try:
|
||||||
|
print("Initializing temporary async client for fetching stats...")
|
||||||
|
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
||||||
|
print("Fetching inference statistics before running tasks...")
|
||||||
|
stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not retrieve statistics before running tasks: {e}")
|
||||||
|
# --- End Statistics Fetching (Before) ---
|
||||||
|
|
||||||
num_tasks = min(args.num_tasks, len(manifest_item_list))
|
num_tasks = min(args.num_tasks, len(manifest_item_list))
|
||||||
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
||||||
|
|
||||||
@@ -776,8 +893,9 @@ async def main():
|
|||||||
|
|
||||||
elif args.mode == "streaming":
|
elif args.mode == "streaming":
|
||||||
# Calculate stats for total request latency and first chunk latency
|
# Calculate stats for total request latency and first chunk latency
|
||||||
total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
|
total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
|
||||||
first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
|
first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
|
||||||
|
second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
|
||||||
|
|
||||||
s += "\n--- Total Request Latency ---\n"
|
s += "\n--- Total Request Latency ---\n"
|
||||||
if total_latency_list:
|
if total_latency_list:
|
||||||
@@ -804,6 +922,19 @@ async def main():
|
|||||||
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
||||||
else:
|
else:
|
||||||
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
||||||
|
|
||||||
|
s += "\n--- Second Chunk Latency ---\n"
|
||||||
|
if second_chunk_latency_list:
|
||||||
|
avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
|
||||||
|
variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
|
||||||
|
s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
||||||
|
s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
|
||||||
|
else:
|
||||||
|
s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
|
||||||
else:
|
else:
|
||||||
s += "No latency data collected.\n"
|
s += "No latency data collected.\n"
|
||||||
# --- End Statistics Reporting ---
|
# --- End Statistics Reporting ---
|
||||||
@@ -822,20 +953,23 @@ async def main():
|
|||||||
|
|
||||||
# --- Statistics Fetching using temporary Async Client ---
|
# --- Statistics Fetching using temporary Async Client ---
|
||||||
# Use a separate async client for fetching stats regardless of mode
|
# Use a separate async client for fetching stats regardless of mode
|
||||||
stats_client = None
|
|
||||||
try:
|
try:
|
||||||
print("Initializing temporary async client for fetching stats...")
|
if stats_client and stats_before:
|
||||||
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
print("Fetching inference statistics after running tasks...")
|
||||||
print("Fetching inference statistics...")
|
stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
||||||
# Fetching for all models, filtering might be needed depending on server setup
|
|
||||||
stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
|
||||||
print("Fetching model config...")
|
|
||||||
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
|
|
||||||
|
|
||||||
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
print("Calculating statistics difference...")
|
||||||
|
stats = subtract_stats(stats_after, stats_before)
|
||||||
|
|
||||||
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
print("Fetching model config...")
|
||||||
json.dump(metadata, f, indent=4)
|
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
|
||||||
|
|
||||||
|
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
||||||
|
|
||||||
|
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
||||||
|
json.dump(metadata, f, indent=4)
|
||||||
|
else:
|
||||||
|
print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Could not retrieve statistics or config: {e}")
|
print(f"Could not retrieve statistics or config: {e}")
|
||||||
|
|||||||
@@ -109,7 +109,6 @@ class TritonPythonModel:
|
|||||||
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||||
self.default_spk_info = spk_info["001"]
|
self.default_spk_info = spk_info["001"]
|
||||||
self.http_client = httpx.AsyncClient()
|
self.http_client = httpx.AsyncClient()
|
||||||
self.runtime_cache = {}
|
|
||||||
|
|
||||||
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
||||||
"""Converts a tensor or list of speech token IDs to a string representation."""
|
"""Converts a tensor or list of speech token IDs to a string representation."""
|
||||||
@@ -264,38 +263,11 @@ class TritonPythonModel:
|
|||||||
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||||
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
||||||
|
|
||||||
# optional cache inputs
|
|
||||||
if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None:
|
|
||||||
# inputs_tensor.extend([
|
|
||||||
# pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()),
|
|
||||||
# pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()),
|
|
||||||
# pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()),
|
|
||||||
# pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()),
|
|
||||||
# pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()),
|
|
||||||
# pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()),
|
|
||||||
# pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()),
|
|
||||||
# ])
|
|
||||||
inputs_tensor.extend([
|
|
||||||
self.runtime_cache[request_id]["conformer_cnn_cache"],
|
|
||||||
self.runtime_cache[request_id]["conformer_att_cache"],
|
|
||||||
self.runtime_cache[request_id]["estimator_cnn_cache"],
|
|
||||||
self.runtime_cache[request_id]["estimator_att_cache"],
|
|
||||||
self.runtime_cache[request_id]["mel"],
|
|
||||||
self.runtime_cache[request_id]["source"],
|
|
||||||
self.runtime_cache[request_id]["speech"],
|
|
||||||
])
|
|
||||||
# Create and execute inference request
|
# Create and execute inference request
|
||||||
inference_request = pb_utils.InferenceRequest(
|
inference_request = pb_utils.InferenceRequest(
|
||||||
model_name='token2wav_dit',
|
model_name='token2wav_dit',
|
||||||
requested_output_names=[
|
requested_output_names=[
|
||||||
"waveform",
|
"waveform",
|
||||||
"conformer_cnn_cache",
|
|
||||||
"conformer_att_cache",
|
|
||||||
"estimator_cnn_cache",
|
|
||||||
"estimator_att_cache",
|
|
||||||
"mel",
|
|
||||||
"source",
|
|
||||||
"speech",
|
|
||||||
],
|
],
|
||||||
inputs=inputs_tensor,
|
inputs=inputs_tensor,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
@@ -306,14 +278,6 @@ class TritonPythonModel:
|
|||||||
if inference_response.has_error():
|
if inference_response.has_error():
|
||||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||||
|
|
||||||
self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache")
|
|
||||||
self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache")
|
|
||||||
self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache")
|
|
||||||
self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache")
|
|
||||||
self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel")
|
|
||||||
self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source")
|
|
||||||
self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech")
|
|
||||||
|
|
||||||
# Extract and convert output waveform
|
# Extract and convert output waveform
|
||||||
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
||||||
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
||||||
@@ -339,16 +303,6 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
async def _process_request(self, request):
|
async def _process_request(self, request):
|
||||||
request_id = request.request_id()
|
request_id = request.request_id()
|
||||||
if request_id not in self.runtime_cache:
|
|
||||||
self.runtime_cache[request_id] = {
|
|
||||||
"conformer_cnn_cache": None,
|
|
||||||
"conformer_att_cache": None,
|
|
||||||
"estimator_cnn_cache": None,
|
|
||||||
"estimator_att_cache": None,
|
|
||||||
"mel": None,
|
|
||||||
"source": None,
|
|
||||||
"speech": None,
|
|
||||||
}
|
|
||||||
# Extract input tensors
|
# Extract input tensors
|
||||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||||
|
|
||||||
@@ -369,7 +323,7 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||||
reference_text = reference_text[0][0].decode('utf-8')
|
reference_text = reference_text[0][0].decode('utf-8')
|
||||||
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||||
|
|
||||||
# reference_text = self.default_spk_info["prompt_text"]
|
# reference_text = self.default_spk_info["prompt_text"]
|
||||||
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||||
@@ -453,9 +407,7 @@ class TritonPythonModel:
|
|||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
response_sender.send(inference_response)
|
response_sender.send(inference_response)
|
||||||
if request_id in self.runtime_cache:
|
|
||||||
del self.runtime_cache[request_id]
|
|
||||||
self.logger.log_info(f"Deleted cache for request_id: {request_id}")
|
|
||||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ parameters [
|
|||||||
value: {string_value:"${model_dir}"}
|
value: {string_value:"${model_dir}"}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}}
|
|
||||||
input [
|
input [
|
||||||
{
|
{
|
||||||
name: "reference_wav"
|
name: "reference_wav"
|
||||||
|
|||||||
@@ -103,91 +103,47 @@ class TritonPythonModel:
|
|||||||
List of inference responses containing generated waveforms
|
List of inference responses containing generated waveforms
|
||||||
"""
|
"""
|
||||||
responses = []
|
responses = []
|
||||||
|
# Process each request in batch
|
||||||
for request in requests:
|
for request in requests:
|
||||||
request_id = request.request_id()
|
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
||||||
|
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
|
||||||
# Get inputs
|
# shift the speech tokens according to the original vocab size
|
||||||
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens")
|
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
target_speech_tokens = torch.utils.dlpack.from_dlpack(target_speech_tokens_tensor.to_dlpack())
|
|
||||||
target_speech_tokens = target_speech_tokens.squeeze().tolist()
|
target_speech_tokens = target_speech_tokens.squeeze().tolist()
|
||||||
|
|
||||||
|
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
|
||||||
|
|
||||||
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||||
wav_array = pb_utils.get_input_tensor_by_name(request, "reference_wav").as_numpy()
|
|
||||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len").as_numpy().item()
|
request_id = request.request_id()
|
||||||
wav = torch.from_numpy(wav_array)[:, :wav_len].squeeze(0)
|
|
||||||
|
|
||||||
|
wav_array = pb_utils.get_input_tensor_by_name(
|
||||||
|
request, "reference_wav").as_numpy()
|
||||||
|
wav_len = pb_utils.get_input_tensor_by_name(
|
||||||
|
request, "reference_wav_len").as_numpy().item()
|
||||||
|
|
||||||
|
wav_array = torch.from_numpy(wav_array)
|
||||||
|
# Prepare inputs
|
||||||
|
wav = wav_array[:, :wav_len].squeeze(0)
|
||||||
|
|
||||||
spk_id = get_spk_id_from_prompt_audio(wav)
|
spk_id = get_spk_id_from_prompt_audio(wav)
|
||||||
|
# wav = wav.to(self.device)
|
||||||
|
|
||||||
# Handle cache
|
# update cache before forward
|
||||||
conformer_cnn_cache = pb_utils.get_input_tensor_by_name(request, "conformer_cnn_cache")
|
# self.token2wav_model.streaming_flow_cache[request_id]
|
||||||
if conformer_cnn_cache is not None:
|
# self.token2wav_model.hift_cache_dict[request_id]
|
||||||
self.token2wav_model.streaming_flow_cache[request_id]['conformer_cnn_cache'] = torch.utils.dlpack.from_dlpack(conformer_cnn_cache.to_dlpack())
|
|
||||||
|
|
||||||
conformer_att_cache_np = pb_utils.get_input_tensor_by_name(request, "conformer_att_cache")
|
|
||||||
self.token2wav_model.streaming_flow_cache[request_id]['conformer_att_cache'] = torch.utils.dlpack.from_dlpack(conformer_att_cache_np.to_dlpack()).transpose(0,1)
|
|
||||||
|
|
||||||
estimator_cnn_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_cnn_cache")
|
|
||||||
self.token2wav_model.streaming_flow_cache[request_id]['estimator_cnn_cache'] = torch.utils.dlpack.from_dlpack(estimator_cnn_cache_np.to_dlpack()).squeeze(0)
|
|
||||||
|
|
||||||
estimator_att_cache_np = pb_utils.get_input_tensor_by_name(request, "estimator_att_cache")
|
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
|
||||||
self.token2wav_model.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.utils.dlpack.from_dlpack(estimator_att_cache_np.to_dlpack()).squeeze(0)
|
|
||||||
|
|
||||||
mel_np = pb_utils.get_input_tensor_by_name(request, "mel")
|
# get the cache after forward
|
||||||
self.token2wav_model.streaming_flow_cache[request_id]['mel'] = torch.utils.dlpack.from_dlpack(mel_np.to_dlpack())
|
|
||||||
|
|
||||||
source_np = pb_utils.get_input_tensor_by_name(request, "source")
|
|
||||||
self.token2wav_model.hift_cache_dict[request_id]['source'] = torch.utils.dlpack.from_dlpack(source_np.to_dlpack())
|
|
||||||
|
|
||||||
speech_np = pb_utils.get_input_tensor_by_name(request, "speech")
|
|
||||||
self.token2wav_model.hift_cache_dict[request_id]['speech'] = torch.utils.dlpack.from_dlpack(speech_np.to_dlpack())
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
audio_hat = self.token2wav_model.forward_streaming(
|
|
||||||
target_speech_tokens,
|
|
||||||
finalize,
|
|
||||||
request_id=request_id,
|
|
||||||
speaker_id=f"{spk_id}",
|
|
||||||
prompt_audio=wav,
|
|
||||||
prompt_audio_sample_rate=16000
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare outputs
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
|
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
||||||
outputs.append(wav_tensor)
|
outputs.append(wav_tensor)
|
||||||
|
|
||||||
if request_id in self.token2wav_model.streaming_flow_cache:
|
|
||||||
cache = self.token2wav_model.streaming_flow_cache[request_id]
|
|
||||||
hifigan_cache = self.token2wav_model.hift_cache_dict[request_id]
|
|
||||||
conformer_cnn_cache = cache['conformer_cnn_cache']
|
|
||||||
conformer_att_cache = cache['conformer_att_cache'].transpose(0,1)
|
|
||||||
estimator_cnn_cache = cache['estimator_cnn_cache'].unsqueeze(0)
|
|
||||||
estimator_att_cache = cache['estimator_att_cache'].unsqueeze(0)
|
|
||||||
mel = hifigan_cache['mel']
|
|
||||||
source = hifigan_cache['source']
|
|
||||||
speech = hifigan_cache['speech']
|
|
||||||
|
|
||||||
outputs.extend([
|
|
||||||
pb_utils.Tensor.from_dlpack("conformer_cnn_cache", to_dlpack(conformer_cnn_cache)),
|
|
||||||
pb_utils.Tensor.from_dlpack("conformer_att_cache", to_dlpack(conformer_att_cache)),
|
|
||||||
pb_utils.Tensor.from_dlpack("estimator_cnn_cache", to_dlpack(estimator_cnn_cache)),
|
|
||||||
pb_utils.Tensor.from_dlpack("estimator_att_cache", to_dlpack(estimator_att_cache)),
|
|
||||||
pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)),
|
|
||||||
pb_utils.Tensor.from_dlpack("source", to_dlpack(source)),
|
|
||||||
pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)),
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
outputs.extend([pb_utils.Tensor("conformer_cnn_cache", np.array([], dtype=np.float16)),
|
|
||||||
pb_utils.Tensor("conformer_att_cache", np.array([], dtype=np.float16)),
|
|
||||||
pb_utils.Tensor("estimator_cnn_cache", np.array([], dtype=np.float16)),
|
|
||||||
pb_utils.Tensor("estimator_att_cache", np.array([], dtype=np.float16)),
|
|
||||||
pb_utils.Tensor("mel", np.array([], dtype=np.float32)),
|
|
||||||
pb_utils.Tensor("source", np.array([], dtype=np.float32)),
|
|
||||||
pb_utils.Tensor("speech", np.array([], dtype=np.float32)),
|
|
||||||
])
|
|
||||||
|
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
|
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
|
||||||
responses.append(inference_response)
|
responses.append(inference_response)
|
||||||
return responses
|
|
||||||
|
|
||||||
def finalize(self):
|
return responses
|
||||||
self.logger.log_info("Finalizing Token2WavDiT model")
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ dynamic_batching {
|
|||||||
default_priority_level: 10
|
default_priority_level: 10
|
||||||
}
|
}
|
||||||
|
|
||||||
parameters: { key: "FORCE_CPU_ONLY_INPUT_TENSORS" value: {string_value:"no"}}
|
|
||||||
parameters [
|
parameters [
|
||||||
{
|
{
|
||||||
key: "model_dir",
|
key: "model_dir",
|
||||||
@@ -52,48 +51,6 @@ input [
|
|||||||
dims: [ 1 ]
|
dims: [ 1 ]
|
||||||
reshape: { shape: [ ] }
|
reshape: { shape: [ ] }
|
||||||
optional: true
|
optional: true
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "conformer_cnn_cache"
|
|
||||||
data_type: TYPE_FP16
|
|
||||||
dims: [ 512, -1 ]
|
|
||||||
optional: true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "conformer_att_cache"
|
|
||||||
data_type: TYPE_FP16
|
|
||||||
dims: [ 10, 8, -1, 128 ]
|
|
||||||
optional: true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "estimator_cnn_cache"
|
|
||||||
data_type: TYPE_FP16
|
|
||||||
dims: [ 10, 16, -1, 1024, 2 ]
|
|
||||||
optional: true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "estimator_att_cache"
|
|
||||||
data_type: TYPE_FP16
|
|
||||||
dims: [ 10, 16, -1, 8, -1, 128 ]
|
|
||||||
optional: true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mel"
|
|
||||||
data_type: TYPE_FP32
|
|
||||||
dims: [ 80, -1 ]
|
|
||||||
optional: true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "source"
|
|
||||||
data_type: TYPE_FP32
|
|
||||||
dims: [ 1, -1 ]
|
|
||||||
optional: true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "speech"
|
|
||||||
data_type: TYPE_FP32
|
|
||||||
dims: [ -1 ]
|
|
||||||
optional: true
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
output [
|
output [
|
||||||
@@ -101,41 +58,6 @@ output [
|
|||||||
name: "waveform"
|
name: "waveform"
|
||||||
data_type: TYPE_FP32
|
data_type: TYPE_FP32
|
||||||
dims: [ -1 ]
|
dims: [ -1 ]
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "conformer_cnn_cache"
|
|
||||||
data_type: TYPE_FP16
|
|
||||||
dims: [ 512, -1 ]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "conformer_att_cache"
|
|
||||||
data_type: TYPE_FP16
|
|
||||||
dims: [ 10, 8, -1, 128 ]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "estimator_cnn_cache"
|
|
||||||
data_type: TYPE_FP16
|
|
||||||
dims: [ 10, 16, -1, 1024, 2 ]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "estimator_att_cache"
|
|
||||||
data_type: TYPE_FP16
|
|
||||||
dims: [ 10, 16, -1, 8, -1, 128 ]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mel"
|
|
||||||
data_type: TYPE_FP32
|
|
||||||
dims: [ 80, -1 ]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "source"
|
|
||||||
data_type: TYPE_FP32
|
|
||||||
dims: [ 1, -1 ]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "speech"
|
|
||||||
data_type: TYPE_FP32
|
|
||||||
dims: [ -1 ]
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
|
# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
|
||||||
export CUDA_VISIBLE_DEVICES=1
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
cosyvoice_path=/workspace/CosyVoice
|
cosyvoice_path=/workspace/CosyVoice
|
||||||
cosyvoice_path=/workspace_yuekai/tts/CosyVoice
|
cosyvoice_path=/workspace_yuekai/tts/CosyVoice
|
||||||
stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
|
stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
|
||||||
@@ -112,7 +112,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
MODEL_DIR=$model_scope_model_local_dir
|
MODEL_DIR=$model_scope_model_local_dir
|
||||||
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
|
LLM_TOKENIZER_DIR=$huggingface_model_local_dir
|
||||||
BLS_INSTANCE_NUM=4
|
BLS_INSTANCE_NUM=4
|
||||||
TRITON_MAX_BATCH_SIZE=32
|
TRITON_MAX_BATCH_SIZE=1
|
||||||
DECOUPLED_MODE=True # True for streaming, False for offline
|
DECOUPLED_MODE=True # True for streaming, False for offline
|
||||||
STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
|
STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
|
||||||
|
|
||||||
@@ -154,7 +154,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
--num-tasks $num_task \
|
--num-tasks $num_task \
|
||||||
--mode $mode \
|
--mode $mode \
|
||||||
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
||||||
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_no_att_cnn_cache_new
|
--log-dir ./log_debug_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
@@ -185,14 +185,14 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
|
|
||||||
python3 streaming_inference.py
|
CUDA_VISIBLE_DEVICES=2 python3 streaming_inference.py --enable-trt --strategy exponential
|
||||||
|
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||||
mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16
|
CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 --kv_cache_free_gpu_memory_fraction 0.4
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ def get_args():
|
|||||||
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
||||||
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
||||||
parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
|
parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
|
||||||
|
parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"])
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -53,12 +54,14 @@ if __name__ == "__main__":
|
|||||||
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
|
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
|
||||||
|
|
||||||
flow_pre_lookahead_len = 3
|
flow_pre_lookahead_len = 3
|
||||||
CHUNK_SIZE = 25
|
CHUNK_SIZE = 15
|
||||||
|
token_frame_rate = 25
|
||||||
OVERLAP_SIZE = 0
|
OVERLAP_SIZE = 0
|
||||||
|
|
||||||
warmup_times = 3
|
warmup_times = 3
|
||||||
for _ in range(warmup_times):
|
for _ in range(warmup_times):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
total_forward_count = 0
|
||||||
for batch in data_loader:
|
for batch in data_loader:
|
||||||
tts_speech_list = []
|
tts_speech_list = []
|
||||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
|
||||||
@@ -83,17 +86,26 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
buffer = generated_speech_tokens
|
buffer = generated_speech_tokens
|
||||||
output_wavs = []
|
output_wavs = []
|
||||||
|
chunk_index = 0
|
||||||
while True:
|
while True:
|
||||||
|
if args.strategy == "equal":
|
||||||
|
this_chunk_size = CHUNK_SIZE
|
||||||
|
elif args.strategy == "exponential":
|
||||||
|
this_chunk_size = token_frame_rate * (2 ** chunk_index)
|
||||||
|
|
||||||
if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len:
|
if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
|
||||||
wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
wavs = token2wav_model.forward_streaming(buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
||||||
buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:]
|
buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
|
||||||
|
|
||||||
output_wavs.append(wavs)
|
output_wavs.append(wavs)
|
||||||
|
total_forward_count += 1
|
||||||
|
chunk_index += 1
|
||||||
|
|
||||||
else:
|
else:
|
||||||
wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
|
||||||
output_wavs.append(wavs)
|
output_wavs.append(wavs)
|
||||||
|
total_forward_count += 1
|
||||||
|
# chunk_index += 1
|
||||||
break
|
break
|
||||||
|
|
||||||
for i, wav in enumerate(output_wavs):
|
for i, wav in enumerate(output_wavs):
|
||||||
@@ -112,4 +124,4 @@ if __name__ == "__main__":
|
|||||||
if _ == 0:
|
if _ == 0:
|
||||||
token2wav_model.speaker_cache = {}
|
token2wav_model.speaker_cache = {}
|
||||||
print(f"Warmup time: {end_time - start_time} seconds")
|
print(f"Warmup time: {end_time - start_time} seconds")
|
||||||
|
print(f"Total forward count: {total_forward_count}")
|
||||||
|
|||||||
Reference in New Issue
Block a user