From 7b3e285bca37d4396388814037ede9de9f0b62f1 Mon Sep 17 00:00:00 2001 From: MiXaiLL76 Date: Thu, 5 Sep 2024 14:32:37 +0300 Subject: [PATCH 1/5] add threading --- tools/extract_embedding.py | 124 ++++++++++++++++++++++++++++--------- 1 file changed, 94 insertions(+), 30 deletions(-) diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py index 616054c..1fd1b95 100755 --- a/tools/extract_embedding.py +++ b/tools/extract_embedding.py @@ -18,53 +18,117 @@ import torchaudio from tqdm import tqdm import onnxruntime import torchaudio.compliance.kaldi as kaldi +from queue import Queue, Empty +from threading import Thread + + +class ExtractEmbedding: + def __init__(self, model_path: str, queue: Queue, out_queue: Queue): + self.model_path = model_path + self.queue = queue + self.out_queue = out_queue + self.is_run = True + + def run(self): + self.consumer_thread = Thread(target=self.consumer) + self.consumer_thread.start() + + def stop(self): + self.is_run = False + self.consumer_thread.join() + + def consumer(self): + option = onnxruntime.SessionOptions() + option.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + ) + option.intra_op_num_threads = 1 + providers = ["CPUExecutionProvider"] + ort_session = onnxruntime.InferenceSession( + self.model_path, sess_options=option, providers=providers + ) + + while self.is_run: + try: + utt, wav_file = self.queue.get(timeout=1) + + audio, sample_rate = torchaudio.load(wav_file) + if sample_rate != 16000: + audio = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=16000 + )(audio) + feat = kaldi.fbank( + audio, num_mel_bins=80, dither=0, sample_frequency=16000 + ) + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = ( + ort_session.run( + None, + { + ort_session.get_inputs()[0] + .name: feat.unsqueeze(dim=0) + .cpu() + .numpy() + }, + )[0] + .flatten() + .tolist() + ) + self.out_queue.put((utt, embedding)) + except Empty: + self.is_run = False + break def main(args): utt2wav, utt2spk = {}, {} - with open('{}/wav.scp'.format(args.dir)) as f: + with open("{}/wav.scp".format(args.dir)) as f: for l in f: - l = l.replace('\n', '').split() + l = l.replace("\n", "").split() utt2wav[l[0]] = l[1] - with open('{}/utt2spk'.format(args.dir)) as f: + with open("{}/utt2spk".format(args.dir)) as f: for l in f: - l = l.replace('\n', '').split() + l = l.replace("\n", "").split() utt2spk[l[0]] = l[1] - option = onnxruntime.SessionOptions() - option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - option.intra_op_num_threads = 1 - providers = ["CPUExecutionProvider"] - ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) + input_queue = Queue() + output_queue = Queue() + consumers = [ + ExtractEmbedding(args.onnx_path, input_queue, output_queue) + for _ in range(args.num_thread) + ] utt2embedding, spk2embedding = {}, {} - for utt in tqdm(utt2wav.keys()): - audio, sample_rate = torchaudio.load(utt2wav[utt]) - if sample_rate != 16000: - audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) - feat = kaldi.fbank(audio, - num_mel_bins=80, - dither=0, - sample_frequency=16000) - feat = feat - feat.mean(dim=0, keepdim=True) - embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() - utt2embedding[utt] = embedding - spk = utt2spk[utt] - if spk not in spk2embedding: - spk2embedding[spk] = [] - spk2embedding[spk].append(embedding) + for utt in tqdm(utt2wav.keys(), desc="Load data"): + input_queue.put((utt, utt2wav[utt])) + + for c in consumers: + c.run() + + with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar: + while any([c.is_run for c in consumers]): + try: + utt, embedding = output_queue.get(timeout=1) + utt2embedding[utt] = embedding + spk = utt2spk[utt] + if spk not in spk2embedding: + spk2embedding[spk] = [] + spk2embedding[spk].append(embedding) + pbar.update(1) + except Empty: + continue + for k, v in spk2embedding.items(): spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() - torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir)) - torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir)) + torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir)) + torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir)) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--dir', - type=str) - parser.add_argument('--onnx_path', - type=str) + parser.add_argument("--dir", type=str) + parser.add_argument("--onnx_path", type=str) + parser.add_argument("--num_thread", type=int, default=8) args = parser.parse_args() main(args) From 73271d46f9da51a3f7253473ea35a645e0c63ff8 Mon Sep 17 00:00:00 2001 From: MiXaiLL76 Date: Fri, 6 Sep 2024 11:08:11 +0300 Subject: [PATCH 2/5] Implementing concurrent.futures --- tools/extract_embedding.py | 139 +++++++++++++++---------------------- 1 file changed, 57 insertions(+), 82 deletions(-) diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py index 1fd1b95..779faa1 100755 --- a/tools/extract_embedding.py +++ b/tools/extract_embedding.py @@ -13,71 +13,40 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os +from concurrent.futures import ThreadPoolExecutor + +import onnxruntime import torch import torchaudio -from tqdm import tqdm -import onnxruntime import torchaudio.compliance.kaldi as kaldi -from queue import Queue, Empty -from threading import Thread +from tqdm import tqdm -class ExtractEmbedding: - def __init__(self, model_path: str, queue: Queue, out_queue: Queue): - self.model_path = model_path - self.queue = queue - self.out_queue = out_queue - self.is_run = True +def extract_embedding(input_list): + utt, wav_file, ort_session = input_list - def run(self): - self.consumer_thread = Thread(target=self.consumer) - self.consumer_thread.start() - - def stop(self): - self.is_run = False - self.consumer_thread.join() - - def consumer(self): - option = onnxruntime.SessionOptions() - option.graph_optimization_level = ( - onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - ) - option.intra_op_num_threads = 1 - providers = ["CPUExecutionProvider"] - ort_session = onnxruntime.InferenceSession( - self.model_path, sess_options=option, providers=providers - ) - - while self.is_run: - try: - utt, wav_file = self.queue.get(timeout=1) - - audio, sample_rate = torchaudio.load(wav_file) - if sample_rate != 16000: - audio = torchaudio.transforms.Resample( - orig_freq=sample_rate, new_freq=16000 - )(audio) - feat = kaldi.fbank( - audio, num_mel_bins=80, dither=0, sample_frequency=16000 - ) - feat = feat - feat.mean(dim=0, keepdim=True) - embedding = ( - ort_session.run( - None, - { - ort_session.get_inputs()[0] - .name: feat.unsqueeze(dim=0) - .cpu() - .numpy() - }, - )[0] - .flatten() - .tolist() - ) - self.out_queue.put((utt, embedding)) - except Empty: - self.is_run = False - break + audio, sample_rate = torchaudio.load(wav_file) + if sample_rate != 16000: + audio = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=16000 + )(audio) + feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000) + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = ( + ort_session.run( + None, + { + ort_session.get_inputs()[0] + .name: feat.unsqueeze(dim=0) + .cpu() + .numpy() + }, + )[0] + .flatten() + .tolist() + ) + return (utt, embedding) def main(args): @@ -91,32 +60,38 @@ def main(args): l = l.replace("\n", "").split() utt2spk[l[0]] = l[1] - input_queue = Queue() - output_queue = Queue() - consumers = [ - ExtractEmbedding(args.onnx_path, input_queue, output_queue) - for _ in range(args.num_thread) + assert os.path.exists(args.onnx_path), "onnx_path not exists" + + option = onnxruntime.SessionOptions() + option.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + ) + option.intra_op_num_threads = 1 + providers = ["CPUExecutionProvider"] + ort_session = onnxruntime.InferenceSession( + args.onnx_path, sess_options=option, providers=providers + ) + + inputs = [ + (utt, utt2wav[utt], ort_session) + for utt in tqdm(utt2wav.keys(), desc="Load data") ] + with ThreadPoolExecutor(max_workers=args.num_thread) as executor: + results = list( + tqdm( + executor.map(extract_embedding, inputs), + total=len(inputs), + desc="Process data: ", + ) + ) utt2embedding, spk2embedding = {}, {} - for utt in tqdm(utt2wav.keys(), desc="Load data"): - input_queue.put((utt, utt2wav[utt])) - - for c in consumers: - c.run() - - with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar: - while any([c.is_run for c in consumers]): - try: - utt, embedding = output_queue.get(timeout=1) - utt2embedding[utt] = embedding - spk = utt2spk[utt] - if spk not in spk2embedding: - spk2embedding[spk] = [] - spk2embedding[spk].append(embedding) - pbar.update(1) - except Empty: - continue + for utt, embedding in results: + utt2embedding[utt] = embedding + spk = utt2spk[utt] + if spk not in spk2embedding: + spk2embedding[spk] = [] + spk2embedding[spk].append(embedding) for k, v in spk2embedding.items(): spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() From 1d05ae5fd3f6263a040e4cdcd75b470ab4cf40d2 Mon Sep 17 00:00:00 2001 From: MiXaiLL76 Date: Fri, 6 Sep 2024 11:40:27 +0300 Subject: [PATCH 3/5] fix --- tools/extract_embedding.py | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py index 779faa1..982a841 100755 --- a/tools/extract_embedding.py +++ b/tools/extract_embedding.py @@ -21,11 +21,10 @@ import torch import torchaudio import torchaudio.compliance.kaldi as kaldi from tqdm import tqdm +from itertools import repeat -def extract_embedding(input_list): - utt, wav_file, ort_session = input_list - +def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession): audio, sample_rate = torchaudio.load(wav_file) if sample_rate != 16000: audio = torchaudio.transforms.Resample( @@ -33,19 +32,7 @@ def extract_embedding(input_list): )(audio) feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000) feat = feat - feat.mean(dim=0, keepdim=True) - embedding = ( - ort_session.run( - None, - { - ort_session.get_inputs()[0] - .name: feat.unsqueeze(dim=0) - .cpu() - .numpy() - }, - )[0] - .flatten() - .tolist() - ) + embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() return (utt, embedding) @@ -72,16 +59,14 @@ def main(args): args.onnx_path, sess_options=option, providers=providers ) - inputs = [ - (utt, utt2wav[utt], ort_session) - for utt in tqdm(utt2wav.keys(), desc="Load data") - ] + all_utt = utt2wav.keys() + with ThreadPoolExecutor(max_workers=args.num_thread) as executor: results = list( tqdm( - executor.map(extract_embedding, inputs), - total=len(inputs), - desc="Process data: ", + executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)), + total=len(utt2wav), + desc="Process data: " ) ) From ff8e63567af8f1620550caddf18e78ebf4a7d671 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Wed, 18 Sep 2024 17:41:15 +0800 Subject: [PATCH 4/5] use thread pool in tools --- tools/extract_embedding.py | 77 ++++++++++++++--------------------- tools/extract_speech_token.py | 54 +++++++++++++----------- 2 files changed, 61 insertions(+), 70 deletions(-) diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py index 982a841..cb198cb 100755 --- a/tools/extract_embedding.py +++ b/tools/extract_embedding.py @@ -13,74 +13,39 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -import os -from concurrent.futures import ThreadPoolExecutor - +from concurrent.futures import ThreadPoolExecutor, as_completed import onnxruntime import torch import torchaudio import torchaudio.compliance.kaldi as kaldi from tqdm import tqdm -from itertools import repeat -def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession): - audio, sample_rate = torchaudio.load(wav_file) +def single_job(utt): + audio, sample_rate = torchaudio.load(utt2wav[utt]) if sample_rate != 16000: - audio = torchaudio.transforms.Resample( - orig_freq=sample_rate, new_freq=16000 - )(audio) - feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000) + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) + feat = kaldi.fbank(audio, + num_mel_bins=80, + dither=0, + sample_frequency=16000) feat = feat - feat.mean(dim=0, keepdim=True) embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() - return (utt, embedding) + return utt, embedding def main(args): - utt2wav, utt2spk = {}, {} - with open("{}/wav.scp".format(args.dir)) as f: - for l in f: - l = l.replace("\n", "").split() - utt2wav[l[0]] = l[1] - with open("{}/utt2spk".format(args.dir)) as f: - for l in f: - l = l.replace("\n", "").split() - utt2spk[l[0]] = l[1] - - assert os.path.exists(args.onnx_path), "onnx_path not exists" - - option = onnxruntime.SessionOptions() - option.graph_optimization_level = ( - onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - ) - option.intra_op_num_threads = 1 - providers = ["CPUExecutionProvider"] - ort_session = onnxruntime.InferenceSession( - args.onnx_path, sess_options=option, providers=providers - ) - - all_utt = utt2wav.keys() - - with ThreadPoolExecutor(max_workers=args.num_thread) as executor: - results = list( - tqdm( - executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)), - total=len(utt2wav), - desc="Process data: " - ) - ) - + all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()] utt2embedding, spk2embedding = {}, {} - for utt, embedding in results: + for future in tqdm(as_completed(all_task)): + utt, embedding = future.result() utt2embedding[utt] = embedding spk = utt2spk[utt] if spk not in spk2embedding: spk2embedding[spk] = [] spk2embedding[spk].append(embedding) - for k, v in spk2embedding.items(): spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() - torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir)) torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir)) @@ -91,4 +56,22 @@ if __name__ == "__main__": parser.add_argument("--onnx_path", type=str) parser.add_argument("--num_thread", type=int, default=8) args = parser.parse_args() + + utt2wav, utt2spk = {}, {} + with open('{}/wav.scp'.format(args.dir)) as f: + for l in f: + l = l.replace('\n', '').split() + utt2wav[l[0]] = l[1] + with open('{}/utt2spk'.format(args.dir)) as f: + for l in f: + l = l.replace('\n', '').split() + utt2spk[l[0]] = l[1] + + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ["CPUExecutionProvider"] + ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) + executor = ThreadPoolExecutor(max_workers=args.num_thread) + main(args) diff --git a/tools/extract_speech_token.py b/tools/extract_speech_token.py index fac0b0b..2829624 100755 --- a/tools/extract_speech_token.py +++ b/tools/extract_speech_token.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +from concurrent.futures import ThreadPoolExecutor, as_completed import logging import torch from tqdm import tqdm @@ -22,7 +23,36 @@ import torchaudio import whisper +def single_job(utt): + audio, sample_rate = torchaudio.load(utt2wav[utt]) + if sample_rate != 16000: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) + if audio.shape[1] / 16000 > 30: + logging.warning('do not support extract speech token for audio longer than 30s') + speech_token = [] + else: + feat = whisper.log_mel_spectrogram(audio, n_mels=128) + speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(), + ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() + return utt, speech_token + + def main(args): + all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()] + utt2speech_token = {} + for future in tqdm(as_completed(all_task)): + utt, speech_token = future.result() + utt2speech_token[utt] = speech_token + torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dir", type=str) + parser.add_argument("--onnx_path", type=str) + parser.add_argument("--num_thread", type=int, default=8) + args = parser.parse_args() + utt2wav = {} with open('{}/wav.scp'.format(args.dir)) as f: for l in f: @@ -34,28 +64,6 @@ def main(args): option.intra_op_num_threads = 1 providers = ["CUDAExecutionProvider"] ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) + executor = ThreadPoolExecutor(max_workers=args.num_thread) - utt2speech_token = {} - for utt in tqdm(utt2wav.keys()): - audio, sample_rate = torchaudio.load(utt2wav[utt]) - if sample_rate != 16000: - audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) - if audio.shape[1] / 16000 > 30: - logging.warning('do not support extract speech token for audio longer than 30s') - speech_token = [] - else: - feat = whisper.log_mel_spectrogram(audio, n_mels=128) - speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(), - ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() - utt2speech_token[utt] = speech_token - torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--dir', - type=str) - parser.add_argument('--onnx_path', - type=str) - args = parser.parse_args() main(args) From f6b5c428237ecd0f485cf2cf4e367262d01908f4 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Wed, 18 Sep 2024 17:42:54 +0800 Subject: [PATCH 5/5] fix flake --- tools/extract_embedding.py | 6 +++--- tools/extract_speech_token.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py index cb198cb..de48779 100755 --- a/tools/extract_embedding.py +++ b/tools/extract_embedding.py @@ -26,9 +26,9 @@ def single_job(utt): if sample_rate != 16000: audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) feat = kaldi.fbank(audio, - num_mel_bins=80, - dither=0, - sample_frequency=16000) + num_mel_bins=80, + dither=0, + sample_frequency=16000) feat = feat - feat.mean(dim=0, keepdim=True) embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() return utt, embedding diff --git a/tools/extract_speech_token.py b/tools/extract_speech_token.py index 2829624..26aa296 100755 --- a/tools/extract_speech_token.py +++ b/tools/extract_speech_token.py @@ -33,7 +33,7 @@ def single_job(utt): else: feat = whisper.log_mel_spectrogram(audio, n_mels=128) speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(), - ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() + ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() return utt, speech_token