mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
add hifigan train code
This commit is contained in:
@@ -85,6 +85,7 @@ def filter(data,
|
||||
"""
|
||||
for sample in data:
|
||||
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
||||
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
||||
del sample['audio_data']
|
||||
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
||||
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
||||
@@ -134,6 +135,27 @@ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
||||
yield sample
|
||||
|
||||
|
||||
def truncate(data, truncate_length=24576, mode='train'):
|
||||
""" Truncate data.
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
truncate_length: truncate length
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
waveform = sample['speech']
|
||||
if waveform.shape[1] > truncate_length:
|
||||
start = random.randint(0, waveform.shape[1] - truncate_length)
|
||||
waveform = waveform[:, start: start + truncate_length]
|
||||
else:
|
||||
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
|
||||
sample['speech'] = waveform
|
||||
yield sample
|
||||
|
||||
|
||||
def compute_fbank(data,
|
||||
feat_extractor,
|
||||
mode='train'):
|
||||
@@ -153,7 +175,26 @@ def compute_fbank(data,
|
||||
waveform = sample['speech']
|
||||
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
||||
sample['speech_feat'] = mat
|
||||
del sample['speech']
|
||||
yield sample
|
||||
|
||||
def compute_f0(data, pitch_extractor, mode='train'):
|
||||
""" Extract f0
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'speech' in sample
|
||||
assert 'utt' in sample
|
||||
assert 'text_token' in sample
|
||||
waveform = sample['speech']
|
||||
mat = pitch_extractor(waveform).transpose(1, 2)
|
||||
mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
|
||||
sample['pitch_feat'] = mat[0, 0]
|
||||
yield sample
|
||||
|
||||
|
||||
@@ -325,6 +366,9 @@ def padding(data, use_spk_embedding, mode='train'):
|
||||
order = torch.argsort(speech_feat_len, descending=True)
|
||||
|
||||
utts = [sample[i]['utt'] for i in order]
|
||||
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
||||
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
||||
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
||||
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
||||
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
||||
speech_token = pad_sequence(speech_token,
|
||||
@@ -335,6 +379,11 @@ def padding(data, use_spk_embedding, mode='train'):
|
||||
speech_feat = pad_sequence(speech_feat,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
||||
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
||||
pitch_feat = pad_sequence(pitch_feat,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
text = [sample[i]['text'] for i in order]
|
||||
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
||||
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
||||
@@ -343,10 +392,14 @@ def padding(data, use_spk_embedding, mode='train'):
|
||||
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
||||
batch = {
|
||||
"utts": utts,
|
||||
"speech": speech,
|
||||
"speech_len": speech_len,
|
||||
"speech_token": speech_token,
|
||||
"speech_token_len": speech_token_len,
|
||||
"speech_feat": speech_feat,
|
||||
"speech_feat_len": speech_feat_len,
|
||||
"pitch_feat": pitch_feat,
|
||||
"pitch_feat_len": pitch_feat_len,
|
||||
"text": text,
|
||||
"text_token": text_token,
|
||||
"text_token_len": text_token_len,
|
||||
|
||||
Reference in New Issue
Block a user