add prompt audio cache

This commit is contained in:
yuekaiz
2025-09-05 13:54:39 +08:00
parent 86e7c2d731
commit 6971536358
7 changed files with 112 additions and 53 deletions

View File

@@ -257,7 +257,13 @@ def get_args():
default=0.1,
help="Chunk overlap duration for streaming reconstruction (in seconds)."
)
# --- End Added arguments ---
parser.add_argument(
"--use-spk2info-cache",
type=bool,
default=False,
help="Use spk2info cache for reference audio.",
)
return parser.parse_args()
@@ -283,7 +289,8 @@ def prepare_request_input_output(
reference_text,
target_text,
sample_rate=16000,
padding_duration: int = None # Optional padding for offline mode
padding_duration: int = None, # Optional padding for offline mode
use_spk2info_cache: bool = False
):
"""Prepares inputs for Triton inference (offline or streaming)."""
assert len(waveform.shape) == 1, "waveform should be 1D"
@@ -330,7 +337,8 @@ def prepare_request_input_output(
inputs[3].set_data_from_numpy(input_data_numpy)
outputs = [protocol_client.InferRequestedOutput("waveform")]
if use_spk2info_cache:
inputs = inputs[-1:]
return inputs, outputs
@@ -453,6 +461,7 @@ async def send_streaming(
save_sample_rate: int = 16000,
chunk_overlap_duration: float = 0.1,
padding_duration: int = None,
use_spk2info_cache: bool = False,
):
total_duration = 0.0
latency_data = []
@@ -478,7 +487,8 @@ async def send_streaming(
reference_text,
target_text,
sample_rate,
padding_duration=padding_duration
padding_duration=padding_duration,
use_spk2info_cache=use_spk2info_cache
)
request_id = str(uuid.uuid4())
user_data = UserData()
@@ -534,6 +544,7 @@ async def send(
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
use_spk2info_cache: bool = False,
):
total_duration = 0.0
latency_data = []
@@ -552,7 +563,8 @@ async def send(
reference_text,
target_text,
sample_rate,
padding_duration=padding_duration
padding_duration=padding_duration,
use_spk2info_cache=use_spk2info_cache
)
sequence_id = 100000000 + i + task_id * 10
start = time.time()
@@ -691,6 +703,7 @@ async def main():
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
use_spk2info_cache=args.use_spk2info_cache,
)
)
elif args.mode == "streaming":
@@ -706,6 +719,7 @@ async def main():
padding_duration=10,
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
chunk_overlap_duration=args.chunk_overlap_duration,
use_spk2info_cache=args.use_spk2info_cache,
)
)
# --- End Task Creation ---