Merge pull request #356 from MiXaiLL76/main

Implemented fast processing of extract_embedding
This commit is contained in:
Xiang Lyu
2024-09-18 16:16:46 +08:00
committed by GitHub

View File

@@ -13,58 +13,82 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from concurrent.futures import ThreadPoolExecutor
import onnxruntime
import torch
import torchaudio
from tqdm import tqdm
import onnxruntime
import torchaudio.compliance.kaldi as kaldi
from tqdm import tqdm
from itertools import repeat
def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession):
audio, sample_rate = torchaudio.load(wav_file)
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(audio)
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
return (utt, embedding)
def main(args):
utt2wav, utt2spk = {}, {}
with open('{}/wav.scp'.format(args.dir)) as f:
with open("{}/wav.scp".format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
l = l.replace("\n", "").split()
utt2wav[l[0]] = l[1]
with open('{}/utt2spk'.format(args.dir)) as f:
with open("{}/utt2spk".format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
l = l.replace("\n", "").split()
utt2spk[l[0]] = l[1]
assert os.path.exists(args.onnx_path), "onnx_path not exists"
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
ort_session = onnxruntime.InferenceSession(
args.onnx_path, sess_options=option, providers=providers
)
all_utt = utt2wav.keys()
with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
results = list(
tqdm(
executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)),
total=len(utt2wav),
desc="Process data: "
)
)
utt2embedding, spk2embedding = {}, {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
feat = kaldi.fbank(audio,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
for utt, embedding in results:
utt2embedding[utt] = embedding
spk = utt2spk[utt]
if spk not in spk2embedding:
spk2embedding[spk] = []
spk2embedding[spk].append(embedding)
for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--onnx_path',
type=str)
parser.add_argument("--dir", type=str)
parser.add_argument("--onnx_path", type=str)
parser.add_argument("--num_thread", type=int, default=8)
args = parser.parse_args()
main(args)