mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
191
funasr_local/runtime/triton_gpu/client/client.py
Normal file
191
funasr_local/runtime/triton_gpu/client/client.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tritonclient.grpc as grpcclient
|
||||
from utils import cal_cer
|
||||
from speech_client import *
|
||||
import numpy as np
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
required=False,
|
||||
default=False,
|
||||
help="Enable verbose output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u",
|
||||
"--url",
|
||||
type=str,
|
||||
required=False,
|
||||
default="localhost:10086",
|
||||
help="Inference server URL. Default is " "localhost:8001.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
required=False,
|
||||
default="attention_rescoring",
|
||||
choices=["attention_rescoring", "streaming_wenet", "infer_pipeline"],
|
||||
help="the model to send request to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wavscp",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="audio_id \t wav_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trans",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="audio_id \t text",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="path prefix for wav_path in wavscp/audio_file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio_file",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="single wav file path",
|
||||
)
|
||||
# below arguments are for streaming
|
||||
# Please check onnx_config.yaml and train.yaml
|
||||
parser.add_argument("--streaming", action="store_true", required=False)
|
||||
parser.add_argument(
|
||||
"--sample_rate",
|
||||
type=int,
|
||||
required=False,
|
||||
default=16000,
|
||||
help="sample rate used in training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frame_length_ms",
|
||||
type=int,
|
||||
required=False,
|
||||
default=25,
|
||||
help="frame length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frame_shift_ms",
|
||||
type=int,
|
||||
required=False,
|
||||
default=10,
|
||||
help="frame shift length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk_size",
|
||||
type=int,
|
||||
required=False,
|
||||
default=16,
|
||||
help="chunk size default is 16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context",
|
||||
type=int,
|
||||
required=False,
|
||||
default=7,
|
||||
help="subsampling context",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subsampling",
|
||||
type=int,
|
||||
required=False,
|
||||
default=4,
|
||||
help="subsampling rate",
|
||||
)
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
print(FLAGS)
|
||||
|
||||
# load data
|
||||
filenames = []
|
||||
transcripts = []
|
||||
if FLAGS.audio_file is not None:
|
||||
path = FLAGS.audio_file
|
||||
if FLAGS.data_dir:
|
||||
path = os.path.join(FLAGS.data_dir, path)
|
||||
if os.path.exists(path):
|
||||
filenames = [path]
|
||||
elif FLAGS.wavscp is not None:
|
||||
audio_data = {}
|
||||
with open(FLAGS.wavscp, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
aid, path = line.strip().split("\t")
|
||||
if FLAGS.data_dir:
|
||||
path = os.path.join(FLAGS.data_dir, path)
|
||||
audio_data[aid] = {"path": path}
|
||||
with open(FLAGS.trans, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
aid, text = line.strip().split("\t")
|
||||
audio_data[aid]["text"] = text
|
||||
for key, value in audio_data.items():
|
||||
filenames.append(value["path"])
|
||||
transcripts.append(value["text"])
|
||||
|
||||
num_workers = multiprocessing.cpu_count() // 2
|
||||
|
||||
if FLAGS.streaming:
|
||||
speech_client_cls = StreamingSpeechClient
|
||||
else:
|
||||
speech_client_cls = OfflineSpeechClient
|
||||
|
||||
def single_job(client_files):
|
||||
with grpcclient.InferenceServerClient(
|
||||
url=FLAGS.url, verbose=FLAGS.verbose
|
||||
) as triton_client:
|
||||
protocol_client = grpcclient
|
||||
speech_client = speech_client_cls(
|
||||
triton_client, FLAGS.model_name, protocol_client, FLAGS
|
||||
)
|
||||
idx, audio_files = client_files
|
||||
predictions = []
|
||||
for li in audio_files:
|
||||
result = speech_client.recognize(li, idx)
|
||||
print("Recognized {}:{}".format(li, result[0]))
|
||||
predictions += result
|
||||
return predictions
|
||||
|
||||
# start to do inference
|
||||
# Group requests in batches
|
||||
predictions = []
|
||||
tasks = []
|
||||
splits = np.array_split(filenames, num_workers)
|
||||
|
||||
for idx, per_split in enumerate(splits):
|
||||
cur_files = per_split.tolist()
|
||||
tasks.append((idx, cur_files))
|
||||
|
||||
with Pool(processes=num_workers) as pool:
|
||||
predictions = pool.map(single_job, tasks)
|
||||
|
||||
predictions = [item for sublist in predictions for item in sublist]
|
||||
if transcripts:
|
||||
cer = cal_cer(predictions, transcripts)
|
||||
print("CER is: {}".format(cer))
|
||||
Reference in New Issue
Block a user