mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 09:59:23 +08:00
fix lint
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
# 2023 Nvidia (authors: Yuekai Zhang)
|
# 2023 Nvidia (authors: Yuekai Zhang)
|
||||||
# 2023 Recurrent.ai (authors: Songtao Shi)
|
# 2023 Recurrent.ai (authors: Songtao Shi)
|
||||||
@@ -46,7 +45,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import queue # Added
|
import queue # Added
|
||||||
import uuid # Added
|
import uuid # Added
|
||||||
import functools # Added
|
import functools # Added
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@@ -56,9 +55,9 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import tritonclient
|
import tritonclient
|
||||||
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
|
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
|
||||||
import tritonclient.grpc as grpcclient_sync # Added sync client import
|
import tritonclient.grpc as grpcclient_sync # Added sync client import
|
||||||
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
||||||
|
|
||||||
|
|
||||||
# --- Added UserData and callback ---
|
# --- Added UserData and callback ---
|
||||||
@@ -76,9 +75,10 @@ class UserData:
|
|||||||
return self._first_chunk_time - self._start_time
|
return self._first_chunk_time - self._start_time
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def callback(user_data, result, error):
|
def callback(user_data, result, error):
|
||||||
if user_data._first_chunk_time is None and not error:
|
if user_data._first_chunk_time is None and not error:
|
||||||
user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
||||||
if error:
|
if error:
|
||||||
user_data._completed_requests.put(error)
|
user_data._completed_requests.put(error)
|
||||||
else:
|
else:
|
||||||
@@ -206,8 +206,11 @@ def get_args():
|
|||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="f5_tts",
|
default="f5_tts",
|
||||||
choices=["f5_tts", "spark_tts", "cosyvoice2"],
|
choices=[
|
||||||
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
"f5_tts",
|
||||||
|
"spark_tts",
|
||||||
|
"cosyvoice2"],
|
||||||
|
help="triton model_repo module name to request",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -273,13 +276,14 @@ def load_audio(wav_path, target_sample_rate=16000):
|
|||||||
waveform = resample(waveform, num_samples)
|
waveform = resample(waveform, num_samples)
|
||||||
return waveform, target_sample_rate
|
return waveform, target_sample_rate
|
||||||
|
|
||||||
|
|
||||||
def prepare_request_input_output(
|
def prepare_request_input_output(
|
||||||
protocol_client, # Can be grpcclient_aio or grpcclient_sync
|
protocol_client, # Can be grpcclient_aio or grpcclient_sync
|
||||||
waveform,
|
waveform,
|
||||||
reference_text,
|
reference_text,
|
||||||
target_text,
|
target_text,
|
||||||
sample_rate=16000,
|
sample_rate=16000,
|
||||||
padding_duration: int = None # Optional padding for offline mode
|
padding_duration: int = None # Optional padding for offline mode
|
||||||
):
|
):
|
||||||
"""Prepares inputs for Triton inference (offline or streaming)."""
|
"""Prepares inputs for Triton inference (offline or streaming)."""
|
||||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||||
@@ -291,9 +295,9 @@ def prepare_request_input_output(
|
|||||||
# Estimate target duration based on text length ratio (crude estimation)
|
# Estimate target duration based on text length ratio (crude estimation)
|
||||||
# Avoid division by zero if reference_text is empty
|
# Avoid division by zero if reference_text is empty
|
||||||
if reference_text:
|
if reference_text:
|
||||||
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
||||||
else:
|
else:
|
||||||
estimated_target_duration = duration # Assume target duration similar to reference if no text
|
estimated_target_duration = duration # Assume target duration similar to reference if no text
|
||||||
|
|
||||||
# Calculate required samples based on estimated total duration
|
# Calculate required samples based on estimated total duration
|
||||||
required_total_samples = padding_duration * sample_rate * (
|
required_total_samples = padding_duration * sample_rate * (
|
||||||
@@ -329,6 +333,7 @@ def prepare_request_input_output(
|
|||||||
|
|
||||||
return inputs, outputs
|
return inputs, outputs
|
||||||
|
|
||||||
|
|
||||||
def run_sync_streaming_inference(
|
def run_sync_streaming_inference(
|
||||||
sync_triton_client: tritonclient.grpc.InferenceServerClient,
|
sync_triton_client: tritonclient.grpc.InferenceServerClient,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@@ -342,7 +347,7 @@ def run_sync_streaming_inference(
|
|||||||
):
|
):
|
||||||
"""Helper function to run the blocking sync streaming call."""
|
"""Helper function to run the blocking sync streaming call."""
|
||||||
start_time_total = time.time()
|
start_time_total = time.time()
|
||||||
user_data.record_start_time() # Record start time for first chunk latency calculation
|
user_data.record_start_time() # Record start time for first chunk latency calculation
|
||||||
|
|
||||||
# Establish stream
|
# Establish stream
|
||||||
sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
|
sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
|
||||||
@@ -360,11 +365,11 @@ def run_sync_streaming_inference(
|
|||||||
audios = []
|
audios = []
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
result = user_data._completed_requests.get() # Add timeout
|
result = user_data._completed_requests.get() # Add timeout
|
||||||
if isinstance(result, InferenceServerException):
|
if isinstance(result, InferenceServerException):
|
||||||
print(f"Received InferenceServerException: {result}")
|
print(f"Received InferenceServerException: {result}")
|
||||||
sync_triton_client.stop_stream()
|
sync_triton_client.stop_stream()
|
||||||
return None, None, None # Indicate error
|
return None, None, None # Indicate error
|
||||||
# Get response metadata
|
# Get response metadata
|
||||||
response = result.get_response()
|
response = result.get_response()
|
||||||
final = response.parameters["triton_final_response"].bool_param
|
final = response.parameters["triton_final_response"].bool_param
|
||||||
@@ -372,15 +377,15 @@ def run_sync_streaming_inference(
|
|||||||
break
|
break
|
||||||
|
|
||||||
audio_chunk = result.as_numpy("waveform").reshape(-1)
|
audio_chunk = result.as_numpy("waveform").reshape(-1)
|
||||||
if audio_chunk.size > 0: # Only append non-empty chunks
|
if audio_chunk.size > 0: # Only append non-empty chunks
|
||||||
audios.append(audio_chunk)
|
audios.append(audio_chunk)
|
||||||
else:
|
else:
|
||||||
print("Warning: received empty audio chunk.")
|
print("Warning: received empty audio chunk.")
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
print(f"Timeout waiting for response for request id {request_id}")
|
print(f"Timeout waiting for response for request id {request_id}")
|
||||||
sync_triton_client.stop_stream()
|
sync_triton_client.stop_stream()
|
||||||
return None, None, None # Indicate error
|
return None, None, None # Indicate error
|
||||||
|
|
||||||
sync_triton_client.stop_stream()
|
sync_triton_client.stop_stream()
|
||||||
end_time_total = time.time()
|
end_time_total = time.time()
|
||||||
@@ -398,19 +403,19 @@ def run_sync_streaming_inference(
|
|||||||
# Simplified reconstruction based on client_grpc_streaming.py
|
# Simplified reconstruction based on client_grpc_streaming.py
|
||||||
if not audios:
|
if not audios:
|
||||||
print("Warning: No audio chunks received.")
|
print("Warning: No audio chunks received.")
|
||||||
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
|
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
|
||||||
elif len(audios) == 1:
|
elif len(audios) == 1:
|
||||||
reconstructed_audio = audios[0]
|
reconstructed_audio = audios[0]
|
||||||
else:
|
else:
|
||||||
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
|
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
|
||||||
for i in range(1, len(audios)):
|
for i in range(1, len(audios)):
|
||||||
# Cross-fade section
|
# Cross-fade section
|
||||||
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
||||||
audios[i - 1][-cross_fade_samples:] * fade_out)
|
audios[i - 1][-cross_fade_samples:] * fade_out)
|
||||||
# Middle section of the current chunk
|
# Middle section of the current chunk
|
||||||
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
||||||
# Concatenate
|
# Concatenate
|
||||||
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
||||||
# Add the last part of the final chunk
|
# Add the last part of the final chunk
|
||||||
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
||||||
|
|
||||||
@@ -421,11 +426,11 @@ def run_sync_streaming_inference(
|
|||||||
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
||||||
else:
|
else:
|
||||||
print("Warning: No audio chunks received or reconstructed.")
|
print("Warning: No audio chunks received or reconstructed.")
|
||||||
actual_duration = 0 # Set duration to 0 if no audio
|
actual_duration = 0 # Set duration to 0 if no audio
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("Warning: No audio chunks received.")
|
print("Warning: No audio chunks received.")
|
||||||
actual_duration = 0
|
actual_duration = 0
|
||||||
|
|
||||||
return total_request_latency, first_chunk_latency, actual_duration
|
return total_request_latency, first_chunk_latency, actual_duration
|
||||||
|
|
||||||
@@ -433,7 +438,7 @@ def run_sync_streaming_inference(
|
|||||||
async def send_streaming(
|
async def send_streaming(
|
||||||
manifest_item_list: list,
|
manifest_item_list: list,
|
||||||
name: str,
|
name: str,
|
||||||
server_url: str, # Changed from sync_triton_client
|
server_url: str, # Changed from sync_triton_client
|
||||||
protocol_client: types.ModuleType,
|
protocol_client: types.ModuleType,
|
||||||
log_interval: int,
|
log_interval: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@@ -445,11 +450,11 @@ async def send_streaming(
|
|||||||
total_duration = 0.0
|
total_duration = 0.0
|
||||||
latency_data = []
|
latency_data = []
|
||||||
task_id = int(name[5:])
|
task_id = int(name[5:])
|
||||||
sync_triton_client = None # Initialize client variable
|
sync_triton_client = None # Initialize client variable
|
||||||
|
|
||||||
try: # Wrap in try...finally to ensure client closing
|
try: # Wrap in try...finally to ensure client closing
|
||||||
print(f"{name}: Initializing sync client for streaming...")
|
print(f"{name}: Initializing sync client for streaming...")
|
||||||
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
||||||
|
|
||||||
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
||||||
for i, item in enumerate(manifest_item_list):
|
for i, item in enumerate(manifest_item_list):
|
||||||
@@ -491,8 +496,7 @@ async def send_streaming(
|
|||||||
latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
|
latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
|
||||||
total_duration += actual_duration
|
total_duration += actual_duration
|
||||||
else:
|
else:
|
||||||
print(f"{name}: Item {i} failed.")
|
print(f"{name}: Item {i} failed.")
|
||||||
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
||||||
@@ -501,8 +505,7 @@ async def send_streaming(
|
|||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
finally: # Ensure client is closed
|
||||||
finally: # Ensure client is closed
|
|
||||||
if sync_triton_client:
|
if sync_triton_client:
|
||||||
try:
|
try:
|
||||||
print(f"{name}: Closing sync client...")
|
print(f"{name}: Closing sync client...")
|
||||||
@@ -510,10 +513,10 @@ async def send_streaming(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{name}: Error closing sync client: {e}")
|
print(f"{name}: Error closing sync client: {e}")
|
||||||
|
|
||||||
|
|
||||||
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
|
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
|
||||||
return total_duration, latency_data
|
return total_duration, latency_data
|
||||||
|
|
||||||
|
|
||||||
async def send(
|
async def send(
|
||||||
manifest_item_list: list,
|
manifest_item_list: list,
|
||||||
name: str,
|
name: str,
|
||||||
@@ -605,6 +608,7 @@ def split_data(data, k):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
url = f"{args.server_addr}:{args.server_port}"
|
url = f"{args.server_addr}:{args.server_port}"
|
||||||
@@ -622,7 +626,7 @@ async def main():
|
|||||||
# Use the sync client for streaming tasks, handled via asyncio.to_thread
|
# Use the sync client for streaming tasks, handled via asyncio.to_thread
|
||||||
# We will create one sync client instance PER TASK inside send_streaming.
|
# We will create one sync client instance PER TASK inside send_streaming.
|
||||||
# triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
|
# triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
|
||||||
protocol_client = grpcclient_sync # protocol client for input prep
|
protocol_client = grpcclient_sync # protocol client for input prep
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid mode: {args.mode}")
|
raise ValueError(f"Invalid mode: {args.mode}")
|
||||||
# --- End Client Initialization ---
|
# --- End Client Initialization ---
|
||||||
@@ -682,11 +686,11 @@ async def main():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif args.mode == "streaming":
|
elif args.mode == "streaming":
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
send_streaming(
|
send_streaming(
|
||||||
manifest_item_list[i],
|
manifest_item_list[i],
|
||||||
name=f"task-{i}",
|
name=f"task-{i}",
|
||||||
server_url=url, # Pass URL instead of client
|
server_url=url, # Pass URL instead of client
|
||||||
protocol_client=protocol_client,
|
protocol_client=protocol_client,
|
||||||
log_interval=args.log_interval,
|
log_interval=args.log_interval,
|
||||||
model_name=args.model_name,
|
model_name=args.model_name,
|
||||||
@@ -709,16 +713,15 @@ async def main():
|
|||||||
for ans in ans_list:
|
for ans in ans_list:
|
||||||
if ans:
|
if ans:
|
||||||
total_duration += ans[0]
|
total_duration += ans[0]
|
||||||
latency_data.extend(ans[1]) # Use extend for list of lists
|
latency_data.extend(ans[1]) # Use extend for list of lists
|
||||||
else:
|
else:
|
||||||
print("Warning: A task returned None, possibly due to an error.")
|
print("Warning: A task returned None, possibly due to an error.")
|
||||||
|
|
||||||
|
|
||||||
if total_duration == 0:
|
if total_duration == 0:
|
||||||
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
|
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
|
||||||
rtf = float('inf')
|
rtf = float('inf')
|
||||||
else:
|
else:
|
||||||
rtf = elapsed / total_duration
|
rtf = elapsed / total_duration
|
||||||
|
|
||||||
s = f"Mode: {args.mode}\n"
|
s = f"Mode: {args.mode}\n"
|
||||||
s += f"RTF: {rtf:.4f}\n"
|
s += f"RTF: {rtf:.4f}\n"
|
||||||
@@ -759,7 +762,7 @@ async def main():
|
|||||||
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
|
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
|
||||||
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
|
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
|
||||||
else:
|
else:
|
||||||
s += "No total request latency data collected.\n"
|
s += "No total request latency data collected.\n"
|
||||||
|
|
||||||
s += "\n--- First Chunk Latency ---\n"
|
s += "\n--- First Chunk Latency ---\n"
|
||||||
if first_chunk_latency_list:
|
if first_chunk_latency_list:
|
||||||
@@ -772,7 +775,7 @@ async def main():
|
|||||||
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
||||||
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
||||||
else:
|
else:
|
||||||
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
||||||
else:
|
else:
|
||||||
s += "No latency data collected.\n"
|
s += "No latency data collected.\n"
|
||||||
# --- End Statistics Reporting ---
|
# --- End Statistics Reporting ---
|
||||||
@@ -785,7 +788,7 @@ async def main():
|
|||||||
elif args.reference_audio:
|
elif args.reference_audio:
|
||||||
name = Path(args.reference_audio).stem
|
name = Path(args.reference_audio).stem
|
||||||
else:
|
else:
|
||||||
name = "results" # Default name if no manifest/split/audio provided
|
name = "results" # Default name if no manifest/split/audio provided
|
||||||
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
||||||
f.write(s)
|
f.write(s)
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import json
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@@ -67,9 +68,10 @@ def get_args():
|
|||||||
type=str,
|
type=str,
|
||||||
default="spark_tts",
|
default="spark_tts",
|
||||||
choices=[
|
choices=[
|
||||||
"f5_tts", "spark_tts", "cosyvoice2"
|
"f5_tts",
|
||||||
],
|
"spark_tts",
|
||||||
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
"cosyvoice2"],
|
||||||
|
help="triton model_repo module name to request",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -80,6 +82,7 @@ def get_args():
|
|||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def prepare_request(
|
def prepare_request(
|
||||||
waveform,
|
waveform,
|
||||||
reference_text,
|
reference_text,
|
||||||
@@ -97,7 +100,7 @@ def prepare_request(
|
|||||||
1,
|
1,
|
||||||
padding_duration
|
padding_duration
|
||||||
* sample_rate
|
* sample_rate
|
||||||
* ((int(duration) // padding_duration) + 1),
|
* ((int(len(waveform) / sample_rate) // padding_duration) + 1),
|
||||||
),
|
),
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
@@ -105,11 +108,11 @@ def prepare_request(
|
|||||||
samples[0, : len(waveform)] = waveform
|
samples[0, : len(waveform)] = waveform
|
||||||
else:
|
else:
|
||||||
samples = waveform
|
samples = waveform
|
||||||
|
|
||||||
samples = samples.reshape(1, -1).astype(np.float32)
|
samples = samples.reshape(1, -1).astype(np.float32)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"inputs":[
|
"inputs": [
|
||||||
{
|
{
|
||||||
"name": "reference_wav",
|
"name": "reference_wav",
|
||||||
"shape": samples.shape,
|
"shape": samples.shape,
|
||||||
@@ -139,16 +142,17 @@ def prepare_request(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args()
|
args = get_args()
|
||||||
server_url = args.server_url
|
server_url = args.server_url
|
||||||
if not server_url.startswith(("http://", "https://")):
|
if not server_url.startswith(("http://", "https://")):
|
||||||
server_url = f"http://{server_url}"
|
server_url = f"http://{server_url}"
|
||||||
|
|
||||||
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
||||||
waveform, sr = sf.read(args.reference_audio)
|
waveform, sr = sf.read(args.reference_audio)
|
||||||
assert sr == 16000, "sample rate hardcoded in server"
|
assert sr == 16000, "sample rate hardcoded in server"
|
||||||
|
|
||||||
samples = np.array(waveform, dtype=np.float32)
|
samples = np.array(waveform, dtype=np.float32)
|
||||||
data = prepare_request(samples, args.reference_text, args.target_text)
|
data = prepare_request(samples, args.reference_text, args.target_text)
|
||||||
|
|
||||||
@@ -166,4 +170,4 @@ if __name__ == "__main__":
|
|||||||
sample_rate = 16000
|
sample_rate = 16000
|
||||||
else:
|
else:
|
||||||
sample_rate = 24000
|
sample_rate = 24000
|
||||||
sf.write(args.output_audio, audio, sample_rate, "PCM_16")
|
sf.write(args.output_audio, audio, sample_rate, "PCM_16")
|
||||||
|
|||||||
@@ -35,33 +35,34 @@ import s3tokenizer
|
|||||||
|
|
||||||
ORIGINAL_VOCAB_SIZE = 151663
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
"""Triton Python model for audio tokenization.
|
"""Triton Python model for audio tokenization.
|
||||||
|
|
||||||
This model takes reference audio input and extracts semantic tokens
|
This model takes reference audio input and extracts semantic tokens
|
||||||
using s3tokenizer.
|
using s3tokenizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def initialize(self, args):
|
def initialize(self, args):
|
||||||
"""Initialize the model.
|
"""Initialize the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
args: Dictionary containing model configuration
|
args: Dictionary containing model configuration
|
||||||
"""
|
"""
|
||||||
# Parse model parameters
|
# Parse model parameters
|
||||||
parameters = json.loads(args['model_config'])['parameters']
|
parameters = json.loads(args['model_config'])['parameters']
|
||||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||||
|
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
|
model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
|
||||||
self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
|
self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
|
||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
"""Execute inference on the batched requests.
|
"""Execute inference on the batched requests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requests: List of inference requests
|
requests: List of inference requests
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of inference responses containing tokenized outputs
|
List of inference responses containing tokenized outputs
|
||||||
"""
|
"""
|
||||||
@@ -79,18 +80,18 @@ class TritonPythonModel:
|
|||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
wav = wav_array[:, :wav_len].squeeze(0)
|
wav = wav_array[:, :wav_len].squeeze(0)
|
||||||
mels.append(s3tokenizer.log_mel_spectrogram(wav))
|
mels.append(s3tokenizer.log_mel_spectrogram(wav))
|
||||||
|
|
||||||
mels, mels_lens = s3tokenizer.padding(mels)
|
mels, mels_lens = s3tokenizer.padding(mels)
|
||||||
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
|
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
|
||||||
codes = codes.clone() + ORIGINAL_VOCAB_SIZE
|
codes = codes.clone() + ORIGINAL_VOCAB_SIZE
|
||||||
|
|
||||||
responses = []
|
responses = []
|
||||||
for i in range(len(requests)):
|
for i in range(len(requests)):
|
||||||
prompt_speech_tokens = codes[i, :codes_lens[i].item()]
|
prompt_speech_tokens = codes[i, :codes_lens[i].item()]
|
||||||
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
||||||
"prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
"prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
||||||
inference_response = pb_utils.InferenceResponse(
|
inference_response = pb_utils.InferenceResponse(
|
||||||
output_tensors=[prompt_speech_tokens_tensor])
|
output_tensors=[prompt_speech_tokens_tensor])
|
||||||
responses.append(inference_response)
|
responses.append(inference_response)
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|||||||
@@ -42,16 +42,17 @@ import onnxruntime
|
|||||||
|
|
||||||
from matcha.utils.audio import mel_spectrogram
|
from matcha.utils.audio import mel_spectrogram
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
"""Triton Python model for Spark TTS.
|
"""Triton Python model for Spark TTS.
|
||||||
|
|
||||||
This model orchestrates the end-to-end TTS pipeline by coordinating
|
This model orchestrates the end-to-end TTS pipeline by coordinating
|
||||||
between audio tokenizer, LLM, and vocoder components.
|
between audio tokenizer, LLM, and vocoder components.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def initialize(self, args):
|
def initialize(self, args):
|
||||||
"""Initialize the model.
|
"""Initialize the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
args: Dictionary containing model configuration
|
args: Dictionary containing model configuration
|
||||||
"""
|
"""
|
||||||
@@ -116,58 +117,58 @@ class TritonPythonModel:
|
|||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Convert inputs to Triton tensors
|
# Convert inputs to Triton tensors
|
||||||
input_tensor_list = [
|
input_tensor_list = [
|
||||||
pb_utils.Tensor(k, v) for k, v in input_dict.items()
|
pb_utils.Tensor(k, v) for k, v in input_dict.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create and execute inference request
|
# Create and execute inference request
|
||||||
llm_request = pb_utils.InferenceRequest(
|
llm_request = pb_utils.InferenceRequest(
|
||||||
model_name="tensorrt_llm",
|
model_name="tensorrt_llm",
|
||||||
requested_output_names=["output_ids", "sequence_length"],
|
requested_output_names=["output_ids", "sequence_length"],
|
||||||
inputs=input_tensor_list,
|
inputs=input_tensor_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_responses = llm_request.exec(decoupled=self.decoupled)
|
llm_responses = llm_request.exec(decoupled=self.decoupled)
|
||||||
if self.decoupled:
|
if self.decoupled:
|
||||||
for llm_response in llm_responses:
|
for llm_response in llm_responses:
|
||||||
if llm_response.has_error():
|
if llm_response.has_error():
|
||||||
raise pb_utils.TritonModelException(llm_response.error().message())
|
raise pb_utils.TritonModelException(llm_response.error().message())
|
||||||
|
|
||||||
# Extract and process output
|
# Extract and process output
|
||||||
output_ids = pb_utils.get_output_tensor_by_name(
|
output_ids = pb_utils.get_output_tensor_by_name(
|
||||||
llm_response, "output_ids").as_numpy()
|
llm_response, "output_ids").as_numpy()
|
||||||
seq_lens = pb_utils.get_output_tensor_by_name(
|
seq_lens = pb_utils.get_output_tensor_by_name(
|
||||||
llm_response, "sequence_length").as_numpy()
|
llm_response, "sequence_length").as_numpy()
|
||||||
|
|
||||||
# Get actual output IDs up to the sequence length
|
# Get actual output IDs up to the sequence length
|
||||||
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
||||||
|
|
||||||
yield actual_output_ids
|
yield actual_output_ids
|
||||||
else:
|
else:
|
||||||
llm_response = llm_responses
|
llm_response = llm_responses
|
||||||
if llm_response.has_error():
|
if llm_response.has_error():
|
||||||
raise pb_utils.TritonModelException(llm_response.error().message())
|
raise pb_utils.TritonModelException(llm_response.error().message())
|
||||||
|
|
||||||
# Extract and process output
|
# Extract and process output
|
||||||
output_ids = pb_utils.get_output_tensor_by_name(
|
output_ids = pb_utils.get_output_tensor_by_name(
|
||||||
llm_response, "output_ids").as_numpy()
|
llm_response, "output_ids").as_numpy()
|
||||||
seq_lens = pb_utils.get_output_tensor_by_name(
|
seq_lens = pb_utils.get_output_tensor_by_name(
|
||||||
llm_response, "sequence_length").as_numpy()
|
llm_response, "sequence_length").as_numpy()
|
||||||
|
|
||||||
# Get actual output IDs up to the sequence length
|
# Get actual output IDs up to the sequence length
|
||||||
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
||||||
|
|
||||||
yield actual_output_ids
|
yield actual_output_ids
|
||||||
|
|
||||||
def forward_audio_tokenizer(self, wav, wav_len):
|
def forward_audio_tokenizer(self, wav, wav_len):
|
||||||
"""Forward pass through the audio tokenizer component.
|
"""Forward pass through the audio tokenizer component.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
wav: Input waveform tensor
|
wav: Input waveform tensor
|
||||||
wav_len: Waveform length tensor
|
wav_len: Waveform length tensor
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of global and semantic tokens
|
Tuple of global and semantic tokens
|
||||||
"""
|
"""
|
||||||
@@ -176,26 +177,31 @@ class TritonPythonModel:
|
|||||||
requested_output_names=['prompt_speech_tokens'],
|
requested_output_names=['prompt_speech_tokens'],
|
||||||
inputs=[wav, wav_len]
|
inputs=[wav, wav_len]
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_response = inference_request.exec()
|
inference_response = inference_request.exec()
|
||||||
if inference_response.has_error():
|
if inference_response.has_error():
|
||||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||||
|
|
||||||
# Extract and convert output tensors
|
# Extract and convert output tensors
|
||||||
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
|
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
|
||||||
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
|
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
|
||||||
|
|
||||||
return prompt_speech_tokens
|
return prompt_speech_tokens
|
||||||
|
|
||||||
def forward_token2wav(self, prompt_speech_tokens: torch.Tensor, prompt_speech_feat: torch.Tensor, prompt_spk_embedding: torch.Tensor, target_speech_tokens: torch.Tensor) -> torch.Tensor:
|
def forward_token2wav(
|
||||||
|
self,
|
||||||
|
prompt_speech_tokens: torch.Tensor,
|
||||||
|
prompt_speech_feat: torch.Tensor,
|
||||||
|
prompt_spk_embedding: torch.Tensor,
|
||||||
|
target_speech_tokens: torch.Tensor) -> torch.Tensor:
|
||||||
"""Forward pass through the vocoder component.
|
"""Forward pass through the vocoder component.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt_speech_tokens: Prompt speech tokens tensor
|
prompt_speech_tokens: Prompt speech tokens tensor
|
||||||
prompt_speech_feat: Prompt speech feat tensor
|
prompt_speech_feat: Prompt speech feat tensor
|
||||||
prompt_spk_embedding: Prompt spk embedding tensor
|
prompt_spk_embedding: Prompt spk embedding tensor
|
||||||
target_speech_tokens: Target speech tokens tensor
|
target_speech_tokens: Target speech tokens tensor
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generated waveform tensor
|
Generated waveform tensor
|
||||||
"""
|
"""
|
||||||
@@ -203,22 +209,22 @@ class TritonPythonModel:
|
|||||||
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
||||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
||||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||||
|
|
||||||
# Create and execute inference request
|
# Create and execute inference request
|
||||||
inference_request = pb_utils.InferenceRequest(
|
inference_request = pb_utils.InferenceRequest(
|
||||||
model_name='token2wav',
|
model_name='token2wav',
|
||||||
requested_output_names=['waveform'],
|
requested_output_names=['waveform'],
|
||||||
inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_response = inference_request.exec()
|
inference_response = inference_request.exec()
|
||||||
if inference_response.has_error():
|
if inference_response.has_error():
|
||||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||||
|
|
||||||
# Extract and convert output waveform
|
# Extract and convert output waveform
|
||||||
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
||||||
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
||||||
|
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
def parse_input(self, text, prompt_text, prompt_speech_tokens):
|
def parse_input(self, text, prompt_text, prompt_speech_tokens):
|
||||||
@@ -231,43 +237,53 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
def _extract_spk_embedding(self, speech):
|
def _extract_spk_embedding(self, speech):
|
||||||
feat = kaldi.fbank(speech,
|
feat = kaldi.fbank(speech,
|
||||||
num_mel_bins=80,
|
num_mel_bins=80,
|
||||||
dither=0,
|
dither=0,
|
||||||
sample_frequency=16000)
|
sample_frequency=16000)
|
||||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||||
embedding = self.campplus_session.run(None,
|
embedding = self.campplus_session.run(None,
|
||||||
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
||||||
embedding = torch.tensor([embedding]).to(self.device).half()
|
embedding = torch.tensor([embedding]).to(self.device).half()
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
def _extract_speech_feat(self, speech):
|
def _extract_speech_feat(self, speech):
|
||||||
speech_feat = mel_spectrogram(speech, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000).squeeze(dim=0).transpose(0, 1).to(self.device)
|
speech_feat = mel_spectrogram(
|
||||||
|
speech,
|
||||||
|
n_fft=1920,
|
||||||
|
num_mels=80,
|
||||||
|
sampling_rate=24000,
|
||||||
|
hop_size=480,
|
||||||
|
win_size=1920,
|
||||||
|
fmin=0,
|
||||||
|
fmax=8000).squeeze(
|
||||||
|
dim=0).transpose(
|
||||||
|
0,
|
||||||
|
1).to(
|
||||||
|
self.device)
|
||||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||||
return speech_feat
|
return speech_feat
|
||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
"""Execute inference on the batched requests.
|
"""Execute inference on the batched requests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requests: List of inference requests
|
requests: List of inference requests
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of inference responses containing generated audio
|
List of inference responses containing generated audio
|
||||||
"""
|
"""
|
||||||
responses = []
|
responses = []
|
||||||
|
|
||||||
for request in requests:
|
for request in requests:
|
||||||
# Extract input tensors
|
# Extract input tensors
|
||||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||||
|
|
||||||
# Process reference audio through audio tokenizer
|
# Process reference audio through audio tokenizer
|
||||||
|
|
||||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
wav_tensor = wav.as_numpy()
|
wav_tensor = wav.as_numpy()
|
||||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||||
@@ -275,20 +291,20 @@ class TritonPythonModel:
|
|||||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||||
|
|
||||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||||
reference_text = reference_text[0][0].decode('utf-8')
|
reference_text = reference_text[0][0].decode('utf-8')
|
||||||
|
|
||||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||||
target_text = target_text[0][0].decode('utf-8')
|
target_text = target_text[0][0].decode('utf-8')
|
||||||
|
|
||||||
# Prepare prompt for LLM
|
# Prepare prompt for LLM
|
||||||
input_ids = self.parse_input(
|
input_ids = self.parse_input(
|
||||||
text=target_text,
|
text=target_text,
|
||||||
prompt_text=reference_text,
|
prompt_text=reference_text,
|
||||||
prompt_speech_tokens=prompt_speech_tokens,
|
prompt_speech_tokens=prompt_speech_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate semantic tokens with LLM
|
# Generate semantic tokens with LLM
|
||||||
generated_ids_iter = self.forward_llm(input_ids)
|
generated_ids_iter = self.forward_llm(input_ids)
|
||||||
|
|
||||||
@@ -305,13 +321,13 @@ class TritonPythonModel:
|
|||||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
|
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
|
||||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
||||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
||||||
|
|
||||||
# Prepare response
|
# Prepare response
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
response_sender.send(inference_response)
|
response_sender.send(inference_response)
|
||||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||||
self.logger.log_info(f"send tritonserver_response_complete_final to end")
|
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||||
else:
|
else:
|
||||||
generated_ids = next(generated_ids_iter)
|
generated_ids = next(generated_ids_iter)
|
||||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
|
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
|
||||||
@@ -320,11 +336,11 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
||||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
||||||
|
|
||||||
# Prepare response
|
# Prepare response
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
responses.append(inference_response)
|
responses.append(inference_response)
|
||||||
|
|
||||||
if not self.decoupled:
|
if not self.decoupled:
|
||||||
return responses
|
return responses
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
ORIGINAL_VOCAB_SIZE = 151663
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
|
|
||||||
|
|
||||||
class CosyVoice2:
|
class CosyVoice2:
|
||||||
|
|
||||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||||
@@ -66,6 +67,7 @@ class CosyVoice2:
|
|||||||
trt_concurrent,
|
trt_concurrent,
|
||||||
self.fp16)
|
self.fp16)
|
||||||
|
|
||||||
|
|
||||||
class CosyVoice2Model:
|
class CosyVoice2Model:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -109,16 +111,17 @@ class CosyVoice2Model:
|
|||||||
input_names = ["x", "mask", "mu", "cond"]
|
input_names = ["x", "mask", "mu", "cond"]
|
||||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
"""Triton Python model for vocoder.
|
"""Triton Python model for vocoder.
|
||||||
|
|
||||||
This model takes global and semantic tokens as input and generates audio waveforms
|
This model takes global and semantic tokens as input and generates audio waveforms
|
||||||
using the BiCodec vocoder.
|
using the BiCodec vocoder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def initialize(self, args):
|
def initialize(self, args):
|
||||||
"""Initialize the model.
|
"""Initialize the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
args: Dictionary containing model configuration
|
args: Dictionary containing model configuration
|
||||||
"""
|
"""
|
||||||
@@ -126,24 +129,23 @@ class TritonPythonModel:
|
|||||||
parameters = json.loads(args['model_config'])['parameters']
|
parameters = json.loads(args['model_config'])['parameters']
|
||||||
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
||||||
model_dir = model_params["model_dir"]
|
model_dir = model_params["model_dir"]
|
||||||
|
|
||||||
# Initialize device and vocoder
|
# Initialize device and vocoder
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
||||||
|
|
||||||
self.token2wav_model = CosyVoice2(
|
self.token2wav_model = CosyVoice2(
|
||||||
model_dir, load_jit=True, load_trt=True, fp16=True
|
model_dir, load_jit=True, load_trt=True, fp16=True
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Token2Wav initialized successfully")
|
logger.info("Token2Wav initialized successfully")
|
||||||
|
|
||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
"""Execute inference on the batched requests.
|
"""Execute inference on the batched requests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requests: List of inference requests
|
requests: List of inference requests
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of inference responses containing generated waveforms
|
List of inference responses containing generated waveforms
|
||||||
"""
|
"""
|
||||||
@@ -163,7 +165,7 @@ class TritonPythonModel:
|
|||||||
# shift the speech tokens according to the original vocab size
|
# shift the speech tokens according to the original vocab size
|
||||||
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
|
|
||||||
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
||||||
token=target_speech_tokens,
|
token=target_speech_tokens,
|
||||||
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
||||||
@@ -189,9 +191,5 @@ class TritonPythonModel:
|
|||||||
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
|
||||||
responses.append(inference_response)
|
responses.append(inference_response)
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -35,8 +35,7 @@ def parse_arguments():
|
|||||||
type=str,
|
type=str,
|
||||||
default='auto',
|
default='auto',
|
||||||
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
||||||
help=
|
help="The data type for the model weights and activations if not quantized. "
|
||||||
"The data type for the model weights and activations if not quantized. "
|
|
||||||
"If 'auto', the data type is automatically inferred from the source model; "
|
"If 'auto', the data type is automatically inferred from the source model; "
|
||||||
"however, if the source dtype is float32, it is converted to float16.")
|
"however, if the source dtype is float32, it is converted to float16.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -49,8 +48,7 @@ def parse_arguments():
|
|||||||
'--disable_weight_only_quant_plugin',
|
'--disable_weight_only_quant_plugin',
|
||||||
default=False,
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help=
|
help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
|
||||||
'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
|
|
||||||
'You must also use --use_weight_only for that argument to have an impact.'
|
'You must also use --use_weight_only for that argument to have an impact.'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -60,16 +58,14 @@ def parse_arguments():
|
|||||||
nargs='?',
|
nargs='?',
|
||||||
default='int8',
|
default='int8',
|
||||||
choices=['int8', 'int4', 'int4_gptq'],
|
choices=['int8', 'int4', 'int4_gptq'],
|
||||||
help=
|
help='Define the precision for the weights when using weight-only quantization.'
|
||||||
'Define the precision for the weights when using weight-only quantization.'
|
|
||||||
'You must also use --use_weight_only for that argument to have an impact.'
|
'You must also use --use_weight_only for that argument to have an impact.'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--calib_dataset',
|
'--calib_dataset',
|
||||||
type=str,
|
type=str,
|
||||||
default='ccdv/cnn_dailymail',
|
default='ccdv/cnn_dailymail',
|
||||||
help=
|
help="The huggingface dataset name or the local directory of the dataset for calibration."
|
||||||
"The huggingface dataset name or the local directory of the dataset for calibration."
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--smoothquant",
|
"--smoothquant",
|
||||||
@@ -83,31 +79,27 @@ def parse_arguments():
|
|||||||
'--per_channel',
|
'--per_channel',
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help=
|
help='By default, we use a single static scaling factor for the GEMM\'s result. '
|
||||||
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
|
||||||
'per_channel instead uses a different static scaling factor for each channel. '
|
'per_channel instead uses a different static scaling factor for each channel. '
|
||||||
'The latter is usually more accurate, but a little slower.')
|
'The latter is usually more accurate, but a little slower.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--per_token',
|
'--per_token',
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help=
|
help='By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||||||
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
|
||||||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||||||
'The latter is usually more accurate, but a little slower.')
|
'The latter is usually more accurate, but a little slower.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--int8_kv_cache',
|
'--int8_kv_cache',
|
||||||
default=False,
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help=
|
help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
||||||
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--per_group',
|
'--per_group',
|
||||||
default=False,
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help=
|
help='By default, we use a single static scaling factor to scale weights in the int4 range. '
|
||||||
'By default, we use a single static scaling factor to scale weights in the int4 range. '
|
|
||||||
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
||||||
'The flag is built for GPTQ/AWQ quantization.')
|
'The flag is built for GPTQ/AWQ quantization.')
|
||||||
|
|
||||||
@@ -121,16 +113,14 @@ def parse_arguments():
|
|||||||
'--use_parallel_embedding',
|
'--use_parallel_embedding',
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help=
|
help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
||||||
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--embedding_sharding_dim',
|
'--embedding_sharding_dim',
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
choices=[0, 1],
|
choices=[0, 1],
|
||||||
help=
|
help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
||||||
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
|
||||||
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
||||||
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
||||||
)
|
)
|
||||||
@@ -147,15 +137,13 @@ def parse_arguments():
|
|||||||
'--moe_tp_size',
|
'--moe_tp_size',
|
||||||
type=int,
|
type=int,
|
||||||
default=-1,
|
default=-1,
|
||||||
help=
|
help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
|
||||||
'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--moe_ep_size',
|
'--moe_ep_size',
|
||||||
type=int,
|
type=int,
|
||||||
default=-1,
|
default=-1,
|
||||||
help=
|
help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
|
||||||
'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
|
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
@@ -249,7 +237,7 @@ def convert_and_save_hf(args):
|
|||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
quant_config, override_fields = update_quant_config_from_hf(
|
quant_config, override_fields = update_quant_config_from_hf(
|
||||||
quant_config, hf_config, override_fields)
|
quant_config, hf_config, override_fields)
|
||||||
except:
|
except BaseException:
|
||||||
logger.warning("AutoConfig cannot load the huggingface config.")
|
logger.warning("AutoConfig cannot load the huggingface config.")
|
||||||
|
|
||||||
if args.smoothquant is not None or args.int8_kv_cache:
|
if args.smoothquant is not None or args.int8_kv_cache:
|
||||||
@@ -339,4 +327,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#! /usr/bin/env python3
|
# /usr/bin/env python3
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from string import Template
|
from string import Template
|
||||||
|
|
||||||
@@ -59,8 +59,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("file_path", help="path of the .pbtxt to modify")
|
parser.add_argument("file_path", help="path of the .pbtxt to modify")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"substitutions",
|
"substitutions",
|
||||||
help=
|
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
|
||||||
"substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--in_place",
|
parser.add_argument("--in_place",
|
||||||
"-i",
|
"-i",
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ def parse_arguments(args=None):
|
|||||||
parser.add_argument('--top_k', type=int, default=50)
|
parser.add_argument('--top_k', type=int, default=50)
|
||||||
parser.add_argument('--top_p', type=float, default=0.95)
|
parser.add_argument('--top_p', type=float, default=0.95)
|
||||||
|
|
||||||
|
|
||||||
return parser.parse_args(args=args)
|
return parser.parse_args(args=args)
|
||||||
|
|
||||||
|
|
||||||
@@ -60,7 +59,7 @@ def parse_input(tokenizer,
|
|||||||
input_ids = tokenizer.encode(
|
input_ids = tokenizer.encode(
|
||||||
curr_text)
|
curr_text)
|
||||||
batch_input_ids.append(input_ids)
|
batch_input_ids.append(input_ids)
|
||||||
|
|
||||||
batch_input_ids = [
|
batch_input_ids = [
|
||||||
torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
|
torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user