mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add prompt audio cache
This commit is contained in:
@@ -257,7 +257,13 @@ def get_args():
|
||||
default=0.1,
|
||||
help="Chunk overlap duration for streaming reconstruction (in seconds)."
|
||||
)
|
||||
# --- End Added arguments ---
|
||||
|
||||
parser.add_argument(
|
||||
"--use-spk2info-cache",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use spk2info cache for reference audio.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -283,7 +289,8 @@ def prepare_request_input_output(
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=16000,
|
||||
padding_duration: int = None # Optional padding for offline mode
|
||||
padding_duration: int = None, # Optional padding for offline mode
|
||||
use_spk2info_cache: bool = False
|
||||
):
|
||||
"""Prepares inputs for Triton inference (offline or streaming)."""
|
||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||
@@ -330,7 +337,8 @@ def prepare_request_input_output(
|
||||
inputs[3].set_data_from_numpy(input_data_numpy)
|
||||
|
||||
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
||||
|
||||
if use_spk2info_cache:
|
||||
inputs = inputs[-1:]
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
@@ -453,6 +461,7 @@ async def send_streaming(
|
||||
save_sample_rate: int = 16000,
|
||||
chunk_overlap_duration: float = 0.1,
|
||||
padding_duration: int = None,
|
||||
use_spk2info_cache: bool = False,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
@@ -478,7 +487,8 @@ async def send_streaming(
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate,
|
||||
padding_duration=padding_duration
|
||||
padding_duration=padding_duration,
|
||||
use_spk2info_cache=use_spk2info_cache
|
||||
)
|
||||
request_id = str(uuid.uuid4())
|
||||
user_data = UserData()
|
||||
@@ -534,6 +544,7 @@ async def send(
|
||||
padding_duration: int = None,
|
||||
audio_save_dir: str = "./",
|
||||
save_sample_rate: int = 16000,
|
||||
use_spk2info_cache: bool = False,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
@@ -552,7 +563,8 @@ async def send(
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate,
|
||||
padding_duration=padding_duration
|
||||
padding_duration=padding_duration,
|
||||
use_spk2info_cache=use_spk2info_cache
|
||||
)
|
||||
sequence_id = 100000000 + i + task_id * 10
|
||||
start = time.time()
|
||||
@@ -691,6 +703,7 @@ async def main():
|
||||
audio_save_dir=args.log_dir,
|
||||
padding_duration=1,
|
||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||
use_spk2info_cache=args.use_spk2info_cache,
|
||||
)
|
||||
)
|
||||
elif args.mode == "streaming":
|
||||
@@ -706,6 +719,7 @@ async def main():
|
||||
padding_duration=10,
|
||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||
chunk_overlap_duration=args.chunk_overlap_duration,
|
||||
use_spk2info_cache=args.use_spk2info_cache,
|
||||
)
|
||||
)
|
||||
# --- End Task Creation ---
|
||||
|
||||
Reference in New Issue
Block a user