add files

This commit is contained in:
烨玮
2025-02-20 12:17:03 +08:00
parent a21dd4555c
commit edd008441b
667 changed files with 473123 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,191 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
from multiprocessing import Pool
import argparse
import os
import tritonclient.grpc as grpcclient
from utils import cal_cer
from speech_client import *
import numpy as np
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-v",
"--verbose",
action="store_true",
required=False,
default=False,
help="Enable verbose output",
)
parser.add_argument(
"-u",
"--url",
type=str,
required=False,
default="localhost:10086",
help="Inference server URL. Default is " "localhost:8001.",
)
parser.add_argument(
"--model_name",
required=False,
default="attention_rescoring",
choices=["attention_rescoring", "streaming_wenet", "infer_pipeline"],
help="the model to send request to",
)
parser.add_argument(
"--wavscp",
type=str,
required=False,
default=None,
help="audio_id \t wav_path",
)
parser.add_argument(
"--trans",
type=str,
required=False,
default=None,
help="audio_id \t text",
)
parser.add_argument(
"--data_dir",
type=str,
required=False,
default=None,
help="path prefix for wav_path in wavscp/audio_file",
)
parser.add_argument(
"--audio_file",
type=str,
required=False,
default=None,
help="single wav file path",
)
# below arguments are for streaming
# Please check onnx_config.yaml and train.yaml
parser.add_argument("--streaming", action="store_true", required=False)
parser.add_argument(
"--sample_rate",
type=int,
required=False,
default=16000,
help="sample rate used in training",
)
parser.add_argument(
"--frame_length_ms",
type=int,
required=False,
default=25,
help="frame length",
)
parser.add_argument(
"--frame_shift_ms",
type=int,
required=False,
default=10,
help="frame shift length",
)
parser.add_argument(
"--chunk_size",
type=int,
required=False,
default=16,
help="chunk size default is 16",
)
parser.add_argument(
"--context",
type=int,
required=False,
default=7,
help="subsampling context",
)
parser.add_argument(
"--subsampling",
type=int,
required=False,
default=4,
help="subsampling rate",
)
FLAGS = parser.parse_args()
print(FLAGS)
# load data
filenames = []
transcripts = []
if FLAGS.audio_file is not None:
path = FLAGS.audio_file
if FLAGS.data_dir:
path = os.path.join(FLAGS.data_dir, path)
if os.path.exists(path):
filenames = [path]
elif FLAGS.wavscp is not None:
audio_data = {}
with open(FLAGS.wavscp, "r", encoding="utf-8") as f:
for line in f:
aid, path = line.strip().split("\t")
if FLAGS.data_dir:
path = os.path.join(FLAGS.data_dir, path)
audio_data[aid] = {"path": path}
with open(FLAGS.trans, "r", encoding="utf-8") as f:
for line in f:
aid, text = line.strip().split("\t")
audio_data[aid]["text"] = text
for key, value in audio_data.items():
filenames.append(value["path"])
transcripts.append(value["text"])
num_workers = multiprocessing.cpu_count() // 2
if FLAGS.streaming:
speech_client_cls = StreamingSpeechClient
else:
speech_client_cls = OfflineSpeechClient
def single_job(client_files):
with grpcclient.InferenceServerClient(
url=FLAGS.url, verbose=FLAGS.verbose
) as triton_client:
protocol_client = grpcclient
speech_client = speech_client_cls(
triton_client, FLAGS.model_name, protocol_client, FLAGS
)
idx, audio_files = client_files
predictions = []
for li in audio_files:
result = speech_client.recognize(li, idx)
print("Recognized {}:{}".format(li, result[0]))
predictions += result
return predictions
# start to do inference
# Group requests in batches
predictions = []
tasks = []
splits = np.array_split(filenames, num_workers)
for idx, per_split in enumerate(splits):
cur_files = per_split.tolist()
tasks.append((idx, cur_files))
with Pool(processes=num_workers) as pool:
predictions = pool.map(single_job, tasks)
predictions = [item for sublist in predictions for item in sublist]
if transcripts:
cer = cal_cer(predictions, transcripts)
print("CER is: {}".format(cer))

View File

@@ -0,0 +1,541 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Nvidia (authors: Yuekai Zhang)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a manifest in lhotse format and sends it to the server
for decoding, in parallel.
Usage:
# For offline wenet server
./decode_manifest_triton.py \
--server-addr localhost \
--compute-cer \
--model-name attention_rescoring \
--num-tasks 300 \
--manifest-filename ./aishell-test-dev-manifests/data/fbank/aishell_cuts_test.jsonl.gz # noqa
# For streaming wenet server
./decode_manifest_triton.py \
--server-addr localhost \
--streaming \
--compute-cer \
--context 7 \
--model-name streaming_wenet \
--num-tasks 300 \
--manifest-filename ./aishell-test-dev-manifests/data/fbank/aishell_cuts_test.jsonl.gz # noqa
# For simulate streaming mode wenet server
./decode_manifest_triton.py \
--server-addr localhost \
--simulate-streaming \
--compute-cer \
--context 7 \
--model-name streaming_wenet \
--num-tasks 300 \
--manifest-filename ./aishell-test-dev-manifests/data/fbank/aishell_cuts_test.jsonl.gz # noqa
# For test container:
docker run -it --rm --name "wenet_client_test" --net host --gpus all soar97/triton-k2:22.12.1 # noqa
# For aishell manifests:
apt-get install git-lfs
git-lfs install
git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
sudo mkdir -p /root/fangjun/open-source/icefall-aishell/egs/aishell/ASR/download/aishell
tar xf ./aishell-test-dev-manifests/data_aishell.tar.gz -C /root/fangjun/open-source/icefall-aishell/egs/aishell/ASR/download/aishell/ # noqa
"""
import argparse
import asyncio
import math
import time
import types
from pathlib import Path
import json
import numpy as np
import tritonclient
import tritonclient.grpc.aio as grpcclient
from lhotse import CutSet, load_manifest
from tritonclient.utils import np_to_triton_dtype
from icefall.utils import store_transcripts, write_error_stats
DEFAULT_MANIFEST_FILENAME = "/mnt/samsung-t7/yuekai/aishell-test-dev-manifests/data/fbank/aishell_cuts_test.jsonl.gz" # noqa
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--server-addr",
type=str,
default="localhost",
help="Address of the server",
)
parser.add_argument(
"--server-port",
type=int,
default=8001,
help="Port of the server",
)
parser.add_argument(
"--manifest-filename",
type=str,
default=DEFAULT_MANIFEST_FILENAME,
help="Path to the manifest for decoding",
)
parser.add_argument(
"--model-name",
type=str,
default="transducer",
help="triton model_repo module name to request",
)
parser.add_argument(
"--num-tasks",
type=int,
default=50,
help="Number of tasks to use for sending",
)
parser.add_argument(
"--log-interval",
type=int,
default=5,
help="Controls how frequently we print the log.",
)
parser.add_argument(
"--compute-cer",
action="store_true",
default=False,
help="""True to compute CER, e.g., for Chinese.
False to compute WER, e.g., for English words.
""",
)
parser.add_argument(
"--streaming",
action="store_true",
default=False,
help="""True for streaming ASR.
""",
)
parser.add_argument(
"--simulate-streaming",
action="store_true",
default=False,
help="""True for strictly simulate streaming ASR.
Threads will sleep to simulate the real speaking scene.
""",
)
parser.add_argument(
"--chunk_size",
type=int,
required=False,
default=16,
help="chunk size default is 16",
)
parser.add_argument(
"--context",
type=int,
required=False,
default=-1,
help="subsampling context for wenet",
)
parser.add_argument(
"--encoder_right_context",
type=int,
required=False,
default=2,
help="encoder right context",
)
parser.add_argument(
"--subsampling",
type=int,
required=False,
default=4,
help="subsampling rate",
)
parser.add_argument(
"--stats_file",
type=str,
required=False,
default="./stats.json",
help="output of stats anaylasis",
)
return parser.parse_args()
async def send(
cuts: CutSet,
name: str,
triton_client: tritonclient.grpc.aio.InferenceServerClient,
protocol_client: types.ModuleType,
log_interval: int,
compute_cer: bool,
model_name: str,
):
total_duration = 0.0
results = []
for i, c in enumerate(cuts):
if i % log_interval == 0:
print(f"{name}: {i}/{len(cuts)}")
waveform = c.load_audio().reshape(-1).astype(np.float32)
sample_rate = 16000
# padding to nearset 10 seconds
samples = np.zeros(
(
1,
10 * sample_rate * (int(len(waveform) / sample_rate // 10) + 1),
),
dtype=np.float32,
)
samples[0, : len(waveform)] = waveform
lengths = np.array([[len(waveform)]], dtype=np.int32)
inputs = [
protocol_client.InferInput(
"WAV", samples.shape, np_to_triton_dtype(samples.dtype)
),
protocol_client.InferInput(
"WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
),
]
inputs[0].set_data_from_numpy(samples)
inputs[1].set_data_from_numpy(lengths)
outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")]
sequence_id = 10086 + i
response = await triton_client.infer(
model_name, inputs, request_id=str(sequence_id), outputs=outputs
)
decoding_results = response.as_numpy("TRANSCRIPTS")[0]
if type(decoding_results) == np.ndarray:
decoding_results = b" ".join(decoding_results).decode("utf-8")
else:
# For wenet
decoding_results = decoding_results.decode("utf-8")
total_duration += c.duration
if compute_cer:
ref = c.supervisions[0].text.split()
hyp = decoding_results.split()
ref = list("".join(ref))
hyp = list("".join(hyp))
results.append((c.id, ref, hyp))
else:
results.append(
(
c.id,
c.supervisions[0].text.split(),
decoding_results.split(),
)
) # noqa
return total_duration, results
async def send_streaming(
cuts: CutSet,
name: str,
triton_client: tritonclient.grpc.aio.InferenceServerClient,
protocol_client: types.ModuleType,
log_interval: int,
compute_cer: bool,
model_name: str,
first_chunk_in_secs: float,
other_chunk_in_secs: float,
task_index: int,
simulate_mode: bool = False,
):
total_duration = 0.0
results = []
latency_data = []
for i, c in enumerate(cuts):
if i % log_interval == 0:
print(f"{name}: {i}/{len(cuts)}")
waveform = c.load_audio().reshape(-1).astype(np.float32)
sample_rate = 16000
wav_segs = []
j = 0
while j < len(waveform):
if j == 0:
stride = int(first_chunk_in_secs * sample_rate)
wav_segs.append(waveform[j : j + stride])
else:
stride = int(other_chunk_in_secs * sample_rate)
wav_segs.append(waveform[j : j + stride])
j += len(wav_segs[-1])
sequence_id = task_index + 10086
for idx, seg in enumerate(wav_segs):
chunk_len = len(seg)
if simulate_mode:
await asyncio.sleep(chunk_len / sample_rate)
chunk_start = time.time()
if idx == 0:
chunk_samples = int(first_chunk_in_secs * sample_rate)
expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
else:
chunk_samples = int(other_chunk_in_secs * sample_rate)
expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
expect_input[0][0:chunk_len] = seg
input0_data = expect_input
input1_data = np.array([[chunk_len]], dtype=np.int32)
inputs = [
protocol_client.InferInput(
"WAV",
input0_data.shape,
np_to_triton_dtype(input0_data.dtype),
),
protocol_client.InferInput(
"WAV_LENS",
input1_data.shape,
np_to_triton_dtype(input1_data.dtype),
),
]
inputs[0].set_data_from_numpy(input0_data)
inputs[1].set_data_from_numpy(input1_data)
outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")]
end = False
if idx == len(wav_segs) - 1:
end = True
response = await triton_client.infer(
model_name,
inputs,
outputs=outputs,
sequence_id=sequence_id,
sequence_start=idx == 0,
sequence_end=end,
)
idx += 1
decoding_results = response.as_numpy("TRANSCRIPTS")
if type(decoding_results) == np.ndarray:
decoding_results = b" ".join(decoding_results).decode("utf-8")
else:
# For wenet
decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode(
"utf-8"
)
chunk_end = time.time() - chunk_start
latency_data.append((chunk_end, chunk_len / sample_rate))
total_duration += c.duration
if compute_cer:
ref = c.supervisions[0].text.split()
hyp = decoding_results.split()
ref = list("".join(ref))
hyp = list("".join(hyp))
results.append((c.id, ref, hyp))
else:
results.append(
(
c.id,
c.supervisions[0].text.split(),
decoding_results.split(),
)
) # noqa
return total_duration, results, latency_data
async def main():
args = get_args()
filename = args.manifest_filename
server_addr = args.server_addr
server_port = args.server_port
url = f"{server_addr}:{server_port}"
num_tasks = args.num_tasks
log_interval = args.log_interval
compute_cer = args.compute_cer
cuts = load_manifest(filename)
cuts_list = cuts.split(num_tasks)
tasks = []
triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
protocol_client = grpcclient
if args.streaming or args.simulate_streaming:
frame_shift_ms = 10
frame_length_ms = 25
add_frames = math.ceil(
(frame_length_ms - frame_shift_ms) / frame_shift_ms
)
# decode_window_length: input sequence length of streaming encoder
if args.context > 0:
# decode window length calculation for wenet
decode_window_length = (
args.chunk_size - 1
) * args.subsampling + args.context
else:
# decode window length calculation for icefall
decode_window_length = (
args.chunk_size + 2 + args.encoder_right_context
) * args.subsampling + 3
first_chunk_ms = (decode_window_length + add_frames) * frame_shift_ms
start_time = time.time()
for i in range(num_tasks):
if args.streaming:
assert not args.simulate_streaming
task = asyncio.create_task(
send_streaming(
cuts=cuts_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=log_interval,
compute_cer=compute_cer,
model_name=args.model_name,
first_chunk_in_secs=first_chunk_ms / 1000,
other_chunk_in_secs=args.chunk_size
* args.subsampling
* frame_shift_ms
/ 1000,
task_index=i,
)
)
elif args.simulate_streaming:
task = asyncio.create_task(
send_streaming(
cuts=cuts_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=log_interval,
compute_cer=compute_cer,
model_name=args.model_name,
first_chunk_in_secs=first_chunk_ms / 1000,
other_chunk_in_secs=args.chunk_size
* args.subsampling
* frame_shift_ms
/ 1000,
task_index=i,
simulate_mode=True,
)
)
else:
task = asyncio.create_task(
send(
cuts=cuts_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=log_interval,
compute_cer=compute_cer,
model_name=args.model_name,
)
)
tasks.append(task)
ans_list = await asyncio.gather(*tasks)
end_time = time.time()
elapsed = end_time - start_time
results = []
total_duration = 0.0
latency_data = []
for ans in ans_list:
total_duration += ans[0]
results += ans[1]
if args.streaming or args.simulate_streaming:
latency_data += ans[2]
rtf = elapsed / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration/3600:.2f} hours)\n"
s += (
f"processing time: {elapsed:.3f} seconds "
f"({elapsed/3600:.2f} hours)\n"
)
if args.streaming or args.simulate_streaming:
latency_list = [
chunk_end for (chunk_end, chunk_duration) in latency_data
]
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
s += f"latency_variance: {latency_variance:.2f}\n"
s += f"latency_50_percentile: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
s += f"latency_90_percentile: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
s += f"latency_99_percentile: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
s += f"average_latency_ms: {latency_ms:.2f}\n"
print(s)
with open("rtf.txt", "w") as f:
f.write(s)
name = Path(filename).stem.split(".")[0]
results = sorted(results)
store_transcripts(filename=f"recogs-{name}.txt", texts=results)
with open(f"errs-{name}.txt", "w") as f:
write_error_stats(f, "test-set", results, enable_log=True)
with open(f"errs-{name}.txt", "r") as f:
print(f.readline()) # WER
print(f.readline()) # Detailed errors
if args.stats_file:
stats = await triton_client.get_inference_statistics(
model_name="", as_json=True
)
with open(args.stats_file, "w") as f:
json.dump(stats, f)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,561 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Nvidia (authors: Yuekai Zhang)
# 2023 Recurrent.ai (authors: Songtao Shi)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a manifest in nemo format and sends it to the server
for decoding, in parallel.
{'audio_filepath':'','text':'',duration:}\n
{'audio_filepath':'','text':'',duration:}\n
Usage:
# For aishell manifests:
apt-get install git-lfs
git-lfs install
git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
sudo mkdir -p ./aishell-test-dev-manifests/aishell
tar xf ./aishell-test-dev-manifests/data_aishell.tar.gz -C ./aishell-test-dev-manifests/aishell # noqa
# cmd run
manifest_path='./client/aishell_test.txt'
serveraddr=localhost
num_task=60
python3 client/decode_manifest_triton_wo_cuts.py \
--server-addr $serveraddr \
--compute-cer \
--model-name infer_pipeline \
--num-tasks $num_task \
--manifest-filename $manifest_path \
"""
from pydub import AudioSegment
import argparse
import asyncio
import math
import time
import types
from pathlib import Path
import json
import os
import numpy as np
import tritonclient
import tritonclient.grpc.aio as grpcclient
from tritonclient.utils import np_to_triton_dtype
from icefall.utils import store_transcripts, write_error_stats
DEFAULT_MANIFEST_FILENAME = "./aishell_test.txt" # noqa
DEFAULT_ROOT = './'
DEFAULT_ROOT = '/mfs/songtao/researchcode/FunASR/data/'
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--server-addr",
type=str,
default="localhost",
help="Address of the server",
)
parser.add_argument(
"--server-port",
type=int,
default=8001,
help="Port of the server",
)
parser.add_argument(
"--manifest-filename",
type=str,
default=DEFAULT_MANIFEST_FILENAME,
help="Path to the manifest for decoding",
)
parser.add_argument(
"--model-name",
type=str,
default="transducer",
help="triton model_repo module name to request",
)
parser.add_argument(
"--num-tasks",
type=int,
default=50,
help="Number of tasks to use for sending",
)
parser.add_argument(
"--log-interval",
type=int,
default=5,
help="Controls how frequently we print the log.",
)
parser.add_argument(
"--compute-cer",
action="store_true",
default=False,
help="""True to compute CER, e.g., for Chinese.
False to compute WER, e.g., for English words.
""",
)
parser.add_argument(
"--streaming",
action="store_true",
default=False,
help="""True for streaming ASR.
""",
)
parser.add_argument(
"--simulate-streaming",
action="store_true",
default=False,
help="""True for strictly simulate streaming ASR.
Threads will sleep to simulate the real speaking scene.
""",
)
parser.add_argument(
"--chunk_size",
type=int,
required=False,
default=16,
help="chunk size default is 16",
)
parser.add_argument(
"--context",
type=int,
required=False,
default=-1,
help="subsampling context for wenet",
)
parser.add_argument(
"--encoder_right_context",
type=int,
required=False,
default=2,
help="encoder right context",
)
parser.add_argument(
"--subsampling",
type=int,
required=False,
default=4,
help="subsampling rate",
)
parser.add_argument(
"--stats_file",
type=str,
required=False,
default="./stats.json",
help="output of stats anaylasis",
)
return parser.parse_args()
def load_manifest(fp):
data = []
with open(fp) as f:
for i, dp in enumerate(f.readlines()):
dp = eval(dp)
dp['id'] = i
data.append(dp)
return data
def split_dps(dps, num_tasks):
dps_splited = []
# import pdb;pdb.set_trace()
assert len(dps) > num_tasks
one_task_num = len(dps)//num_tasks
for i in range(0, len(dps), one_task_num):
if i+one_task_num >= len(dps):
for k, j in enumerate(range(i, len(dps))):
dps_splited[k].append(dps[j])
else:
dps_splited.append(dps[i:i+one_task_num])
return dps_splited
def load_audio(path):
audio = AudioSegment.from_wav(path).set_frame_rate(16000).set_channels(1)
audiop_np = np.array(audio.get_array_of_samples())/32768.0
return audiop_np.astype(np.float32), audio.duration_seconds
async def send(
dps: list,
name: str,
triton_client: tritonclient.grpc.aio.InferenceServerClient,
protocol_client: types.ModuleType,
log_interval: int,
compute_cer: bool,
model_name: str,
):
total_duration = 0.0
results = []
for i, dp in enumerate(dps):
if i % log_interval == 0:
print(f"{name}: {i}/{len(dps)}")
waveform, duration = load_audio(
os.path.join(DEFAULT_ROOT, dp['audio_filepath']))
sample_rate = 16000
# padding to nearset 10 seconds
samples = np.zeros(
(
1,
10 * sample_rate *
(int(len(waveform) / sample_rate // 10) + 1),
),
dtype=np.float32,
)
samples[0, : len(waveform)] = waveform
lengths = np.array([[len(waveform)]], dtype=np.int32)
inputs = [
protocol_client.InferInput(
"WAV", samples.shape, np_to_triton_dtype(samples.dtype)
),
protocol_client.InferInput(
"WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
),
]
inputs[0].set_data_from_numpy(samples)
inputs[1].set_data_from_numpy(lengths)
outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")]
sequence_id = 10086 + i
response = await triton_client.infer(
model_name, inputs, request_id=str(sequence_id), outputs=outputs
)
decoding_results = response.as_numpy("TRANSCRIPTS")[0]
if type(decoding_results) == np.ndarray:
decoding_results = b" ".join(decoding_results).decode("utf-8")
else:
# For wenet
decoding_results = decoding_results.decode("utf-8")
total_duration += duration
if compute_cer:
ref = dp['text'].split()
hyp = decoding_results.split()
ref = list("".join(ref))
hyp = list("".join(hyp))
results.append((dp['id'], ref, hyp))
else:
results.append(
(
dp['id'],
dp['text'].split(),
decoding_results.split(),
)
) # noqa
return total_duration, results
async def send_streaming(
dps: list,
name: str,
triton_client: tritonclient.grpc.aio.InferenceServerClient,
protocol_client: types.ModuleType,
log_interval: int,
compute_cer: bool,
model_name: str,
first_chunk_in_secs: float,
other_chunk_in_secs: float,
task_index: int,
simulate_mode: bool = False,
):
total_duration = 0.0
results = []
latency_data = []
for i, dp in enumerate(dps):
if i % log_interval == 0:
print(f"{name}: {i}/{len(dps)}")
waveform, duration = load_audio(dp['audio_filepath'])
sample_rate = 16000
wav_segs = []
j = 0
while j < len(waveform):
if j == 0:
stride = int(first_chunk_in_secs * sample_rate)
wav_segs.append(waveform[j: j + stride])
else:
stride = int(other_chunk_in_secs * sample_rate)
wav_segs.append(waveform[j: j + stride])
j += len(wav_segs[-1])
sequence_id = task_index + 10086
for idx, seg in enumerate(wav_segs):
chunk_len = len(seg)
if simulate_mode:
await asyncio.sleep(chunk_len / sample_rate)
chunk_start = time.time()
if idx == 0:
chunk_samples = int(first_chunk_in_secs * sample_rate)
expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
else:
chunk_samples = int(other_chunk_in_secs * sample_rate)
expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
expect_input[0][0:chunk_len] = seg
input0_data = expect_input
input1_data = np.array([[chunk_len]], dtype=np.int32)
inputs = [
protocol_client.InferInput(
"WAV",
input0_data.shape,
np_to_triton_dtype(input0_data.dtype),
),
protocol_client.InferInput(
"WAV_LENS",
input1_data.shape,
np_to_triton_dtype(input1_data.dtype),
),
]
inputs[0].set_data_from_numpy(input0_data)
inputs[1].set_data_from_numpy(input1_data)
outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")]
end = False
if idx == len(wav_segs) - 1:
end = True
response = await triton_client.infer(
model_name,
inputs,
outputs=outputs,
sequence_id=sequence_id,
sequence_start=idx == 0,
sequence_end=end,
)
idx += 1
decoding_results = response.as_numpy("TRANSCRIPTS")
if type(decoding_results) == np.ndarray:
decoding_results = b" ".join(decoding_results).decode("utf-8")
else:
# For wenet
decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode(
"utf-8"
)
chunk_end = time.time() - chunk_start
latency_data.append((chunk_end, chunk_len / sample_rate))
total_duration += duration
if compute_cer:
ref = dp['text'].split()
hyp = decoding_results.split()
ref = list("".join(ref))
hyp = list("".join(hyp))
results.append((dp['id'], ref, hyp))
else:
results.append(
(
dp['id'],
dp['text'].split(),
decoding_results.split(),
)
) # noqa
return total_duration, results, latency_data
async def main():
args = get_args()
filename = args.manifest_filename
server_addr = args.server_addr
server_port = args.server_port
url = f"{server_addr}:{server_port}"
num_tasks = args.num_tasks
log_interval = args.log_interval
compute_cer = args.compute_cer
dps = load_manifest(filename)
dps_list = split_dps(dps, num_tasks)
tasks = []
triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
protocol_client = grpcclient
if args.streaming or args.simulate_streaming:
frame_shift_ms = 10
frame_length_ms = 25
add_frames = math.ceil(
(frame_length_ms - frame_shift_ms) / frame_shift_ms
)
# decode_window_length: input sequence length of streaming encoder
if args.context > 0:
# decode window length calculation for wenet
decode_window_length = (
args.chunk_size - 1
) * args.subsampling + args.context
else:
# decode window length calculation for icefall
decode_window_length = (
args.chunk_size + 2 + args.encoder_right_context
) * args.subsampling + 3
first_chunk_ms = (decode_window_length + add_frames) * frame_shift_ms
start_time = time.time()
for i in range(num_tasks):
if args.streaming:
assert not args.simulate_streaming
task = asyncio.create_task(
send_streaming(
dps=dps_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=log_interval,
compute_cer=compute_cer,
model_name=args.model_name,
first_chunk_in_secs=first_chunk_ms / 1000,
other_chunk_in_secs=args.chunk_size
* args.subsampling
* frame_shift_ms
/ 1000,
task_index=i,
)
)
elif args.simulate_streaming:
task = asyncio.create_task(
send_streaming(
dps=dps_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=log_interval,
compute_cer=compute_cer,
model_name=args.model_name,
first_chunk_in_secs=first_chunk_ms / 1000,
other_chunk_in_secs=args.chunk_size
* args.subsampling
* frame_shift_ms
/ 1000,
task_index=i,
simulate_mode=True,
)
)
else:
task = asyncio.create_task(
send(
dps=dps_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=log_interval,
compute_cer=compute_cer,
model_name=args.model_name,
)
)
tasks.append(task)
ans_list = await asyncio.gather(*tasks)
end_time = time.time()
elapsed = end_time - start_time
results = []
total_duration = 0.0
latency_data = []
for ans in ans_list:
total_duration += ans[0]
results += ans[1]
if args.streaming or args.simulate_streaming:
latency_data += ans[2]
rtf = elapsed / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration/3600:.2f} hours)\n"
s += (
f"processing time: {elapsed:.3f} seconds "
f"({elapsed/3600:.2f} hours)\n"
)
if args.streaming or args.simulate_streaming:
latency_list = [
chunk_end for (chunk_end, chunk_duration) in latency_data
]
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
s += f"latency_variance: {latency_variance:.2f}\n"
s += f"latency_50_percentile: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
s += f"latency_90_percentile: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
s += f"latency_99_percentile: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
s += f"average_latency_ms: {latency_ms:.2f}\n"
print(s)
with open("rtf.txt", "w") as f:
f.write(s)
name = Path(filename).stem.split(".")[0]
results = sorted(results)
store_transcripts(filename=f"recogs-{name}.txt", texts=results)
with open(f"errs-{name}.txt", "w") as f:
write_error_stats(f, "test-set", results, enable_log=True)
with open(f"errs-{name}.txt", "r") as f:
print(f.readline()) # WER
print(f.readline()) # Detailed errors
if args.stats_file:
stats = await triton_client.get_inference_statistics(
model_name="", as_json=True
)
with open(args.stats_file, "w") as f:
json.dump(stats, f)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,142 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tritonclient.utils import np_to_triton_dtype
import numpy as np
import math
import soundfile as sf
class OfflineSpeechClient(object):
def __init__(self, triton_client, model_name, protocol_client, args):
self.triton_client = triton_client
self.protocol_client = protocol_client
self.model_name = model_name
def recognize(self, wav_file, idx=0):
waveform, sample_rate = sf.read(wav_file)
samples = np.array([waveform], dtype=np.float32)
lengths = np.array([[len(waveform)]], dtype=np.int32)
# better pad waveform to nearest length here
# target_seconds = math.cel(len(waveform) / sample_rate)
# target_samples = np.zeros([1, target_seconds * sample_rate])
# target_samples[0][0: len(waveform)] = waveform
# samples = target_samples
sequence_id = 10086 + idx
result = ""
inputs = [
self.protocol_client.InferInput(
"WAV", samples.shape, np_to_triton_dtype(samples.dtype)
),
self.protocol_client.InferInput(
"WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
),
]
inputs[0].set_data_from_numpy(samples)
inputs[1].set_data_from_numpy(lengths)
outputs = [self.protocol_client.InferRequestedOutput("TRANSCRIPTS")]
response = self.triton_client.infer(
self.model_name,
inputs,
request_id=str(sequence_id),
outputs=outputs,
)
result = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
return [result]
class StreamingSpeechClient(object):
def __init__(self, triton_client, model_name, protocol_client, args):
self.triton_client = triton_client
self.protocol_client = protocol_client
self.model_name = model_name
chunk_size = args.chunk_size
subsampling = args.subsampling
context = args.context
frame_shift_ms = args.frame_shift_ms
frame_length_ms = args.frame_length_ms
# for the first chunk
# we need additional frames to generate
# the exact first chunk length frames
# since the subsampling will look ahead several frames
first_chunk_length = (chunk_size - 1) * subsampling + context
add_frames = math.ceil(
(frame_length_ms - frame_shift_ms) / frame_shift_ms
)
first_chunk_ms = (first_chunk_length + add_frames) * frame_shift_ms
other_chunk_ms = chunk_size * subsampling * frame_shift_ms
self.first_chunk_in_secs = first_chunk_ms / 1000
self.other_chunk_in_secs = other_chunk_ms / 1000
def recognize(self, wav_file, idx=0):
waveform, sample_rate = sf.read(wav_file)
wav_segs = []
i = 0
while i < len(waveform):
if i == 0:
stride = int(self.first_chunk_in_secs * sample_rate)
wav_segs.append(waveform[i : i + stride])
else:
stride = int(self.other_chunk_in_secs * sample_rate)
wav_segs.append(waveform[i : i + stride])
i += len(wav_segs[-1])
sequence_id = idx + 10086
# simulate streaming
for idx, seg in enumerate(wav_segs):
chunk_len = len(seg)
if idx == 0:
chunk_samples = int(self.first_chunk_in_secs * sample_rate)
expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
else:
chunk_samples = int(self.other_chunk_in_secs * sample_rate)
expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
expect_input[0][0:chunk_len] = seg
input0_data = expect_input
input1_data = np.array([[chunk_len]], dtype=np.int32)
inputs = [
self.protocol_client.InferInput(
"WAV",
input0_data.shape,
np_to_triton_dtype(input0_data.dtype),
),
self.protocol_client.InferInput(
"WAV_LENS",
input1_data.shape,
np_to_triton_dtype(input1_data.dtype),
),
]
inputs[0].set_data_from_numpy(input0_data)
inputs[1].set_data_from_numpy(input1_data)
outputs = [self.protocol_client.InferRequestedOutput("TRANSCRIPTS")]
end = False
if idx == len(wav_segs) - 1:
end = True
response = self.triton_client.infer(
self.model_name,
inputs,
outputs=outputs,
sequence_id=sequence_id,
sequence_start=idx == 0,
sequence_end=end,
)
idx += 1
result = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
print("Get response from {}th chunk: {}".format(idx, result))
return [result]

View File

@@ -0,0 +1,60 @@
import numpy as np
def _levenshtein_distance(ref, hyp):
"""Levenshtein distance is a string metric for measuring the difference
between two sequences. Informally, the levenshtein disctance is defined as
the minimum number of single-character edits (substitutions, insertions or
deletions) required to change one word into the other. We can naturally
extend the edits to word level when calculate levenshtein disctance for
two sentences.
"""
m = len(ref)
n = len(hyp)
# special case
if ref == hyp:
return 0
if m == 0:
return n
if n == 0:
return m
if m < n:
ref, hyp = hyp, ref
m, n = n, m
# use O(min(m, n)) space
distance = np.zeros((2, n + 1), dtype=np.int32)
# initialize distance matrix
for j in range(n + 1):
distance[0][j] = j
# calculate levenshtein distance
for i in range(1, m + 1):
prev_row_idx = (i - 1) % 2
cur_row_idx = i % 2
distance[cur_row_idx][0] = i
for j in range(1, n + 1):
if ref[i - 1] == hyp[j - 1]:
distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
else:
s_num = distance[prev_row_idx][j - 1] + 1
i_num = distance[cur_row_idx][j - 1] + 1
d_num = distance[prev_row_idx][j] + 1
distance[cur_row_idx][j] = min(s_num, i_num, d_num)
return distance[m % 2][n]
def cal_cer(references, predictions):
errors = 0
lengths = 0
for ref, pred in zip(references, predictions):
cur_ref = list(ref)
cur_hyp = list(pred)
cur_error = _levenshtein_distance(cur_ref, cur_hyp)
errors += cur_error
lengths += len(cur_ref)
return float(errors) / lengths