mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
fix
This commit is contained in:
@@ -21,11 +21,10 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
import torchaudio.compliance.kaldi as kaldi
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from itertools import repeat
|
||||||
|
|
||||||
|
|
||||||
def extract_embedding(input_list):
|
def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession):
|
||||||
utt, wav_file, ort_session = input_list
|
|
||||||
|
|
||||||
audio, sample_rate = torchaudio.load(wav_file)
|
audio, sample_rate = torchaudio.load(wav_file)
|
||||||
if sample_rate != 16000:
|
if sample_rate != 16000:
|
||||||
audio = torchaudio.transforms.Resample(
|
audio = torchaudio.transforms.Resample(
|
||||||
@@ -33,19 +32,7 @@ def extract_embedding(input_list):
|
|||||||
)(audio)
|
)(audio)
|
||||||
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
|
feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
|
||||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||||
embedding = (
|
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
||||||
ort_session.run(
|
|
||||||
None,
|
|
||||||
{
|
|
||||||
ort_session.get_inputs()[0]
|
|
||||||
.name: feat.unsqueeze(dim=0)
|
|
||||||
.cpu()
|
|
||||||
.numpy()
|
|
||||||
},
|
|
||||||
)[0]
|
|
||||||
.flatten()
|
|
||||||
.tolist()
|
|
||||||
)
|
|
||||||
return (utt, embedding)
|
return (utt, embedding)
|
||||||
|
|
||||||
|
|
||||||
@@ -72,16 +59,14 @@ def main(args):
|
|||||||
args.onnx_path, sess_options=option, providers=providers
|
args.onnx_path, sess_options=option, providers=providers
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs = [
|
all_utt = utt2wav.keys()
|
||||||
(utt, utt2wav[utt], ort_session)
|
|
||||||
for utt in tqdm(utt2wav.keys(), desc="Load data")
|
|
||||||
]
|
|
||||||
with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
|
with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
|
||||||
results = list(
|
results = list(
|
||||||
tqdm(
|
tqdm(
|
||||||
executor.map(extract_embedding, inputs),
|
executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)),
|
||||||
total=len(inputs),
|
total=len(utt2wav),
|
||||||
desc="Process data: ",
|
desc="Process data: "
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user