diff --git a/tools/extract_speech_token.py b/tools/extract_speech_token.py index 776b6cf..976a23b 100755 --- a/tools/extract_speech_token.py +++ b/tools/extract_speech_token.py @@ -27,6 +27,9 @@ def single_job(utt): audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile') if sample_rate != 16000: audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) + # Convert audio to mono + if audio.shape[0] > 1: + audio = audio.mean(dim=0, keepdim=True) if audio.shape[1] / 16000 > 30: logging.warning('do not support extract speech token for audio longer than 30s') speech_token = []