From 1d05ae5fd3f6263a040e4cdcd75b470ab4cf40d2 Mon Sep 17 00:00:00 2001 From: MiXaiLL76 Date: Fri, 6 Sep 2024 11:40:27 +0300 Subject: [PATCH] fix --- tools/extract_embedding.py | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py index 779faa1..982a841 100755 --- a/tools/extract_embedding.py +++ b/tools/extract_embedding.py @@ -21,11 +21,10 @@ import torch import torchaudio import torchaudio.compliance.kaldi as kaldi from tqdm import tqdm +from itertools import repeat -def extract_embedding(input_list): - utt, wav_file, ort_session = input_list - +def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession): audio, sample_rate = torchaudio.load(wav_file) if sample_rate != 16000: audio = torchaudio.transforms.Resample( @@ -33,19 +32,7 @@ def extract_embedding(input_list): )(audio) feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000) feat = feat - feat.mean(dim=0, keepdim=True) - embedding = ( - ort_session.run( - None, - { - ort_session.get_inputs()[0] - .name: feat.unsqueeze(dim=0) - .cpu() - .numpy() - }, - )[0] - .flatten() - .tolist() - ) + embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() return (utt, embedding) @@ -72,16 +59,14 @@ def main(args): args.onnx_path, sess_options=option, providers=providers ) - inputs = [ - (utt, utt2wav[utt], ort_session) - for utt in tqdm(utt2wav.keys(), desc="Load data") - ] + all_utt = utt2wav.keys() + with ThreadPoolExecutor(max_workers=args.num_thread) as executor: results = list( tqdm( - executor.map(extract_embedding, inputs), - total=len(inputs), - desc="Process data: ", + executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)), + total=len(utt2wav), + desc="Process data: " ) )