Merge pull request #404 from FunAudioLLM/dev/lyuxiang.lx

Dev/lyuxiang.lx
This commit is contained in:
Xiang Lyu
2024-09-18 17:43:45 +08:00
committed by GitHub
2 changed files with 70 additions and 55 deletions

View File

@@ -13,14 +13,50 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
import onnxruntime
import torch import torch
import torchaudio import torchaudio
from tqdm import tqdm
import onnxruntime
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
from tqdm import tqdm
def single_job(utt):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
feat = kaldi.fbank(audio,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
return utt, embedding
def main(args): def main(args):
all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
utt2embedding, spk2embedding = {}, {}
for future in tqdm(as_completed(all_task)):
utt, embedding = future.result()
utt2embedding[utt] = embedding
spk = utt2spk[utt]
if spk not in spk2embedding:
spk2embedding[spk] = []
spk2embedding[spk].append(embedding)
for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dir", type=str)
parser.add_argument("--onnx_path", type=str)
parser.add_argument("--num_thread", type=int, default=8)
args = parser.parse_args()
utt2wav, utt2spk = {}, {} utt2wav, utt2spk = {}, {}
with open('{}/wav.scp'.format(args.dir)) as f: with open('{}/wav.scp'.format(args.dir)) as f:
for l in f: for l in f:
@@ -36,35 +72,6 @@ def main(args):
option.intra_op_num_threads = 1 option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"] providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
executor = ThreadPoolExecutor(max_workers=args.num_thread)
utt2embedding, spk2embedding = {}, {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
feat = kaldi.fbank(audio,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
utt2embedding[utt] = embedding
spk = utt2spk[utt]
if spk not in spk2embedding:
spk2embedding[spk] = []
spk2embedding[spk].append(embedding)
for k, v in spk2embedding.items():
spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--onnx_path',
type=str)
args = parser.parse_args()
main(args) main(args)

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging import logging
import torch import torch
from tqdm import tqdm from tqdm import tqdm
@@ -22,7 +23,36 @@ import torchaudio
import whisper import whisper
def single_job(utt):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
if audio.shape[1] / 16000 > 30:
logging.warning('do not support extract speech token for audio longer than 30s')
speech_token = []
else:
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
return utt, speech_token
def main(args): def main(args):
all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
utt2speech_token = {}
for future in tqdm(as_completed(all_task)):
utt, speech_token = future.result()
utt2speech_token[utt] = speech_token
torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dir", type=str)
parser.add_argument("--onnx_path", type=str)
parser.add_argument("--num_thread", type=int, default=8)
args = parser.parse_args()
utt2wav = {} utt2wav = {}
with open('{}/wav.scp'.format(args.dir)) as f: with open('{}/wav.scp'.format(args.dir)) as f:
for l in f: for l in f:
@@ -34,28 +64,6 @@ def main(args):
option.intra_op_num_threads = 1 option.intra_op_num_threads = 1
providers = ["CUDAExecutionProvider"] providers = ["CUDAExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
executor = ThreadPoolExecutor(max_workers=args.num_thread)
utt2speech_token = {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
if audio.shape[1] / 16000 > 30:
logging.warning('do not support extract speech token for audio longer than 30s')
speech_token = []
else:
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
utt2speech_token[utt] = speech_token
torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--onnx_path',
type=str)
args = parser.parse_args()
main(args) main(args)