mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 17:39:22 +08:00
resolve torchaudio 2.9 utils
This commit is contained in:
@@ -28,6 +28,7 @@ classifiers = [
|
|||||||
"Topic :: Scientific/Engineering",
|
"Topic :: Scientific/Engineering",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"packaging",
|
||||||
"torch>=1.12.0",
|
"torch>=1.12.0",
|
||||||
"torchaudio>=0.12.0",
|
"torchaudio>=0.12.0",
|
||||||
"onnxruntime>=1.16.1",
|
"onnxruntime>=1.16.1",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
import warnings
|
import warnings
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
languages = ['ru', 'en', 'de', 'es']
|
languages = ['ru', 'en', 'de', 'es']
|
||||||
|
|
||||||
@@ -134,40 +135,60 @@ class Validator():
|
|||||||
return outs
|
return outs
|
||||||
|
|
||||||
|
|
||||||
def read_audio(path: str,
|
def read_audio(path: str, sampling_rate: int = 16000) -> torch.Tensor:
|
||||||
sampling_rate: int = 16000):
|
ta_ver = version.parse(torchaudio.__version__)
|
||||||
list_backends = torchaudio.list_audio_backends()
|
if ta_ver < version.parse("2.9"):
|
||||||
|
try:
|
||||||
|
effects = [['channels', '1'],['rate', str(sampling_rate)]]
|
||||||
|
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
|
||||||
|
except:
|
||||||
|
wav, sr = torchaudio.load(path)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
wav, sr = torchaudio.load(path)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
from torchcodec.decoders import AudioDecoder
|
||||||
|
samples = AudioDecoder(path).get_all_samples()
|
||||||
|
wav = samples.data
|
||||||
|
sr = samples.sample_rate
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"torchaudio version {torchaudio.__version__} requires torchcodec for audio I/O. "
|
||||||
|
+ "Install torchcodec or pin torchaudio < 2.9"
|
||||||
|
)
|
||||||
|
|
||||||
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
|
if wav.ndim > 1 and wav.size(0) > 1:
|
||||||
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
|
wav = wav.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
try:
|
if sr != sampling_rate:
|
||||||
effects = [
|
wav = torchaudio.transforms.Resample(sr, sampling_rate)(wav)
|
||||||
['channels', '1'],
|
|
||||||
['rate', str(sampling_rate)]
|
|
||||||
]
|
|
||||||
|
|
||||||
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
|
|
||||||
except:
|
|
||||||
wav, sr = torchaudio.load(path)
|
|
||||||
|
|
||||||
if wav.size(0) > 1:
|
|
||||||
wav = wav.mean(dim=0, keepdim=True)
|
|
||||||
|
|
||||||
if sr != sampling_rate:
|
|
||||||
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
|
||||||
new_freq=sampling_rate)
|
|
||||||
wav = transform(wav)
|
|
||||||
sr = sampling_rate
|
|
||||||
|
|
||||||
assert sr == sampling_rate
|
|
||||||
return wav.squeeze(0)
|
return wav.squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
def save_audio(path: str,
|
def save_audio(path: str, tensor: torch.Tensor, sampling_rate: int = 16000):
|
||||||
tensor: torch.Tensor,
|
tensor = tensor.detach().cpu()
|
||||||
sampling_rate: int = 16000):
|
if tensor.ndim == 1:
|
||||||
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16)
|
tensor = tensor.unsqueeze(0)
|
||||||
|
|
||||||
|
ta_ver = version.parse(torchaudio.__version__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
torchaudio.save(path, tensor, sampling_rate, bits_per_sample=16)
|
||||||
|
except Exception:
|
||||||
|
if ta_ver >= version.parse("2.9"):
|
||||||
|
try:
|
||||||
|
from torchcodec.encoders import AudioEncoder
|
||||||
|
encoder = AudioEncoder(tensor, sample_rate=16000)
|
||||||
|
encoder.to_file(path)
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"torchaudio version {torchaudio.__version__} requires torchcodec for saving. "
|
||||||
|
+ "Install torchcodec or pin torchaudio < 2.9"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def init_jit_model(model_path: str,
|
def init_jit_model(model_path: str,
|
||||||
|
|||||||
Reference in New Issue
Block a user