mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update dpo
This commit is contained in:
@@ -14,14 +14,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import json
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset
|
||||
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
||||
from cosyvoice.utils.file_utils import read_lists
|
||||
|
||||
|
||||
class Processor(IterableDataset):
|
||||
@@ -127,10 +126,9 @@ def Dataset(data_list_file,
|
||||
data_pipeline,
|
||||
mode='train',
|
||||
gan=False,
|
||||
dpo=False,
|
||||
shuffle=True,
|
||||
partition=True,
|
||||
tts_file='',
|
||||
prompt_utt2data=''):
|
||||
partition=True):
|
||||
""" Construct dataset from arguments
|
||||
|
||||
We have two shuffle stage in the Dataset. The first is global
|
||||
@@ -142,23 +140,12 @@ def Dataset(data_list_file,
|
||||
tokenizer (BaseTokenizer): tokenizer to tokenize
|
||||
partition(bool): whether to do data partition in terms of rank
|
||||
"""
|
||||
assert mode in ['train', 'inference']
|
||||
lists = read_lists(data_list_file)
|
||||
if mode == 'inference':
|
||||
with open(tts_file) as f:
|
||||
tts_data = json.load(f)
|
||||
utt2lists = read_json_lists(prompt_utt2data)
|
||||
# filter unnecessary file in inference mode
|
||||
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
||||
dataset = DataList(lists,
|
||||
shuffle=shuffle,
|
||||
partition=partition)
|
||||
if mode == 'inference':
|
||||
# map partial arg to parquet_opener func in inference mode
|
||||
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
||||
if gan is True:
|
||||
# map partial arg to padding func in gan mode
|
||||
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
||||
# map partial arg to padding func
|
||||
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
|
||||
for func in data_pipeline:
|
||||
dataset = Processor(dataset, func, mode=mode)
|
||||
return dataset
|
||||
|
||||
Reference in New Issue
Block a user