fix pitch computation

This commit is contained in:
lyuxiang.lx
2025-01-23 15:44:03 +08:00
parent 49761d2474
commit 190840b8dc
5 changed files with 19 additions and 14 deletions

View File

@@ -20,6 +20,7 @@ import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import pyworld as pw
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
@@ -178,7 +179,7 @@ def compute_fbank(data,
yield sample
def compute_f0(data, pitch_extractor, mode='train'):
def compute_f0(data, sample_rate, hop_size, mode='train'):
""" Extract f0
Args:
@@ -187,15 +188,19 @@ def compute_f0(data, pitch_extractor, mode='train'):
Returns:
Iterable[{key, feat, label}]
"""
frame_period = hop_size * 1000 / sample_rate
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
mat = pitch_extractor(waveform).transpose(1, 2)
mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
sample['pitch_feat'] = mat[0, 0]
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
sample['pitch_feat'] = f0
yield sample