diff --git a/pyproject.toml b/pyproject.toml index fb835f0..19855ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ classifiers = [ "Topic :: Scientific/Engineering", ] dependencies = [ + "packaging", "torch>=1.12.0", "torchaudio>=0.12.0", "onnxruntime>=1.16.1", diff --git a/src/silero_vad/utils_vad.py b/src/silero_vad/utils_vad.py index 1666244..91e1b76 100644 --- a/src/silero_vad/utils_vad.py +++ b/src/silero_vad/utils_vad.py @@ -2,6 +2,7 @@ import torch import torchaudio from typing import Callable, List import warnings +from packaging import version languages = ['ru', 'en', 'de', 'es'] @@ -134,40 +135,60 @@ class Validator(): return outs -def read_audio(path: str, - sampling_rate: int = 16000): - list_backends = torchaudio.list_audio_backends() +def read_audio(path: str, sampling_rate: int = 16000) -> torch.Tensor: + ta_ver = version.parse(torchaudio.__version__) + if ta_ver < version.parse("2.9"): + try: + effects = [['channels', '1'],['rate', str(sampling_rate)]] + wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects) + except: + wav, sr = torchaudio.load(path) + else: + try: + wav, sr = torchaudio.load(path) + except: + try: + from torchcodec.decoders import AudioDecoder + samples = AudioDecoder(path).get_all_samples() + wav = samples.data + sr = samples.sample_rate + except ImportError: + raise RuntimeError( + f"torchaudio version {torchaudio.__version__} requires torchcodec for audio I/O. " + + "Install torchcodec or pin torchaudio < 2.9" + ) - assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \ - \n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)' + if wav.ndim > 1 and wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) - try: - effects = [ - ['channels', '1'], - ['rate', str(sampling_rate)] - ] + if sr != sampling_rate: + wav = torchaudio.transforms.Resample(sr, sampling_rate)(wav) - wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects) - except: - wav, sr = torchaudio.load(path) - - if wav.size(0) > 1: - wav = wav.mean(dim=0, keepdim=True) - - if sr != sampling_rate: - transform = torchaudio.transforms.Resample(orig_freq=sr, - new_freq=sampling_rate) - wav = transform(wav) - sr = sampling_rate - - assert sr == sampling_rate return wav.squeeze(0) -def save_audio(path: str, - tensor: torch.Tensor, - sampling_rate: int = 16000): - torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16) +def save_audio(path: str, tensor: torch.Tensor, sampling_rate: int = 16000): + tensor = tensor.detach().cpu() + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + + ta_ver = version.parse(torchaudio.__version__) + + try: + torchaudio.save(path, tensor, sampling_rate, bits_per_sample=16) + except Exception: + if ta_ver >= version.parse("2.9"): + try: + from torchcodec.encoders import AudioEncoder + encoder = AudioEncoder(tensor, sample_rate=16000) + encoder.to_file(path) + except ImportError: + raise RuntimeError( + f"torchaudio version {torchaudio.__version__} requires torchcodec for saving. " + + "Install torchcodec or pin torchaudio < 2.9" + ) + else: + raise def init_jit_model(model_path: str,