add hifigan train code

This commit is contained in:
lyuxiang.lx
2024-10-09 17:36:42 +08:00
parent 67f298d94a
commit cb200b21c5
10 changed files with 768 additions and 40 deletions

View File

@@ -85,6 +85,7 @@ def filter(data,
"""
for sample in data:
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
del sample['audio_data']
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
@@ -134,6 +135,27 @@ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
yield sample
def truncate(data, truncate_length=24576, mode='train'):
""" Truncate data.
Args:
data: Iterable[{key, wav, label, sample_rate}]
truncate_length: truncate length
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
waveform = sample['speech']
if waveform.shape[1] > truncate_length:
start = random.randint(0, waveform.shape[1] - truncate_length)
waveform = waveform[:, start: start + truncate_length]
else:
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
sample['speech'] = waveform
yield sample
def compute_fbank(data,
feat_extractor,
mode='train'):
@@ -153,7 +175,26 @@ def compute_fbank(data,
waveform = sample['speech']
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
sample['speech_feat'] = mat
del sample['speech']
yield sample
def compute_f0(data, pitch_extractor, mode='train'):
""" Extract f0
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
mat = pitch_extractor(waveform).transpose(1, 2)
mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
sample['pitch_feat'] = mat[0, 0]
yield sample
@@ -325,6 +366,9 @@ def padding(data, use_spk_embedding, mode='train'):
order = torch.argsort(speech_feat_len, descending=True)
utts = [sample[i]['utt'] for i in order]
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
speech = pad_sequence(speech, batch_first=True, padding_value=0)
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
@@ -335,6 +379,11 @@ def padding(data, use_spk_embedding, mode='train'):
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
pitch_feat = [sample[i]['pitch_feat'] for i in order]
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
pitch_feat = pad_sequence(pitch_feat,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
@@ -343,10 +392,14 @@ def padding(data, use_spk_embedding, mode='train'):
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
"utts": utts,
"speech": speech,
"speech_len": speech_len,
"speech_token": speech_token,
"speech_token_len": speech_token_len,
"speech_feat": speech_feat,
"speech_feat_len": speech_feat_len,
"pitch_feat": pitch_feat,
"pitch_feat_len": pitch_feat_len,
"text": text,
"text_token": text_token,
"text_token_len": text_token_len,