This commit is contained in:
root
2025-07-29 08:39:41 +00:00
parent 1b8d194b67
commit 07cbc51cd1
8 changed files with 165 additions and 157 deletions

View File

@@ -29,6 +29,7 @@ import json
import numpy as np
import argparse
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -67,9 +68,10 @@ def get_args():
type=str,
default="spark_tts",
choices=[
"f5_tts", "spark_tts", "cosyvoice2"
],
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
"f5_tts",
"spark_tts",
"cosyvoice2"],
help="triton model_repo module name to request",
)
parser.add_argument(
@@ -80,6 +82,7 @@ def get_args():
)
return parser.parse_args()
def prepare_request(
waveform,
reference_text,
@@ -97,7 +100,7 @@ def prepare_request(
1,
padding_duration
* sample_rate
* ((int(duration) // padding_duration) + 1),
* ((int(len(waveform) / sample_rate) // padding_duration) + 1),
),
dtype=np.float32,
)
@@ -105,11 +108,11 @@ def prepare_request(
samples[0, : len(waveform)] = waveform
else:
samples = waveform
samples = samples.reshape(1, -1).astype(np.float32)
data = {
"inputs":[
"inputs": [
{
"name": "reference_wav",
"shape": samples.shape,
@@ -139,16 +142,17 @@ def prepare_request(
return data
if __name__ == "__main__":
args = get_args()
server_url = args.server_url
if not server_url.startswith(("http://", "https://")):
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
waveform, sr = sf.read(args.reference_audio)
assert sr == 16000, "sample rate hardcoded in server"
samples = np.array(waveform, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)
@@ -166,4 +170,4 @@ if __name__ == "__main__":
sample_rate = 16000
else:
sample_rate = 24000
sf.write(args.output_audio, audio, sample_rate, "PCM_16")
sf.write(args.output_audio, audio, sample_rate, "PCM_16")