mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
update dpo
This commit is contained in:
@@ -43,8 +43,6 @@ def parquet_opener(data, mode='train', tts_data={}):
|
||||
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
||||
df = df.to_pandas()
|
||||
for i in range(len(df)):
|
||||
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
||||
continue
|
||||
sample.update(dict(df.loc[i]))
|
||||
if mode == 'train':
|
||||
# NOTE do not return sample directly, must initialize a new dict
|
||||
@@ -100,6 +98,8 @@ def filter(data,
|
||||
continue
|
||||
if len(sample['speech_token']) == 0:
|
||||
continue
|
||||
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
|
||||
continue
|
||||
if num_frames != 0:
|
||||
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
||||
continue
|
||||
@@ -242,8 +242,6 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
||||
for sample in data:
|
||||
assert 'text' in sample
|
||||
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
||||
if mode == 'inference':
|
||||
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
||||
yield sample
|
||||
|
||||
|
||||
@@ -351,18 +349,15 @@ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
||||
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
||||
""" Wrapper for static/dynamic batch
|
||||
"""
|
||||
if mode == 'inference':
|
||||
return static_batch(data, 1)
|
||||
if batch_type == 'static':
|
||||
return static_batch(data, batch_size)
|
||||
elif batch_type == 'dynamic':
|
||||
return dynamic_batch(data, max_frames_in_batch)
|
||||
else:
|
||||
if batch_type == 'static':
|
||||
return static_batch(data, batch_size)
|
||||
elif batch_type == 'dynamic':
|
||||
return dynamic_batch(data, max_frames_in_batch)
|
||||
else:
|
||||
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
||||
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
||||
|
||||
|
||||
def padding(data, use_spk_embedding, mode='train', gan=False):
|
||||
def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
||||
""" Padding the data into training data
|
||||
|
||||
Args:
|
||||
@@ -424,16 +419,14 @@ def padding(data, use_spk_embedding, mode='train', gan=False):
|
||||
# only gan train needs speech, delete it to save memory
|
||||
del batch["speech"]
|
||||
del batch["speech_len"]
|
||||
if mode == 'inference':
|
||||
tts_text = [sample[i]['tts_text'] for i in order]
|
||||
tts_index = [sample[i]['tts_index'] for i in order]
|
||||
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
||||
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
||||
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
||||
batch.update({'tts_text': tts_text,
|
||||
'tts_index': tts_index,
|
||||
'tts_text_token': tts_text_token,
|
||||
'tts_text_token_len': tts_text_token_len})
|
||||
if dpo is True:
|
||||
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
|
||||
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
|
||||
reject_speech_token = pad_sequence(reject_speech_token,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
batch['reject_speech_token'] = reject_speech_token
|
||||
batch['reject_speech_token_len'] = reject_speech_token_len
|
||||
if use_spk_embedding is True:
|
||||
batch["embedding"] = batch["spk_embedding"]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user