update dpo

This commit is contained in:
lyuxiang.lx
2025-06-13 16:14:05 +08:00
parent cc234bd322
commit 63856565f3
23 changed files with 345 additions and 2024 deletions

View File

@@ -14,14 +14,13 @@
# limitations under the License.
import random
import json
import math
from functools import partial
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
from cosyvoice.utils.file_utils import read_lists, read_json_lists
from cosyvoice.utils.file_utils import read_lists
class Processor(IterableDataset):
@@ -127,10 +126,9 @@ def Dataset(data_list_file,
data_pipeline,
mode='train',
gan=False,
dpo=False,
shuffle=True,
partition=True,
tts_file='',
prompt_utt2data=''):
partition=True):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
@@ -142,23 +140,12 @@ def Dataset(data_list_file,
tokenizer (BaseTokenizer): tokenizer to tokenize
partition(bool): whether to do data partition in terms of rank
"""
assert mode in ['train', 'inference']
lists = read_lists(data_list_file)
if mode == 'inference':
with open(tts_file) as f:
tts_data = json.load(f)
utt2lists = read_json_lists(prompt_utt2data)
# filter unnecessary file in inference mode
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
dataset = DataList(lists,
shuffle=shuffle,
partition=partition)
if mode == 'inference':
# map partial arg to parquet_opener func in inference mode
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
if gan is True:
# map partial arg to padding func in gan mode
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
# map partial arg to padding func
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
for func in data_pipeline:
dataset = Processor(dataset, func, mode=mode)
return dataset