mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
fix
This commit is contained in:
@@ -21,11 +21,10 @@ import torch
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from tqdm import tqdm
|
||||
from itertools import repeat
|
||||
|
||||
|
||||
def extract_embedding(input_list):
|
||||
utt, wav_file, ort_session = input_list
|
||||
|
||||
def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession):
|
||||
audio, sample_rate = torchaudio.load(wav_file)
|
||||
if sample_rate != 16000:
|
||||
audio = torchaudio.transforms.Resample(
|
||||
@@ -33,19 +32,7 @@ def extract_embedding(input_list):
|
||||
)(audio)
|
||||
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
|
||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||
embedding = (
|
||||
ort_session.run(
|
||||
None,
|
||||
{
|
||||
ort_session.get_inputs()[0]
|
||||
.name: feat.unsqueeze(dim=0)
|
||||
.cpu()
|
||||
.numpy()
|
||||
},
|
||||
)[0]
|
||||
.flatten()
|
||||
.tolist()
|
||||
)
|
||||
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
||||
return (utt, embedding)
|
||||
|
||||
|
||||
@@ -72,16 +59,14 @@ def main(args):
|
||||
args.onnx_path, sess_options=option, providers=providers
|
||||
)
|
||||
|
||||
inputs = [
|
||||
(utt, utt2wav[utt], ort_session)
|
||||
for utt in tqdm(utt2wav.keys(), desc="Load data")
|
||||
]
|
||||
all_utt = utt2wav.keys()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
|
||||
results = list(
|
||||
tqdm(
|
||||
executor.map(extract_embedding, inputs),
|
||||
total=len(inputs),
|
||||
desc="Process data: ",
|
||||
executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)),
|
||||
total=len(utt2wav),
|
||||
desc="Process data: "
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user