mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
add hifigan train
This commit is contained in:
@@ -350,7 +350,7 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, m
|
||||
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
||||
|
||||
|
||||
def padding(data, use_spk_embedding, mode='train'):
|
||||
def padding(data, use_spk_embedding, mode='train', gan=False):
|
||||
""" Padding the data into training data
|
||||
|
||||
Args:
|
||||
@@ -379,11 +379,6 @@ def padding(data, use_spk_embedding, mode='train'):
|
||||
speech_feat = pad_sequence(speech_feat,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
||||
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
||||
pitch_feat = pad_sequence(pitch_feat,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
text = [sample[i]['text'] for i in order]
|
||||
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
||||
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
||||
@@ -406,6 +401,19 @@ def padding(data, use_spk_embedding, mode='train'):
|
||||
"utt_embedding": utt_embedding,
|
||||
"spk_embedding": spk_embedding,
|
||||
}
|
||||
if gan is True:
|
||||
# in gan train, we need pitch_feat
|
||||
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
||||
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
||||
pitch_feat = pad_sequence(pitch_feat,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
batch["pitch_feat"] = pitch_feat
|
||||
batch["pitch_feat_len"] = pitch_feat_len
|
||||
else:
|
||||
# only gan train needs speech, delete it to save memory
|
||||
del batch["speech"]
|
||||
del batch["speech_len"]
|
||||
if mode == 'inference':
|
||||
tts_text = [sample[i]['tts_text'] for i in order]
|
||||
tts_index = [sample[i]['tts_index'] for i in order]
|
||||
|
||||
Reference in New Issue
Block a user