use thread pool in tools

This commit is contained in:
lyuxiang.lx
2024-09-18 17:41:15 +08:00
parent 2665b06e95
commit ff8e63567a
2 changed files with 61 additions and 70 deletions

View File

@@ -13,74 +13,39 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import os from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
import onnxruntime import onnxruntime
import torch import torch
import torchaudio import torchaudio
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
from tqdm import tqdm from tqdm import tqdm
from itertools import repeat
def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession): def single_job(utt):
audio, sample_rate = torchaudio.load(wav_file) audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000: if sample_rate != 16000:
audio = torchaudio.transforms.Resample( audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
orig_freq=sample_rate, new_freq=16000 feat = kaldi.fbank(audio,
)(audio) num_mel_bins=80,
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000) dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True) feat = feat - feat.mean(dim=0, keepdim=True)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
return (utt, embedding) return utt, embedding
def main(args): def main(args):
utt2wav, utt2spk = {}, {} all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
with open("{}/wav.scp".format(args.dir)) as f:
for l in f:
l = l.replace("\n", "").split()
utt2wav[l[0]] = l[1]
with open("{}/utt2spk".format(args.dir)) as f:
for l in f:
l = l.replace("\n", "").split()
utt2spk[l[0]] = l[1]
assert os.path.exists(args.onnx_path), "onnx_path not exists"
option = onnxruntime.SessionOptions()
option.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(
args.onnx_path, sess_options=option, providers=providers
)
all_utt = utt2wav.keys()
with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
results = list(
tqdm(
executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)),
total=len(utt2wav),
desc="Process data: "
)
)
utt2embedding, spk2embedding = {}, {} utt2embedding, spk2embedding = {}, {}
for utt, embedding in results: for future in tqdm(as_completed(all_task)):
utt, embedding = future.result()
utt2embedding[utt] = embedding utt2embedding[utt] = embedding
spk = utt2spk[utt] spk = utt2spk[utt]
if spk not in spk2embedding: if spk not in spk2embedding:
spk2embedding[spk] = [] spk2embedding[spk] = []
spk2embedding[spk].append(embedding) spk2embedding[spk].append(embedding)
for k, v in spk2embedding.items(): for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir)) torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir)) torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
@@ -91,4 +56,22 @@ if __name__ == "__main__":
parser.add_argument("--onnx_path", type=str) parser.add_argument("--onnx_path", type=str)
parser.add_argument("--num_thread", type=int, default=8) parser.add_argument("--num_thread", type=int, default=8)
args = parser.parse_args() args = parser.parse_args()
utt2wav, utt2spk = {}, {}
with open('{}/wav.scp'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
with open('{}/utt2spk'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2spk[l[0]] = l[1]
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
executor = ThreadPoolExecutor(max_workers=args.num_thread)
main(args) main(args)

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging import logging
import torch import torch
from tqdm import tqdm from tqdm import tqdm
@@ -22,7 +23,36 @@ import torchaudio
import whisper import whisper
def single_job(utt):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
if audio.shape[1] / 16000 > 30:
logging.warning('do not support extract speech token for audio longer than 30s')
speech_token = []
else:
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
return utt, speech_token
def main(args): def main(args):
all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
utt2speech_token = {}
for future in tqdm(as_completed(all_task)):
utt, speech_token = future.result()
utt2speech_token[utt] = speech_token
torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dir", type=str)
parser.add_argument("--onnx_path", type=str)
parser.add_argument("--num_thread", type=int, default=8)
args = parser.parse_args()
utt2wav = {} utt2wav = {}
with open('{}/wav.scp'.format(args.dir)) as f: with open('{}/wav.scp'.format(args.dir)) as f:
for l in f: for l in f:
@@ -34,28 +64,6 @@ def main(args):
option.intra_op_num_threads = 1 option.intra_op_num_threads = 1
providers = ["CUDAExecutionProvider"] providers = ["CUDAExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
executor = ThreadPoolExecutor(max_workers=args.num_thread)
utt2speech_token = {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
if audio.shape[1] / 16000 > 30:
logging.warning('do not support extract speech token for audio longer than 30s')
speech_token = []
else:
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
utt2speech_token[utt] = speech_token
torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--onnx_path',
type=str)
args = parser.parse_args()
main(args) main(args)