mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -29,6 +29,7 @@ import json
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@@ -67,9 +68,10 @@ def get_args():
|
||||
type=str,
|
||||
default="spark_tts",
|
||||
choices=[
|
||||
"f5_tts", "spark_tts", "cosyvoice2"
|
||||
],
|
||||
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
||||
"f5_tts",
|
||||
"spark_tts",
|
||||
"cosyvoice2"],
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -80,6 +82,7 @@ def get_args():
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_request(
|
||||
waveform,
|
||||
reference_text,
|
||||
@@ -97,7 +100,7 @@ def prepare_request(
|
||||
1,
|
||||
padding_duration
|
||||
* sample_rate
|
||||
* ((int(duration) // padding_duration) + 1),
|
||||
* ((int(len(waveform) / sample_rate) // padding_duration) + 1),
|
||||
),
|
||||
dtype=np.float32,
|
||||
)
|
||||
@@ -105,11 +108,11 @@ def prepare_request(
|
||||
samples[0, : len(waveform)] = waveform
|
||||
else:
|
||||
samples = waveform
|
||||
|
||||
|
||||
samples = samples.reshape(1, -1).astype(np.float32)
|
||||
|
||||
data = {
|
||||
"inputs":[
|
||||
"inputs": [
|
||||
{
|
||||
"name": "reference_wav",
|
||||
"shape": samples.shape,
|
||||
@@ -139,16 +142,17 @@ def prepare_request(
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
server_url = args.server_url
|
||||
if not server_url.startswith(("http://", "https://")):
|
||||
server_url = f"http://{server_url}"
|
||||
|
||||
|
||||
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
||||
waveform, sr = sf.read(args.reference_audio)
|
||||
assert sr == 16000, "sample rate hardcoded in server"
|
||||
|
||||
|
||||
samples = np.array(waveform, dtype=np.float32)
|
||||
data = prepare_request(samples, args.reference_text, args.target_text)
|
||||
|
||||
@@ -166,4 +170,4 @@ if __name__ == "__main__":
|
||||
sample_rate = 16000
|
||||
else:
|
||||
sample_rate = 24000
|
||||
sf.write(args.output_audio, audio, sample_rate, "PCM_16")
|
||||
sf.write(args.output_audio, audio, sample_rate, "PCM_16")
|
||||
|
||||
Reference in New Issue
Block a user