resolve torchaudio 2.9 utils

This commit is contained in:
adamnsandle
2025-10-17 12:35:40 +00:00
parent 33093c6f1b
commit 77c91a91fa
2 changed files with 50 additions and 28 deletions

View File

@@ -28,6 +28,7 @@ classifiers = [
"Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering",
] ]
dependencies = [ dependencies = [
"packaging",
"torch>=1.12.0", "torch>=1.12.0",
"torchaudio>=0.12.0", "torchaudio>=0.12.0",
"onnxruntime>=1.16.1", "onnxruntime>=1.16.1",

View File

@@ -2,6 +2,7 @@ import torch
import torchaudio import torchaudio
from typing import Callable, List from typing import Callable, List
import warnings import warnings
from packaging import version
languages = ['ru', 'en', 'de', 'es'] languages = ['ru', 'en', 'de', 'es']
@@ -134,40 +135,60 @@ class Validator():
return outs return outs
def read_audio(path: str, def read_audio(path: str, sampling_rate: int = 16000) -> torch.Tensor:
sampling_rate: int = 16000): ta_ver = version.parse(torchaudio.__version__)
list_backends = torchaudio.list_audio_backends() if ta_ver < version.parse("2.9"):
try:
effects = [['channels', '1'],['rate', str(sampling_rate)]]
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
except:
wav, sr = torchaudio.load(path)
else:
try:
wav, sr = torchaudio.load(path)
except:
try:
from torchcodec.decoders import AudioDecoder
samples = AudioDecoder(path).get_all_samples()
wav = samples.data
sr = samples.sample_rate
except ImportError:
raise RuntimeError(
f"torchaudio version {torchaudio.__version__} requires torchcodec for audio I/O. "
+ "Install torchcodec or pin torchaudio < 2.9"
)
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \ if wav.ndim > 1 and wav.size(0) > 1:
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)' wav = wav.mean(dim=0, keepdim=True)
try: if sr != sampling_rate:
effects = [ wav = torchaudio.transforms.Resample(sr, sampling_rate)(wav)
['channels', '1'],
['rate', str(sampling_rate)]
]
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
except:
wav, sr = torchaudio.load(path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != sampling_rate:
transform = torchaudio.transforms.Resample(orig_freq=sr,
new_freq=sampling_rate)
wav = transform(wav)
sr = sampling_rate
assert sr == sampling_rate
return wav.squeeze(0) return wav.squeeze(0)
def save_audio(path: str, def save_audio(path: str, tensor: torch.Tensor, sampling_rate: int = 16000):
tensor: torch.Tensor, tensor = tensor.detach().cpu()
sampling_rate: int = 16000): if tensor.ndim == 1:
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16) tensor = tensor.unsqueeze(0)
ta_ver = version.parse(torchaudio.__version__)
try:
torchaudio.save(path, tensor, sampling_rate, bits_per_sample=16)
except Exception:
if ta_ver >= version.parse("2.9"):
try:
from torchcodec.encoders import AudioEncoder
encoder = AudioEncoder(tensor, sample_rate=16000)
encoder.to_file(path)
except ImportError:
raise RuntimeError(
f"torchaudio version {torchaudio.__version__} requires torchcodec for saving. "
+ "Install torchcodec or pin torchaudio < 2.9"
)
else:
raise
def init_jit_model(model_path: str, def init_jit_model(model_path: str,