mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add hifigan train
This commit is contained in:
@@ -126,6 +126,7 @@ class DataList(IterableDataset):
|
||||
def Dataset(data_list_file,
|
||||
data_pipeline,
|
||||
mode='train',
|
||||
gan=False,
|
||||
shuffle=True,
|
||||
partition=True,
|
||||
tts_file='',
|
||||
@@ -153,8 +154,11 @@ def Dataset(data_list_file,
|
||||
shuffle=shuffle,
|
||||
partition=partition)
|
||||
if mode == 'inference':
|
||||
# map partial arg tts_data in inference mode
|
||||
# map partial arg to parquet_opener func in inference mode
|
||||
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
||||
if gan is True:
|
||||
# map partial arg to padding func in gan mode
|
||||
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
||||
for func in data_pipeline:
|
||||
dataset = Processor(dataset, func, mode=mode)
|
||||
return dataset
|
||||
|
||||
@@ -350,7 +350,7 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, m
|
||||
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
||||
|
||||
|
||||
def padding(data, use_spk_embedding, mode='train'):
|
||||
def padding(data, use_spk_embedding, mode='train', gan=False):
|
||||
""" Padding the data into training data
|
||||
|
||||
Args:
|
||||
@@ -379,11 +379,6 @@ def padding(data, use_spk_embedding, mode='train'):
|
||||
speech_feat = pad_sequence(speech_feat,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
||||
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
||||
pitch_feat = pad_sequence(pitch_feat,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
text = [sample[i]['text'] for i in order]
|
||||
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
||||
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
||||
@@ -406,6 +401,19 @@ def padding(data, use_spk_embedding, mode='train'):
|
||||
"utt_embedding": utt_embedding,
|
||||
"spk_embedding": spk_embedding,
|
||||
}
|
||||
if gan is True:
|
||||
# in gan train, we need pitch_feat
|
||||
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
||||
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
||||
pitch_feat = pad_sequence(pitch_feat,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
batch["pitch_feat"] = pitch_feat
|
||||
batch["pitch_feat_len"] = pitch_feat_len
|
||||
else:
|
||||
# only gan train needs speech, delete it to save memory
|
||||
del batch["speech"]
|
||||
del batch["speech_len"]
|
||||
if mode == 'inference':
|
||||
tts_text = [sample[i]['tts_text'] for i in order]
|
||||
tts_index = [sample[i]['tts_index'] for i in order]
|
||||
|
||||
Reference in New Issue
Block a user