diff --git a/utils_vad.py b/utils_vad.py index 3d056b9..9a81cef 100644 --- a/utils_vad.py +++ b/utils_vad.py @@ -122,16 +122,24 @@ class Validator(): def read_audio(path: str, sampling_rate: int = 16000): - wav, sr = torchaudio.load(path) + if 'sox' in torchaudio.list_available_backends(): + effects = [ + ['channels', '1'], + ['rate', str(sampling_rate)] + ] - if wav.size(0) > 1: - wav = wav.mean(dim=0, keepdim=True) + wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects) + else: + wav, sr = torchaudio.load(path) - if sr != sampling_rate: - transform = torchaudio.transforms.Resample(orig_freq=sr, - new_freq=sampling_rate) - wav = transform(wav) - sr = sampling_rate + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + + if sr != sampling_rate: + transform = torchaudio.transforms.Resample(orig_freq=sr, + new_freq=sampling_rate) + wav = transform(wav) + sr = sampling_rate assert sr == sampling_rate return wav.squeeze(0)