diff --git a/tools/extract_embedding.py b/tools/extract_embedding.py index 616054c..de48779 100755 --- a/tools/extract_embedding.py +++ b/tools/extract_embedding.py @@ -13,14 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +from concurrent.futures import ThreadPoolExecutor, as_completed +import onnxruntime import torch import torchaudio -from tqdm import tqdm -import onnxruntime import torchaudio.compliance.kaldi as kaldi +from tqdm import tqdm + + +def single_job(utt): + audio, sample_rate = torchaudio.load(utt2wav[utt]) + if sample_rate != 16000: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) + feat = kaldi.fbank(audio, + num_mel_bins=80, + dither=0, + sample_frequency=16000) + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() + return utt, embedding def main(args): + all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()] + utt2embedding, spk2embedding = {}, {} + for future in tqdm(as_completed(all_task)): + utt, embedding = future.result() + utt2embedding[utt] = embedding + spk = utt2spk[utt] + if spk not in spk2embedding: + spk2embedding[spk] = [] + spk2embedding[spk].append(embedding) + for k, v in spk2embedding.items(): + spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() + torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir)) + torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dir", type=str) + parser.add_argument("--onnx_path", type=str) + parser.add_argument("--num_thread", type=int, default=8) + args = parser.parse_args() + utt2wav, utt2spk = {}, {} with open('{}/wav.scp'.format(args.dir)) as f: for l in f: @@ -36,35 +72,6 @@ def main(args): option.intra_op_num_threads = 1 providers = ["CPUExecutionProvider"] ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) + executor = ThreadPoolExecutor(max_workers=args.num_thread) - utt2embedding, spk2embedding = {}, {} - for utt in tqdm(utt2wav.keys()): - audio, sample_rate = torchaudio.load(utt2wav[utt]) - if sample_rate != 16000: - audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) - feat = kaldi.fbank(audio, - num_mel_bins=80, - dither=0, - sample_frequency=16000) - feat = feat - feat.mean(dim=0, keepdim=True) - embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() - utt2embedding[utt] = embedding - spk = utt2spk[utt] - if spk not in spk2embedding: - spk2embedding[spk] = [] - spk2embedding[spk].append(embedding) - for k, v in spk2embedding.items(): - spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() - - torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir)) - torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--dir', - type=str) - parser.add_argument('--onnx_path', - type=str) - args = parser.parse_args() main(args) diff --git a/tools/extract_speech_token.py b/tools/extract_speech_token.py index fac0b0b..26aa296 100755 --- a/tools/extract_speech_token.py +++ b/tools/extract_speech_token.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +from concurrent.futures import ThreadPoolExecutor, as_completed import logging import torch from tqdm import tqdm @@ -22,7 +23,36 @@ import torchaudio import whisper +def single_job(utt): + audio, sample_rate = torchaudio.load(utt2wav[utt]) + if sample_rate != 16000: + audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) + if audio.shape[1] / 16000 > 30: + logging.warning('do not support extract speech token for audio longer than 30s') + speech_token = [] + else: + feat = whisper.log_mel_spectrogram(audio, n_mels=128) + speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(), + ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() + return utt, speech_token + + def main(args): + all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()] + utt2speech_token = {} + for future in tqdm(as_completed(all_task)): + utt, speech_token = future.result() + utt2speech_token[utt] = speech_token + torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dir", type=str) + parser.add_argument("--onnx_path", type=str) + parser.add_argument("--num_thread", type=int, default=8) + args = parser.parse_args() + utt2wav = {} with open('{}/wav.scp'.format(args.dir)) as f: for l in f: @@ -34,28 +64,6 @@ def main(args): option.intra_op_num_threads = 1 providers = ["CUDAExecutionProvider"] ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) + executor = ThreadPoolExecutor(max_workers=args.num_thread) - utt2speech_token = {} - for utt in tqdm(utt2wav.keys()): - audio, sample_rate = torchaudio.load(utt2wav[utt]) - if sample_rate != 16000: - audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) - if audio.shape[1] / 16000 > 30: - logging.warning('do not support extract speech token for audio longer than 30s') - speech_token = [] - else: - feat = whisper.log_mel_spectrogram(audio, n_mels=128) - speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(), - ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() - utt2speech_token[utt] = speech_token - torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--dir', - type=str) - parser.add_argument('--onnx_path', - type=str) - args = parser.parse_args() main(args)