Implementing concurrent.futures

This commit is contained in:
MiXaiLL76
2024-09-06 11:08:11 +03:00
parent 7b3e285bca
commit 73271d46f9

View File

@@ -13,53 +13,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
from concurrent.futures import ThreadPoolExecutor
import onnxruntime
import torch import torch
import torchaudio import torchaudio
from tqdm import tqdm
import onnxruntime
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
from queue import Queue, Empty from tqdm import tqdm
from threading import Thread
class ExtractEmbedding: def extract_embedding(input_list):
def __init__(self, model_path: str, queue: Queue, out_queue: Queue): utt, wav_file, ort_session = input_list
self.model_path = model_path
self.queue = queue
self.out_queue = out_queue
self.is_run = True
def run(self):
self.consumer_thread = Thread(target=self.consumer)
self.consumer_thread.start()
def stop(self):
self.is_run = False
self.consumer_thread.join()
def consumer(self):
option = onnxruntime.SessionOptions()
option.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(
self.model_path, sess_options=option, providers=providers
)
while self.is_run:
try:
utt, wav_file = self.queue.get(timeout=1)
audio, sample_rate = torchaudio.load(wav_file) audio, sample_rate = torchaudio.load(wav_file)
if sample_rate != 16000: if sample_rate != 16000:
audio = torchaudio.transforms.Resample( audio = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000 orig_freq=sample_rate, new_freq=16000
)(audio) )(audio)
feat = kaldi.fbank( feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
audio, num_mel_bins=80, dither=0, sample_frequency=16000
)
feat = feat - feat.mean(dim=0, keepdim=True) feat = feat - feat.mean(dim=0, keepdim=True)
embedding = ( embedding = (
ort_session.run( ort_session.run(
@@ -74,10 +46,7 @@ class ExtractEmbedding:
.flatten() .flatten()
.tolist() .tolist()
) )
self.out_queue.put((utt, embedding)) return (utt, embedding)
except Empty:
self.is_run = False
break
def main(args): def main(args):
@@ -91,32 +60,38 @@ def main(args):
l = l.replace("\n", "").split() l = l.replace("\n", "").split()
utt2spk[l[0]] = l[1] utt2spk[l[0]] = l[1]
input_queue = Queue() assert os.path.exists(args.onnx_path), "onnx_path not exists"
output_queue = Queue()
consumers = [ option = onnxruntime.SessionOptions()
ExtractEmbedding(args.onnx_path, input_queue, output_queue) option.graph_optimization_level = (
for _ in range(args.num_thread) onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(
args.onnx_path, sess_options=option, providers=providers
)
inputs = [
(utt, utt2wav[utt], ort_session)
for utt in tqdm(utt2wav.keys(), desc="Load data")
] ]
with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
results = list(
tqdm(
executor.map(extract_embedding, inputs),
total=len(inputs),
desc="Process data: ",
)
)
utt2embedding, spk2embedding = {}, {} utt2embedding, spk2embedding = {}, {}
for utt in tqdm(utt2wav.keys(), desc="Load data"): for utt, embedding in results:
input_queue.put((utt, utt2wav[utt]))
for c in consumers:
c.run()
with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar:
while any([c.is_run for c in consumers]):
try:
utt, embedding = output_queue.get(timeout=1)
utt2embedding[utt] = embedding utt2embedding[utt] = embedding
spk = utt2spk[utt] spk = utt2spk[utt]
if spk not in spk2embedding: if spk not in spk2embedding:
spk2embedding[spk] = [] spk2embedding[spk] = []
spk2embedding[spk].append(embedding) spk2embedding[spk].append(embedding)
pbar.update(1)
except Empty:
continue
for k, v in spk2embedding.items(): for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()