diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py index 1fd1b95..779faa1 100755 --- a/tools/extract_embedding.py +++ b/tools/extract_embedding.py @@ -13,71 +13,40 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os +from concurrent.futures import ThreadPoolExecutor + +import onnxruntime import torch import torchaudio -from tqdm import tqdm -import onnxruntime import torchaudio.compliance.kaldi as kaldi -from queue import Queue, Empty -from threading import Thread +from tqdm import tqdm -class ExtractEmbedding: - def __init__(self, model_path: str, queue: Queue, out_queue: Queue): - self.model_path = model_path - self.queue = queue - self.out_queue = out_queue - self.is_run = True +def extract_embedding(input_list): + utt, wav_file, ort_session = input_list - def run(self): - self.consumer_thread = Thread(target=self.consumer) - self.consumer_thread.start() - - def stop(self): - self.is_run = False - self.consumer_thread.join() - - def consumer(self): - option = onnxruntime.SessionOptions() - option.graph_optimization_level = ( - onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - ) - option.intra_op_num_threads = 1 - providers = ["CPUExecutionProvider"] - ort_session = onnxruntime.InferenceSession( - self.model_path, sess_options=option, providers=providers - ) - - while self.is_run: - try: - utt, wav_file = self.queue.get(timeout=1) - - audio, sample_rate = torchaudio.load(wav_file) - if sample_rate != 16000: - audio = torchaudio.transforms.Resample( - orig_freq=sample_rate, new_freq=16000 - )(audio) - feat = kaldi.fbank( - audio, num_mel_bins=80, dither=0, sample_frequency=16000 - ) - feat = feat - feat.mean(dim=0, keepdim=True) - embedding = ( - ort_session.run( - None, - { - ort_session.get_inputs()[0] - .name: feat.unsqueeze(dim=0) - .cpu() - .numpy() - }, - )[0] - .flatten() - .tolist() - ) - self.out_queue.put((utt, embedding)) - except Empty: - self.is_run = False - break + audio, sample_rate = torchaudio.load(wav_file) + if sample_rate != 16000: + audio = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=16000 + )(audio) + feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000) + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = ( + ort_session.run( + None, + { + ort_session.get_inputs()[0] + .name: feat.unsqueeze(dim=0) + .cpu() + .numpy() + }, + )[0] + .flatten() + .tolist() + ) + return (utt, embedding) def main(args): @@ -91,32 +60,38 @@ def main(args): l = l.replace("\n", "").split() utt2spk[l[0]] = l[1] - input_queue = Queue() - output_queue = Queue() - consumers = [ - ExtractEmbedding(args.onnx_path, input_queue, output_queue) - for _ in range(args.num_thread) + assert os.path.exists(args.onnx_path), "onnx_path not exists" + + option = onnxruntime.SessionOptions() + option.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + ) + option.intra_op_num_threads = 1 + providers = ["CPUExecutionProvider"] + ort_session = onnxruntime.InferenceSession( + args.onnx_path, sess_options=option, providers=providers + ) + + inputs = [ + (utt, utt2wav[utt], ort_session) + for utt in tqdm(utt2wav.keys(), desc="Load data") ] + with ThreadPoolExecutor(max_workers=args.num_thread) as executor: + results = list( + tqdm( + executor.map(extract_embedding, inputs), + total=len(inputs), + desc="Process data: ", + ) + ) utt2embedding, spk2embedding = {}, {} - for utt in tqdm(utt2wav.keys(), desc="Load data"): - input_queue.put((utt, utt2wav[utt])) - - for c in consumers: - c.run() - - with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar: - while any([c.is_run for c in consumers]): - try: - utt, embedding = output_queue.get(timeout=1) - utt2embedding[utt] = embedding - spk = utt2spk[utt] - if spk not in spk2embedding: - spk2embedding[spk] = [] - spk2embedding[spk].append(embedding) - pbar.update(1) - except Empty: - continue + for utt, embedding in results: + utt2embedding[utt] = embedding + spk = utt2spk[utt] + if spk not in spk2embedding: + spk2embedding[spk] = [] + spk2embedding[spk].append(embedding) for k, v in spk2embedding.items(): spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()