add threading

This commit is contained in:
MiXaiLL76
2024-09-05 14:32:37 +03:00
parent bcda6d807c
commit 7b3e285bca

View File

@@ -18,53 +18,117 @@ import torchaudio
from tqdm import tqdm from tqdm import tqdm
import onnxruntime import onnxruntime
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
from queue import Queue, Empty
from threading import Thread
class ExtractEmbedding:
def __init__(self, model_path: str, queue: Queue, out_queue: Queue):
self.model_path = model_path
self.queue = queue
self.out_queue = out_queue
self.is_run = True
def run(self):
self.consumer_thread = Thread(target=self.consumer)
self.consumer_thread.start()
def stop(self):
self.is_run = False
self.consumer_thread.join()
def consumer(self):
option = onnxruntime.SessionOptions()
option.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(
self.model_path, sess_options=option, providers=providers
)
while self.is_run:
try:
utt, wav_file = self.queue.get(timeout=1)
audio, sample_rate = torchaudio.load(wav_file)
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(audio)
feat = kaldi.fbank(
audio, num_mel_bins=80, dither=0, sample_frequency=16000
)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = (
ort_session.run(
None,
{
ort_session.get_inputs()[0]
.name: feat.unsqueeze(dim=0)
.cpu()
.numpy()
},
)[0]
.flatten()
.tolist()
)
self.out_queue.put((utt, embedding))
except Empty:
self.is_run = False
break
def main(args): def main(args):
utt2wav, utt2spk = {}, {} utt2wav, utt2spk = {}, {}
with open('{}/wav.scp'.format(args.dir)) as f: with open("{}/wav.scp".format(args.dir)) as f:
for l in f: for l in f:
l = l.replace('\n', '').split() l = l.replace("\n", "").split()
utt2wav[l[0]] = l[1] utt2wav[l[0]] = l[1]
with open('{}/utt2spk'.format(args.dir)) as f: with open("{}/utt2spk".format(args.dir)) as f:
for l in f: for l in f:
l = l.replace('\n', '').split() l = l.replace("\n", "").split()
utt2spk[l[0]] = l[1] utt2spk[l[0]] = l[1]
option = onnxruntime.SessionOptions() input_queue = Queue()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL output_queue = Queue()
option.intra_op_num_threads = 1 consumers = [
providers = ["CPUExecutionProvider"] ExtractEmbedding(args.onnx_path, input_queue, output_queue)
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) for _ in range(args.num_thread)
]
utt2embedding, spk2embedding = {}, {} utt2embedding, spk2embedding = {}, {}
for utt in tqdm(utt2wav.keys()): for utt in tqdm(utt2wav.keys(), desc="Load data"):
audio, sample_rate = torchaudio.load(utt2wav[utt]) input_queue.put((utt, utt2wav[utt]))
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) for c in consumers:
feat = kaldi.fbank(audio, c.run()
num_mel_bins=80,
dither=0, with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar:
sample_frequency=16000) while any([c.is_run for c in consumers]):
feat = feat - feat.mean(dim=0, keepdim=True) try:
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() utt, embedding = output_queue.get(timeout=1)
utt2embedding[utt] = embedding utt2embedding[utt] = embedding
spk = utt2spk[utt] spk = utt2spk[utt]
if spk not in spk2embedding: if spk not in spk2embedding:
spk2embedding[spk] = [] spk2embedding[spk] = []
spk2embedding[spk].append(embedding) spk2embedding[spk].append(embedding)
pbar.update(1)
except Empty:
continue
for k, v in spk2embedding.items(): for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir)) torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir)) torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dir', parser.add_argument("--dir", type=str)
type=str) parser.add_argument("--onnx_path", type=str)
parser.add_argument('--onnx_path', parser.add_argument("--num_thread", type=int, default=8)
type=str)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)