mirror of
https://github.com/snakers4/silero-vad.git
synced 2026-02-04 09:29:22 +08:00
Merge pull request #704 from snakers4/adamnsandle
resolve torchaudio 2.9 utils
This commit is contained in:
@@ -28,6 +28,7 @@ classifiers = [
|
||||
"Topic :: Scientific/Engineering",
|
||||
]
|
||||
dependencies = [
|
||||
"packaging",
|
||||
"torch>=1.12.0",
|
||||
"torchaudio>=0.12.0",
|
||||
"onnxruntime>=1.16.1",
|
||||
|
||||
@@ -2,6 +2,7 @@ import torch
|
||||
import torchaudio
|
||||
from typing import Callable, List
|
||||
import warnings
|
||||
from packaging import version
|
||||
|
||||
languages = ['ru', 'en', 'de', 'es']
|
||||
|
||||
@@ -134,40 +135,60 @@ class Validator():
|
||||
return outs
|
||||
|
||||
|
||||
def read_audio(path: str,
|
||||
sampling_rate: int = 16000):
|
||||
list_backends = torchaudio.list_audio_backends()
|
||||
def read_audio(path: str, sampling_rate: int = 16000) -> torch.Tensor:
|
||||
ta_ver = version.parse(torchaudio.__version__)
|
||||
if ta_ver < version.parse("2.9"):
|
||||
try:
|
||||
effects = [['channels', '1'],['rate', str(sampling_rate)]]
|
||||
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
|
||||
except:
|
||||
wav, sr = torchaudio.load(path)
|
||||
else:
|
||||
try:
|
||||
wav, sr = torchaudio.load(path)
|
||||
except:
|
||||
try:
|
||||
from torchcodec.decoders import AudioDecoder
|
||||
samples = AudioDecoder(path).get_all_samples()
|
||||
wav = samples.data
|
||||
sr = samples.sample_rate
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
f"torchaudio version {torchaudio.__version__} requires torchcodec for audio I/O. "
|
||||
+ "Install torchcodec or pin torchaudio < 2.9"
|
||||
)
|
||||
|
||||
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
|
||||
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
|
||||
if wav.ndim > 1 and wav.size(0) > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
try:
|
||||
effects = [
|
||||
['channels', '1'],
|
||||
['rate', str(sampling_rate)]
|
||||
]
|
||||
if sr != sampling_rate:
|
||||
wav = torchaudio.transforms.Resample(sr, sampling_rate)(wav)
|
||||
|
||||
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
|
||||
except:
|
||||
wav, sr = torchaudio.load(path)
|
||||
|
||||
if wav.size(0) > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != sampling_rate:
|
||||
transform = torchaudio.transforms.Resample(orig_freq=sr,
|
||||
new_freq=sampling_rate)
|
||||
wav = transform(wav)
|
||||
sr = sampling_rate
|
||||
|
||||
assert sr == sampling_rate
|
||||
return wav.squeeze(0)
|
||||
|
||||
|
||||
def save_audio(path: str,
|
||||
tensor: torch.Tensor,
|
||||
sampling_rate: int = 16000):
|
||||
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16)
|
||||
def save_audio(path: str, tensor: torch.Tensor, sampling_rate: int = 16000):
|
||||
tensor = tensor.detach().cpu()
|
||||
if tensor.ndim == 1:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
ta_ver = version.parse(torchaudio.__version__)
|
||||
|
||||
try:
|
||||
torchaudio.save(path, tensor, sampling_rate, bits_per_sample=16)
|
||||
except Exception:
|
||||
if ta_ver >= version.parse("2.9"):
|
||||
try:
|
||||
from torchcodec.encoders import AudioEncoder
|
||||
encoder = AudioEncoder(tensor, sample_rate=16000)
|
||||
encoder.to_file(path)
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
f"torchaudio version {torchaudio.__version__} requires torchcodec for saving. "
|
||||
+ "Install torchcodec or pin torchaudio < 2.9"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def init_jit_model(model_path: str,
|
||||
|
||||
Reference in New Issue
Block a user