Implementing concurrent.futures

This commit is contained in:
MiXaiLL76
2024-09-06 11:08:11 +03:00
parent 7b3e285bca
commit 73271d46f9

View File

@@ -13,71 +13,40 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
from concurrent.futures import ThreadPoolExecutor
import onnxruntime
import torch import torch
import torchaudio import torchaudio
from tqdm import tqdm
import onnxruntime
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
from queue import Queue, Empty from tqdm import tqdm
from threading import Thread
class ExtractEmbedding: def extract_embedding(input_list):
def __init__(self, model_path: str, queue: Queue, out_queue: Queue): utt, wav_file, ort_session = input_list
self.model_path = model_path
self.queue = queue
self.out_queue = out_queue
self.is_run = True
def run(self): audio, sample_rate = torchaudio.load(wav_file)
self.consumer_thread = Thread(target=self.consumer) if sample_rate != 16000:
self.consumer_thread.start() audio = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
def stop(self): )(audio)
self.is_run = False feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
self.consumer_thread.join() feat = feat - feat.mean(dim=0, keepdim=True)
embedding = (
def consumer(self): ort_session.run(
option = onnxruntime.SessionOptions() None,
option.graph_optimization_level = ( {
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL ort_session.get_inputs()[0]
) .name: feat.unsqueeze(dim=0)
option.intra_op_num_threads = 1 .cpu()
providers = ["CPUExecutionProvider"] .numpy()
ort_session = onnxruntime.InferenceSession( },
self.model_path, sess_options=option, providers=providers )[0]
) .flatten()
.tolist()
while self.is_run: )
try: return (utt, embedding)
utt, wav_file = self.queue.get(timeout=1)
audio, sample_rate = torchaudio.load(wav_file)
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(audio)
feat = kaldi.fbank(
audio, num_mel_bins=80, dither=0, sample_frequency=16000
)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = (
ort_session.run(
None,
{
ort_session.get_inputs()[0]
.name: feat.unsqueeze(dim=0)
.cpu()
.numpy()
},
)[0]
.flatten()
.tolist()
)
self.out_queue.put((utt, embedding))
except Empty:
self.is_run = False
break
def main(args): def main(args):
@@ -91,32 +60,38 @@ def main(args):
l = l.replace("\n", "").split() l = l.replace("\n", "").split()
utt2spk[l[0]] = l[1] utt2spk[l[0]] = l[1]
input_queue = Queue() assert os.path.exists(args.onnx_path), "onnx_path not exists"
output_queue = Queue()
consumers = [ option = onnxruntime.SessionOptions()
ExtractEmbedding(args.onnx_path, input_queue, output_queue) option.graph_optimization_level = (
for _ in range(args.num_thread) onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(
args.onnx_path, sess_options=option, providers=providers
)
inputs = [
(utt, utt2wav[utt], ort_session)
for utt in tqdm(utt2wav.keys(), desc="Load data")
] ]
with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
results = list(
tqdm(
executor.map(extract_embedding, inputs),
total=len(inputs),
desc="Process data: ",
)
)
utt2embedding, spk2embedding = {}, {} utt2embedding, spk2embedding = {}, {}
for utt in tqdm(utt2wav.keys(), desc="Load data"): for utt, embedding in results:
input_queue.put((utt, utt2wav[utt])) utt2embedding[utt] = embedding
spk = utt2spk[utt]
for c in consumers: if spk not in spk2embedding:
c.run() spk2embedding[spk] = []
spk2embedding[spk].append(embedding)
with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar:
while any([c.is_run for c in consumers]):
try:
utt, embedding = output_queue.get(timeout=1)
utt2embedding[utt] = embedding
spk = utt2spk[utt]
if spk not in spk2embedding:
spk2embedding[spk] = []
spk2embedding[spk].append(embedding)
pbar.update(1)
except Empty:
continue
for k, v in spk2embedding.items(): for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()