diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py index 616054c..1fd1b95 100755 --- a/tools/extract_embedding.py +++ b/tools/extract_embedding.py @@ -18,53 +18,117 @@ import torchaudio from tqdm import tqdm import onnxruntime import torchaudio.compliance.kaldi as kaldi +from queue import Queue, Empty +from threading import Thread + + +class ExtractEmbedding: + def __init__(self, model_path: str, queue: Queue, out_queue: Queue): + self.model_path = model_path + self.queue = queue + self.out_queue = out_queue + self.is_run = True + + def run(self): + self.consumer_thread = Thread(target=self.consumer) + self.consumer_thread.start() + + def stop(self): + self.is_run = False + self.consumer_thread.join() + + def consumer(self): + option = onnxruntime.SessionOptions() + option.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + ) + option.intra_op_num_threads = 1 + providers = ["CPUExecutionProvider"] + ort_session = onnxruntime.InferenceSession( + self.model_path, sess_options=option, providers=providers + ) + + while self.is_run: + try: + utt, wav_file = self.queue.get(timeout=1) + + audio, sample_rate = torchaudio.load(wav_file) + if sample_rate != 16000: + audio = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=16000 + )(audio) + feat = kaldi.fbank( + audio, num_mel_bins=80, dither=0, sample_frequency=16000 + ) + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = ( + ort_session.run( + None, + { + ort_session.get_inputs()[0] + .name: feat.unsqueeze(dim=0) + .cpu() + .numpy() + }, + )[0] + .flatten() + .tolist() + ) + self.out_queue.put((utt, embedding)) + except Empty: + self.is_run = False + break def main(args): utt2wav, utt2spk = {}, {} - with open('{}/wav.scp'.format(args.dir)) as f: + with open("{}/wav.scp".format(args.dir)) as f: for l in f: - l = l.replace('\n', '').split() + l = l.replace("\n", "").split() utt2wav[l[0]] = l[1] - with open('{}/utt2spk'.format(args.dir)) as f: + with open("{}/utt2spk".format(args.dir)) as f: for l in f: - l = l.replace('\n', '').split() + l = l.replace("\n", "").split() utt2spk[l[0]] = l[1] - option = onnxruntime.SessionOptions() - option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - option.intra_op_num_threads = 1 - providers = ["CPUExecutionProvider"] - ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) + input_queue = Queue() + output_queue = Queue() + consumers = [ + ExtractEmbedding(args.onnx_path, input_queue, output_queue) + for _ in range(args.num_thread) + ] utt2embedding, spk2embedding = {}, {} - for utt in tqdm(utt2wav.keys()): - audio, sample_rate = torchaudio.load(utt2wav[utt]) - if sample_rate != 16000: - audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) - feat = kaldi.fbank(audio, - num_mel_bins=80, - dither=0, - sample_frequency=16000) - feat = feat - feat.mean(dim=0, keepdim=True) - embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() - utt2embedding[utt] = embedding - spk = utt2spk[utt] - if spk not in spk2embedding: - spk2embedding[spk] = [] - spk2embedding[spk].append(embedding) + for utt in tqdm(utt2wav.keys(), desc="Load data"): + input_queue.put((utt, utt2wav[utt])) + + for c in consumers: + c.run() + + with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar: + while any([c.is_run for c in consumers]): + try: + utt, embedding = output_queue.get(timeout=1) + utt2embedding[utt] = embedding + spk = utt2spk[utt] + if spk not in spk2embedding: + spk2embedding[spk] = [] + spk2embedding[spk].append(embedding) + pbar.update(1) + except Empty: + continue + for k, v in spk2embedding.items(): spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() - torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir)) - torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir)) + torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir)) + torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir)) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--dir', - type=str) - parser.add_argument('--onnx_path', - type=str) + parser.add_argument("--dir", type=str) + parser.add_argument("--onnx_path", type=str) + parser.add_argument("--num_thread", type=int, default=8) args = parser.parse_args() main(args)