Implementing concurrent.futures

This commit is contained in:
MiXaiLL76
2024-09-06 11:08:11 +03:00
parent 7b3e285bca
commit 73271d46f9

View File

@@ -13,71 +13,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from concurrent.futures import ThreadPoolExecutor
import onnxruntime
import torch
import torchaudio
from tqdm import tqdm
import onnxruntime
import torchaudio.compliance.kaldi as kaldi
from queue import Queue, Empty
from threading import Thread
from tqdm import tqdm
class ExtractEmbedding:
def __init__(self, model_path: str, queue: Queue, out_queue: Queue):
self.model_path = model_path
self.queue = queue
self.out_queue = out_queue
self.is_run = True
def extract_embedding(input_list):
utt, wav_file, ort_session = input_list
def run(self):
self.consumer_thread = Thread(target=self.consumer)
self.consumer_thread.start()
def stop(self):
self.is_run = False
self.consumer_thread.join()
def consumer(self):
option = onnxruntime.SessionOptions()
option.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(
self.model_path, sess_options=option, providers=providers
)
while self.is_run:
try:
utt, wav_file = self.queue.get(timeout=1)
audio, sample_rate = torchaudio.load(wav_file)
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(audio)
feat = kaldi.fbank(
audio, num_mel_bins=80, dither=0, sample_frequency=16000
)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = (
ort_session.run(
None,
{
ort_session.get_inputs()[0]
.name: feat.unsqueeze(dim=0)
.cpu()
.numpy()
},
)[0]
.flatten()
.tolist()
)
self.out_queue.put((utt, embedding))
except Empty:
self.is_run = False
break
audio, sample_rate = torchaudio.load(wav_file)
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(audio)
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = (
ort_session.run(
None,
{
ort_session.get_inputs()[0]
.name: feat.unsqueeze(dim=0)
.cpu()
.numpy()
},
)[0]
.flatten()
.tolist()
)
return (utt, embedding)
def main(args):
@@ -91,32 +60,38 @@ def main(args):
l = l.replace("\n", "").split()
utt2spk[l[0]] = l[1]
input_queue = Queue()
output_queue = Queue()
consumers = [
ExtractEmbedding(args.onnx_path, input_queue, output_queue)
for _ in range(args.num_thread)
assert os.path.exists(args.onnx_path), "onnx_path not exists"
option = onnxruntime.SessionOptions()
option.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(
args.onnx_path, sess_options=option, providers=providers
)
inputs = [
(utt, utt2wav[utt], ort_session)
for utt in tqdm(utt2wav.keys(), desc="Load data")
]
with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
results = list(
tqdm(
executor.map(extract_embedding, inputs),
total=len(inputs),
desc="Process data: ",
)
)
utt2embedding, spk2embedding = {}, {}
for utt in tqdm(utt2wav.keys(), desc="Load data"):
input_queue.put((utt, utt2wav[utt]))
for c in consumers:
c.run()
with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar:
while any([c.is_run for c in consumers]):
try:
utt, embedding = output_queue.get(timeout=1)
utt2embedding[utt] = embedding
spk = utt2spk[utt]
if spk not in spk2embedding:
spk2embedding[spk] = []
spk2embedding[spk].append(embedding)
pbar.update(1)
except Empty:
continue
for utt, embedding in results:
utt2embedding[utt] = embedding
spk = utt2spk[utt]
if spk not in spk2embedding:
spk2embedding[spk] = []
spk2embedding[spk].append(embedding)
for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()