mirror of
https://github.com/HumanAIGC/lite-avatar.git
synced 2026-02-05 18:09:20 +08:00
add files
This commit is contained in:
142
funasr_local/runtime/triton_gpu/client/speech_client.py
Normal file
142
funasr_local/runtime/triton_gpu/client/speech_client.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
import numpy as np
|
||||
import math
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
class OfflineSpeechClient(object):
|
||||
def __init__(self, triton_client, model_name, protocol_client, args):
|
||||
self.triton_client = triton_client
|
||||
self.protocol_client = protocol_client
|
||||
self.model_name = model_name
|
||||
|
||||
def recognize(self, wav_file, idx=0):
|
||||
waveform, sample_rate = sf.read(wav_file)
|
||||
samples = np.array([waveform], dtype=np.float32)
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
# better pad waveform to nearest length here
|
||||
# target_seconds = math.cel(len(waveform) / sample_rate)
|
||||
# target_samples = np.zeros([1, target_seconds * sample_rate])
|
||||
# target_samples[0][0: len(waveform)] = waveform
|
||||
# samples = target_samples
|
||||
sequence_id = 10086 + idx
|
||||
result = ""
|
||||
inputs = [
|
||||
self.protocol_client.InferInput(
|
||||
"WAV", samples.shape, np_to_triton_dtype(samples.dtype)
|
||||
),
|
||||
self.protocol_client.InferInput(
|
||||
"WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
|
||||
),
|
||||
]
|
||||
inputs[0].set_data_from_numpy(samples)
|
||||
inputs[1].set_data_from_numpy(lengths)
|
||||
outputs = [self.protocol_client.InferRequestedOutput("TRANSCRIPTS")]
|
||||
response = self.triton_client.infer(
|
||||
self.model_name,
|
||||
inputs,
|
||||
request_id=str(sequence_id),
|
||||
outputs=outputs,
|
||||
)
|
||||
result = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
|
||||
return [result]
|
||||
|
||||
|
||||
class StreamingSpeechClient(object):
|
||||
def __init__(self, triton_client, model_name, protocol_client, args):
|
||||
self.triton_client = triton_client
|
||||
self.protocol_client = protocol_client
|
||||
self.model_name = model_name
|
||||
chunk_size = args.chunk_size
|
||||
subsampling = args.subsampling
|
||||
context = args.context
|
||||
frame_shift_ms = args.frame_shift_ms
|
||||
frame_length_ms = args.frame_length_ms
|
||||
# for the first chunk
|
||||
# we need additional frames to generate
|
||||
# the exact first chunk length frames
|
||||
# since the subsampling will look ahead several frames
|
||||
first_chunk_length = (chunk_size - 1) * subsampling + context
|
||||
add_frames = math.ceil(
|
||||
(frame_length_ms - frame_shift_ms) / frame_shift_ms
|
||||
)
|
||||
first_chunk_ms = (first_chunk_length + add_frames) * frame_shift_ms
|
||||
other_chunk_ms = chunk_size * subsampling * frame_shift_ms
|
||||
self.first_chunk_in_secs = first_chunk_ms / 1000
|
||||
self.other_chunk_in_secs = other_chunk_ms / 1000
|
||||
|
||||
def recognize(self, wav_file, idx=0):
|
||||
waveform, sample_rate = sf.read(wav_file)
|
||||
wav_segs = []
|
||||
i = 0
|
||||
while i < len(waveform):
|
||||
if i == 0:
|
||||
stride = int(self.first_chunk_in_secs * sample_rate)
|
||||
wav_segs.append(waveform[i : i + stride])
|
||||
else:
|
||||
stride = int(self.other_chunk_in_secs * sample_rate)
|
||||
wav_segs.append(waveform[i : i + stride])
|
||||
i += len(wav_segs[-1])
|
||||
|
||||
sequence_id = idx + 10086
|
||||
# simulate streaming
|
||||
for idx, seg in enumerate(wav_segs):
|
||||
chunk_len = len(seg)
|
||||
if idx == 0:
|
||||
chunk_samples = int(self.first_chunk_in_secs * sample_rate)
|
||||
expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
|
||||
else:
|
||||
chunk_samples = int(self.other_chunk_in_secs * sample_rate)
|
||||
expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
|
||||
|
||||
expect_input[0][0:chunk_len] = seg
|
||||
input0_data = expect_input
|
||||
input1_data = np.array([[chunk_len]], dtype=np.int32)
|
||||
|
||||
inputs = [
|
||||
self.protocol_client.InferInput(
|
||||
"WAV",
|
||||
input0_data.shape,
|
||||
np_to_triton_dtype(input0_data.dtype),
|
||||
),
|
||||
self.protocol_client.InferInput(
|
||||
"WAV_LENS",
|
||||
input1_data.shape,
|
||||
np_to_triton_dtype(input1_data.dtype),
|
||||
),
|
||||
]
|
||||
|
||||
inputs[0].set_data_from_numpy(input0_data)
|
||||
inputs[1].set_data_from_numpy(input1_data)
|
||||
|
||||
outputs = [self.protocol_client.InferRequestedOutput("TRANSCRIPTS")]
|
||||
end = False
|
||||
if idx == len(wav_segs) - 1:
|
||||
end = True
|
||||
|
||||
response = self.triton_client.infer(
|
||||
self.model_name,
|
||||
inputs,
|
||||
outputs=outputs,
|
||||
sequence_id=sequence_id,
|
||||
sequence_start=idx == 0,
|
||||
sequence_end=end,
|
||||
)
|
||||
idx += 1
|
||||
result = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
|
||||
print("Get response from {}th chunk: {}".format(idx, result))
|
||||
return [result]
|
||||
Reference in New Issue
Block a user