This commit is contained in:
MiXaiLL76
2024-09-06 11:40:27 +03:00
parent 73271d46f9
commit 1d05ae5fd3

View File

@@ -21,11 +21,10 @@ import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from tqdm import tqdm
from itertools import repeat
def extract_embedding(input_list):
utt, wav_file, ort_session = input_list
def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession):
audio, sample_rate = torchaudio.load(wav_file)
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(
@@ -33,19 +32,7 @@ def extract_embedding(input_list):
)(audio)
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = (
ort_session.run(
None,
{
ort_session.get_inputs()[0]
.name: feat.unsqueeze(dim=0)
.cpu()
.numpy()
},
)[0]
.flatten()
.tolist()
)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
return (utt, embedding)
@@ -72,16 +59,14 @@ def main(args):
args.onnx_path, sess_options=option, providers=providers
)
inputs = [
(utt, utt2wav[utt], ort_session)
for utt in tqdm(utt2wav.keys(), desc="Load data")
]
all_utt = utt2wav.keys()
with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
results = list(
tqdm(
executor.map(extract_embedding, inputs),
total=len(inputs),
desc="Process data: ",
executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)),
total=len(utt2wav),
desc="Process data: "
)
)