diff --git a/cosyvoice/bin/train_dpo.py b/cosyvoice/bin/train_dpo.py new file mode 100644 index 0000000..b5b282f --- /dev/null +++ b/cosyvoice/bin/train_dpo.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import argparse +import datetime +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +from copy import deepcopy +import os +import torch +import torch.distributed as dist +import deepspeed + +from hyperpyyaml import load_hyperpyyaml + +from torch.distributed.elastic.multiprocessing.errors import record + +from cosyvoice.utils.executor_dpo import Executor +from cosyvoice.utils.train_utils_dpo import ( + init_distributed, + init_dataset_and_dataloader, + init_optimizer_and_scheduler, + init_summarywriter, save_model, + wrap_cuda_model, check_modify_and_save_config) + + +def get_args(): + parser = argparse.ArgumentParser(description='training your network') + parser.add_argument('--train_engine', + default='torch_ddp', + choices=['torch_ddp', 'deepspeed'], + help='Engine for paralleled training') + parser.add_argument('--model', required=True, help='model which will be trained') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--train_data', required=True, help='train data file') + parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--checkpoint', help='checkpoint model') + parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--tensorboard_dir', + default='tensorboard', + help='tensorboard log dir') + parser.add_argument('--ddp.dist_backend', + dest='dist_backend', + default='nccl', + choices=['nccl', 'gloo'], + help='distributed backend') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--use_amp', + action='store_true', + default=False, + help='Use automatic mixed precision training') + parser.add_argument('--deepspeed.save_states', + dest='save_states', + default='model_only', + choices=['model_only', 'model+optimizer'], + help='save model/optimizer states') + parser.add_argument('--timeout', + default=60, + type=int, + help='timeout (in seconds) of cosyvoice_join.') + parser.add_argument('--dpo', + action='store_true', + default=False, + help='Use Direct Preference Optimization') + parser.add_argument('--beta', + default=0.01, + type=float, + help='beta of dpo training') + parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args() + return args + + +@record +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + # gan train has some special initialization logic + gan = True if args.model == 'hifigan' else False + + override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model} + if gan is True: + override_dict.pop('hift') + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides=override_dict) + if gan is True: + configs['train_conf'] = configs['train_conf_gan'] + configs['train_conf'].update(vars(args)) + + # Init env for ddp + init_distributed(args) + + # Get dataset & dataloader + train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ + init_dataset_and_dataloader(args, configs, gan) + + # Do some sanity checks and save config to arsg.model_dir + configs = check_modify_and_save_config(args, configs) + + # Tensorboard summary + writer = init_summarywriter(args) + + # load checkpoint + model = configs[args.model] + ref_model = None + if args.dpo: + ref_model = deepcopy(model) + start_step, start_epoch = 0, -1 + if args.checkpoint is not None: + if os.path.exists(args.checkpoint): + state_dict = torch.load(args.checkpoint, map_location='cpu') + model.load_state_dict(state_dict, strict=False) + if args.dpo: + ref_model.load_state_dict(state_dict, strict=False) + if 'step' in state_dict: + start_step = state_dict['step'] + if 'epoch' in state_dict: + start_epoch = state_dict['epoch'] + else: + logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint)) + + # Dispatch model from cpu to gpu + model = wrap_cuda_model(args, model) + if args.dpo: + ref_model = wrap_cuda_model(args, ref_model) + + # Get optimizer & scheduler + model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan) + if args.dpo: + ref_model, _, _, _, _ = init_optimizer_and_scheduler(args, configs, ref_model, gan) + scheduler.set_step(start_step) + if scheduler_d is not None: + scheduler_d.set_step(start_step) + + # Save init checkpoints + info_dict = deepcopy(configs['train_conf']) + info_dict['step'] = start_step + info_dict['epoch'] = start_epoch + save_model(model, 'init', info_dict) + + # Get executor + executor = Executor(gan=gan, dpo=args.dpo, beta=args.beta) + executor.step = start_step + + # Init scaler, used for pytorch amp mixed precision training + scaler = torch.cuda.amp.GradScaler() if args.use_amp else None + print('start step {} start epoch {}'.format(start_step, start_epoch)) + # Start training loop + for epoch in range(start_epoch + 1, info_dict['max_epoch']): + executor.epoch = epoch + train_dataset.set_epoch(epoch) + dist.barrier() + group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) + if gan is True: + executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, + writer, info_dict, scaler, group_join) + else: + executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model) + dist.destroy_process_group(group_join) + + +if __name__ == '__main__': + main() diff --git a/cosyvoice/dataset/processor_dpo.py b/cosyvoice/dataset/processor_dpo.py new file mode 100644 index 0000000..719b474 --- /dev/null +++ b/cosyvoice/dataset/processor_dpo.py @@ -0,0 +1,443 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random + +import pyarrow.parquet as pq +from io import BytesIO +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F +import pyworld as pw + + +AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} + + +def parquet_opener(data, mode='train', tts_data={}): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + url = sample['src'] + try: + for df in pq.ParquetFile(url).iter_batches(batch_size=64): + df = df.to_pandas() + for i in range(len(df)): + if mode == 'inference' and df.loc[i, 'utt'] not in tts_data: + continue + sample.update(dict(df.loc[i])) + if mode == 'train': + # NOTE do not return sample directly, must initialize a new dict + yield {**sample} + else: + for index, text in enumerate(tts_data[df.loc[i, 'utt']]): + yield {**sample, 'tts_index': index, 'tts_text': text} + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(url, ex)) + + +def filter(data, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1, + mode='train'): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) + sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) + del sample['audio_data'] + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['text_token']) < token_min_length: + continue + if len(sample['text_token']) > token_max_length: + continue + if len(sample['speech_token']) == 0: + continue + if num_frames != 0: + if len(sample['text_token']) / num_frames < min_output_input_ratio: + continue + if len(sample['text_token']) / num_frames > max_output_input_ratio: + continue + yield sample + + +def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + sample_rate = sample['sample_rate'] + waveform = sample['speech'] + if sample_rate != resample_rate: + if sample_rate < min_sample_rate: + continue + sample['sample_rate'] = resample_rate + sample['speech'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + max_val = sample['speech'].abs().max() + if max_val > 1: + sample['speech'] /= max_val + yield sample + + +def truncate(data, truncate_length=24576, mode='train'): + """ Truncate data. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + truncate_length: truncate length + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + waveform = sample['speech'] + if waveform.shape[1] > truncate_length: + start = random.randint(0, waveform.shape[1] - truncate_length) + waveform = waveform[:, start: start + truncate_length] + else: + waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1) + sample['speech'] = waveform + yield sample + + +def compute_fbank(data, + feat_extractor, + mode='train'): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + assert 'utt' in sample + assert 'text_token' in sample + waveform = sample['speech'] + mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) + sample['speech_feat'] = mat + yield sample + + +def compute_f0(data, sample_rate, hop_size, mode='train'): + """ Extract f0 + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + frame_period = hop_size * 1000 / sample_rate + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + assert 'utt' in sample + assert 'text_token' in sample + waveform = sample['speech'] + _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) + if sum(_f0 != 0) < 5: # this happens when the algorithm fails + _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio + f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate) + f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1) + sample['pitch_feat'] = f0 + yield sample + + +def parse_embedding(data, normalize, mode='train'): + """ Parse utt_embedding/spk_embedding + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32) + sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32) + if normalize: + sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0) + sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0) + yield sample + + +def tokenize(data, get_tokenizer, allowed_special, mode='train'): + """ Decode text to chars or BPE + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + tokenizer = get_tokenizer() + for sample in data: + assert 'text' in sample + sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) + if mode == 'inference': + sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special) + yield sample + + +def shuffle(data, shuffle_size=10000, mode='train'): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def sort(data, sort_size=500, mode='train'): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + data: Iterable[{key, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{key, feat, label}] + """ + + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['speech_feat'].size(0)) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['speech_feat'].size(0)) + for x in buf: + yield x + + +def static_batch(data, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + + +def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + data: Iterable[{key, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in data: + assert 'speech_feat' in sample + assert isinstance(sample['speech_feat'], torch.Tensor) + new_sample_frames = sample['speech_feat'].size(0) + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + if frames_after_padding > max_frames_in_batch: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'): + """ Wrapper for static/dynamic batch + """ + if mode == 'inference': + return static_batch(data, 1) + else: + if batch_type == 'static': + return static_batch(data, batch_size) + elif batch_type == 'dynamic': + return dynamic_batch(data, max_frames_in_batch) + else: + logging.fatal('Unsupported batch type {}'.format(batch_type)) + + +def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], + dtype=torch.int32) + order = torch.argsort(speech_feat_len, descending=True) + + utts = [sample[i]['utt'] for i in order] + speech = [sample[i]['speech'].squeeze(dim=0) for i in order] + speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32) + speech = pad_sequence(speech, batch_first=True, padding_value=0) + speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] + speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) + speech_token = pad_sequence(speech_token, + batch_first=True, + padding_value=0) + speech_feat = [sample[i]['speech_feat'] for i in order] + speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) + speech_feat = pad_sequence(speech_feat, + batch_first=True, + padding_value=0) + text = [sample[i]['text'] for i in order] + text_token = [torch.tensor(sample[i]['text_token']) for i in order] + text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) + text_token = pad_sequence(text_token, batch_first=True, padding_value=0) + utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0) + spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) + batch = { + "utts": utts, + "speech": speech, + "speech_len": speech_len, + "speech_token": speech_token, + "speech_token_len": speech_token_len, + "speech_feat": speech_feat, + "speech_feat_len": speech_feat_len, + "text": text, + "text_token": text_token, + "text_token_len": text_token_len, + "utt_embedding": utt_embedding, + "spk_embedding": spk_embedding, + } + if dpo: + reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order] + reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32) + reject_speech_token = pad_sequence(reject_speech_token, + batch_first=True, + padding_value=0) + batch['reject_speech_token'] = reject_speech_token + batch['reject_speech_token_len'] = reject_speech_token_len + if gan is True: + # in gan train, we need pitch_feat + pitch_feat = [sample[i]['pitch_feat'] for i in order] + pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32) + pitch_feat = pad_sequence(pitch_feat, + batch_first=True, + padding_value=0) + batch["pitch_feat"] = pitch_feat + batch["pitch_feat_len"] = pitch_feat_len + else: + # only gan train needs speech, delete it to save memory + del batch["speech"] + del batch["speech_len"] + if mode == 'inference': + tts_text = [sample[i]['tts_text'] for i in order] + tts_index = [sample[i]['tts_index'] for i in order] + tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] + tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) + tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) + batch.update({'tts_text': tts_text, + 'tts_index': tts_index, + 'tts_text_token': tts_text_token, + 'tts_text_token_len': tts_text_token_len}) + if use_spk_embedding is True: + batch["embedding"] = batch["spk_embedding"] + else: + batch["embedding"] = batch["utt_embedding"] + yield batch diff --git a/cosyvoice/llm/llm_dpo.py b/cosyvoice/llm/llm_dpo.py new file mode 100644 index 0000000..6e0dc2d --- /dev/null +++ b/cosyvoice/llm/llm_dpo.py @@ -0,0 +1,556 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Callable, List, Generator +import torch +from torch import nn +import torch.nn.functional as F +from transformers import Qwen2ForCausalLM +from torch.nn.utils.rnn import pad_sequence, unpad_sequence +from cosyvoice.utils.common import IGNORE_ID +from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss +from cosyvoice.utils.common import th_accuracy +from cosyvoice.utils.file_utils import logging +from cosyvoice.utils.mask import make_pad_mask + + +class TransformerLM(torch.nn.Module): + def __init__( + self, + text_encoder_input_size: int, + llm_input_size: int, + llm_output_size: int, + text_token_size: int, + speech_token_size: int, + text_encoder: torch.nn.Module, + llm: torch.nn.Module, + sampling: Callable, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + spk_embed_dim: int = 192, + ): + super().__init__() + self.llm_input_size = llm_input_size + self.speech_token_size = speech_token_size + # 1. build text token inputs related modules + self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size) + self.text_encoder = text_encoder + self.text_encoder_affine_layer = nn.Linear( + self.text_encoder.output_size(), + llm_input_size + ) + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1) + self.criterion_ce = LabelSmoothingLoss( + size=speech_token_size + 1, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size) + self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size) + + # 4. sampling method + self.sampling = sampling + + def encode( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + ): + encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out = self.text_encoder_affine_layer(encoder_out) + return encoder_out, encoder_out_lens + + def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len): + text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) + speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) + lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) + for i in range(len(text_token))] + lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) + lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) + return lm_input, lm_input_len + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ + Args: + text: (B, L, D) + text_lengths: (B,) + audio: (B, T, N) or (B, T) + audio_lengths: (B,) + """ + text_token = batch['text_token'].to(device) + text_token_len = batch['text_token_len'].to(device) + speech_token = batch['speech_token'].to(device) + speech_token_len = batch['speech_token_len'].to(device) + embedding = batch['embedding'].to(device) + + # 1. prepare llm_target + lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + + [self.speech_token_size]) for i in range(text_token.size(0))] + lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device) + + # 1. encode text_token + text_token = self.text_embedding(text_token) + text_token, text_token_len = self.encode(text_token, text_token_len) + + # 2. embedding projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + embedding = embedding.unsqueeze(1) + + # 3. eos and task_id + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + # 4. encode speech_token + speech_token = self.speech_embedding(speech_token) + + # 5. unpad and pad + lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, + task_id_emb, speech_token, speech_token_len) + + # 6. run lm forward + lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device)) + logits = self.llm_decoder(lm_output) + loss = self.criterion_ce(logits, lm_target) + acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID) + return {'loss': loss, 'acc': acc} + + def sampling_ids( + self, + weighted_scores: torch.Tensor, + decoded_tokens: List, + sampling: int, + ignore_eos: bool = True, + ): + num_trials, max_trials = 0, 100 + while True: + top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) + if (not ignore_eos) or (self.speech_token_size not in top_ids): + break + num_trials += 1 + if num_trials > max_trials: + raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials)) + return top_ids + + @torch.inference_mode() + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + if self.fp16 is True: + embedding = embedding.half() + + device = text.device + text = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.text_embedding(text) + + # 1. encode text + text, text_len = self.encode(text, text_len) + + # 2. encode embedding + if embedding.shape[0] != 0: + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + embedding = embedding.unsqueeze(dim=1) + else: + embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype) + + # 3. concat llm_input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) + + # 4. cal min/max_length + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + + # 5. step by step decode + out_tokens = [] + offset = 0 + att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device) + for i in range(max_len): + y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1, + att_cache=att_cache, cnn_cache=cnn_cache, + att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), + device=lm_input.device)).to(torch.bool)) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + # force continue decode first token + if i == 0: + logp[:, self.speech_token_size] = -float('inf') + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + offset += lm_input.size(1) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + +class Qwen2Encoder(torch.nn.Module): + def __init__(self, pretrain_path): + super().__init__() + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) + + def forward_one_step(self, xs, masks, cache=None): + input_masks = masks[:, -1, :] + outs = self.model( + inputs_embeds=xs, + attention_mask=input_masks, + output_hidden_states=True, + return_dict=True, + use_cache=True, + past_key_values=cache, + ) + xs = outs.hidden_states[-1] + new_cache = outs.past_key_values + return xs, new_cache + + +class Qwen2LM(TransformerLM): + def __init__( + self, + llm_input_size: int, + llm_output_size: int, + speech_token_size: int, + llm: torch.nn.Module, + sampling: Callable, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + mix_ratio: List[int] = [5, 15], + dpo: bool = False, + ): + torch.nn.Module.__init__(self) + self.llm_input_size = llm_input_size + self.llm_output_size = llm_output_size + self.speech_token_size = speech_token_size + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.fill_token = 2 + + self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3) + self.criterion_ce = LabelSmoothingLoss( + size=speech_token_size + 3, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size) + + # 4. sampling method + self.sampling = sampling + self.mix_ratio = mix_ratio + + # 5. [Optional] set dpo + self.dpo = dpo + + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + text_token = batch['text_token'].to(device) + text_token_len = batch['text_token_len'].to(device) + speech_token = batch['speech_token'].to(device) + speech_token_len = batch['speech_token_len'].to(device) + if self.dpo: + reject_speech_token = batch['reject_speech_token'].to(device) + reject_speech_token_len = batch['reject_speech_token_len'].to(device) + # 1. prepare llm_target + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + + [self.speech_token_size]) for i in range(text_token.size(0))] + if self.dpo: + reject_target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + reject_speech_token[i, :reject_speech_token_len[i]].tolist() + + [self.speech_token_size]) for i in range(text_token.size(0))] + target_ids.extend(reject_target_ids) + target_ids = pad_sequence(target_ids, batch_first=True, padding_value=IGNORE_ID).to(device) + + # 2. speech token projection + speech_emb = self.speech_embedding(speech_token) + if self.dpo: + reject_speech_emb = self.speech_embedding(reject_speech_token) + + # 3. text token projection + text_token_lst = unpad_sequence(text_token, text_token_len, batch_first=True) + text_emb = [self.llm.model.model.embed_tokens(y) for y in text_token_lst] + + # 4. prepare llm_input + speech_emb = unpad_sequence(speech_emb, speech_token_len.cpu(), batch_first=True) + input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), speech_emb[i]], dim=0) + for i in range(len(text_emb))] + if self.dpo: + reject_speech_emb = unpad_sequence(reject_speech_emb, reject_speech_token_len.cpu(), batch_first=True) + reject_input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), reject_speech_emb[i]], dim=0) + for i in range(len(text_emb))] + input_emb.extend(reject_input_emb) + input_emb_lengths = torch.tensor([i.size(0) for i in input_emb], dtype=torch.int32).to(device) + input_emb = pad_sequence(input_emb, batch_first=True, padding_value=IGNORE_ID).to(device) + + attention_mask = ~make_pad_mask(input_emb_lengths) + + result = self.llm.model( + inputs_embeds=input_emb, + attention_mask=attention_mask, + return_dict=True + ) + hidden_states = result.hidden_states + logits = self.llm_decoder(hidden_states[-1]) + loss = self.criterion_ce(logits[: speech_token.shape[0]], target_ids[: speech_token.shape[0]]) + acc = th_accuracy( + logits[: speech_token.shape[0]].view(-1, self.speech_token_size + 3), + target_ids[: speech_token.shape[0]], + ignore_label=IGNORE_ID, + ) + if not self.dpo: + return { + "loss": loss, + "acc": acc, + } + else: + all_logps_sum, all_logps_mean = self.get_batch_logps( + logits, target_ids, attention_mask, text_token_len, average_log_prob=False, ignore_id=IGNORE_ID + ) + chosen_logps = all_logps_sum[: speech_token.shape[0]] + rejected_logps = all_logps_sum[speech_token.shape[0]:] + return { + "loss": loss, + "acc": acc, + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps + } + + + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + attention_mask, + prompt_token_lens, + average_log_prob: bool = False, + ignore_id: int = -1, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert average_log_prob == False + assert logits.shape[:-1] == labels.shape + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_masks = attention_mask.clone().bool() + # mask prompts + for mask, text_token_len in zip(loss_masks, prompt_token_lens): + mask[:text_token_len + 1] = False + loss_masks = loss_masks[:, 1:] + labels[loss_masks == False] = 0 + # dummy token; we'll ignore the losses on these tokens later + ignore = labels == ignore_id + labels = labels.masked_fill(ignore, 0) # avoid -1 index + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) # (bs, time,) + logprobs_sums = (per_token_logps * loss_masks).sum(-1) + logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1) + return logprobs_sums, logprobs_means + + + @torch.inference_mode() + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + device = text.device + text = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.llm.model.model.embed_tokens(text) + + # 3. concat llm_input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) + + # 4. cal min/max_length + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + + # 5. step by step decode + out_tokens = [] + cache = None + for i in range(max_len): + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + if top_ids > self.speech_token_size: + continue + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + @torch.inference_mode() + def inference_bistream( + self, + text: Generator, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + + device = prompt_text.device + # 1. prepare input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb], dim=1) + + # 2. iterate text + out_tokens = [] + cache = None + # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5 + text_cache = self.llm.model.model.embed_tokens(prompt_text) + next_fill_index = -1 + for this_text in text: + text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1) + # prompt_speech_token_emb not empty, try append to lm_input + while prompt_speech_token_emb.size(1) != 0: + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]] + logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1))) + lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1) + text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:] + else: + logging.info('not enough text token to decode, wait for more') + break + # no prompt_speech_token_emb remain, can decode some speech token + if prompt_speech_token_emb.size(1) == 0: + if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1): + logging.info('get fill token, need to append more text token') + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text = text_cache[:, :self.mix_ratio[0]] + logging.info('append {} text token'.format(lm_input_text.size(1))) + if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2: + lm_input = lm_input_text + else: + lm_input = torch.concat([lm_input, lm_input_text], dim=1) + text_cache = text_cache[:, self.mix_ratio[0]:] + else: + logging.info('not enough text token to decode, wait for more') + continue + while True: + seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + if next_fill_index != -1 and len(out_tokens) == next_fill_index: + top_ids = self.speech_token_size + 2 + next_fill_index += (self.mix_ratio[1] + 1) + else: + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item() + if top_ids == self.speech_token_size + 2: + next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1 + logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index)) + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size + 2: + break + else: + raise ValueError('should not get token {}'.format(top_ids)) + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + # 3. final decode + lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1) + logging.info('no more text token, decode until met eos') + while True: + seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item() + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size: + break + else: + raise ValueError('should not get token {}'.format(top_ids)) + # in stream mode, yield token one by one + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) diff --git a/cosyvoice/utils/executor_dpo.py b/cosyvoice/utils/executor_dpo.py new file mode 100644 index 0000000..89bb528 --- /dev/null +++ b/cosyvoice/utils/executor_dpo.py @@ -0,0 +1,184 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from contextlib import nullcontext +import os + +import torch +import torch.distributed as dist + +from cosyvoice.utils.train_utils_dpo import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join +from cosyvoice.utils.losses_dpo import DPOLoss + + +class Executor: + + def __init__(self, gan: bool = False, dpo: bool = False, beta: float = 0.01, label_smoothing: float = 0.0, ipo: bool = False): + self.gan = gan + self.step = 0 + self.epoch = 0 + self.rank = int(os.environ.get('RANK', 0)) + self.device = torch.device('cuda:{}'.format(self.rank)) + self.dpo = dpo + if self.dpo: + self.dpo_loss = DPOLoss(beta, label_smoothing, ipo) + else: + self.dpo_loss = None + + def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None): + ''' Train one epoch + ''' + + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) + logging.info('using accumulate grad, new batch size is {} times' + ' larger than before'.format(info_dict['accum_grad'])) + # A context manager to be used in conjunction with an instance of + # torch.nn.parallel.DistributedDataParallel to be able to train + # with uneven inputs across participating processes. + model.train() + if self.dpo: + assert ref_model is not None + ref_model.eval() + model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext + with model_context(): + for batch_idx, batch_dict in enumerate(train_data_loader): + info_dict["tag"] = "TRAIN" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + if cosyvoice_join(group_join, info_dict): + break + + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: + context = model.no_sync + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + + with context(): + info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model, self.dpo_loss) + info_dict = batch_backward(model, scaler, info_dict) + + info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) + log_per_step(writer, info_dict) + # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save + if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ + (batch_idx + 1) % info_dict["accum_grad"] == 0: + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False, ref_model=ref_model, dpo_loss=self.dpo_loss) + model.train() + if (batch_idx + 1) % info_dict["accum_grad"] == 0: + self.step += 1 + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True, ref_model=ref_model, dpo_loss=self.dpo_loss) + + def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, + writer, info_dict, scaler, group_join): + ''' Train one epoch + ''' + + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) + logging.info('using accumulate grad, new batch size is {} times' + ' larger than before'.format(info_dict['accum_grad'])) + # A context manager to be used in conjunction with an instance of + # torch.nn.parallel.DistributedDataParallel to be able to train + # with uneven inputs across participating processes. + model.train() + model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext + with model_context(): + for batch_idx, batch_dict in enumerate(train_data_loader): + info_dict["tag"] = "TRAIN" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + if cosyvoice_join(group_join, info_dict): + break + + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: + context = model.no_sync + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + + with context(): + batch_dict['turn'] = 'discriminator' + info_dict = batch_forward(model, batch_dict, scaler, info_dict) + info_dict = batch_backward(model, scaler, info_dict) + info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict) + optimizer.zero_grad() + log_per_step(writer, info_dict) + with context(): + batch_dict['turn'] = 'generator' + info_dict = batch_forward(model, batch_dict, scaler, info_dict) + info_dict = batch_backward(model, scaler, info_dict) + info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) + optimizer_d.zero_grad() + log_per_step(writer, info_dict) + # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save + if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ + (batch_idx + 1) % info_dict["accum_grad"] == 0: + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) + model.train() + if (batch_idx + 1) % info_dict["accum_grad"] == 0: + self.step += 1 + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) + + @torch.inference_mode() + def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, ref_model=None, dpo_loss=None): + ''' Cross validation on + ''' + logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank)) + model.eval() + if self.dpo: + assert ref_model is not None + ref_model.eval() + total_num_utts, total_loss_dict = 0, {} # avoid division by 0 + for batch_idx, batch_dict in enumerate(cv_data_loader): + info_dict["tag"] = "CV" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + + num_utts = len(batch_dict["utts"]) + total_num_utts += num_utts + + if self.gan is True: + batch_dict['turn'] = 'generator' + info_dict = batch_forward(model, batch_dict, None, info_dict, ref_model, dpo_loss) + + for k, v in info_dict['loss_dict'].items(): + if k not in total_loss_dict: + total_loss_dict[k] = [] + total_loss_dict[k].append(v.item() * num_utts) + log_per_step(None, info_dict) + for k, v in total_loss_dict.items(): + total_loss_dict[k] = sum(v) / total_num_utts + info_dict['loss_dict'] = total_loss_dict + log_per_save(writer, info_dict) + model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) + save_model(model, model_name, info_dict) diff --git a/cosyvoice/utils/losses_dpo.py b/cosyvoice/utils/losses_dpo.py new file mode 100644 index 0000000..2429fdc --- /dev/null +++ b/cosyvoice/utils/losses_dpo.py @@ -0,0 +1,57 @@ +import torch +import torch.nn.functional as F +from typing import Tuple + + +def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): + loss = 0 + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + m_DG = torch.median((dr - dg)) + L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) + loss += tau - F.relu(tau - L_rel) + return loss + + +def mel_loss(real_speech, generated_speech, mel_transforms): + loss = 0 + for transform in mel_transforms: + mel_r = transform(real_speech) + mel_g = transform(generated_speech) + loss += F.l1_loss(mel_g, mel_r) + return loss + + +class DPOLoss(torch.nn.Module): + """ + DPO Loss + """ + + def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None: + super().__init__() + self.beta = beta + self.label_smoothing = label_smoothing + self.ipo = ipo + + def forward( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + logits = pi_logratios - ref_logratios + if self.ipo: + losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf + else: + # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + loss = losses.mean() + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return loss, chosen_rewards, rejected_rewards diff --git a/cosyvoice/utils/train_utils_dpo.py b/cosyvoice/utils/train_utils_dpo.py new file mode 100644 index 0000000..fa1529e --- /dev/null +++ b/cosyvoice/utils/train_utils_dpo.py @@ -0,0 +1,364 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2023 Horizon Inc. (authors: Xingchen Song) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import torch +import json +import re +import datetime +import yaml + +import deepspeed +import torch.optim as optim +import torch.distributed as dist + +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader +from torch.nn.utils import clip_grad_norm_ + +from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live + +from cosyvoice.dataset.dataset import Dataset +from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR + + +def init_distributed(args): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + logging.info('training on multiple gpus, this gpu {}'.format(local_rank) + + ', rank {}, world_size {}'.format(rank, world_size)) + if args.train_engine == 'torch_ddp': + torch.cuda.set_device(local_rank) + dist.init_process_group(args.dist_backend) + else: + deepspeed.init_distributed(dist_backend=args.dist_backend) + return world_size, local_rank, rank + + +def init_dataset_and_dataloader(args, configs, gan): + data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline'] + train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True) + cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False) + + # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts + train_data_loader = DataLoader(train_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + cv_data_loader = DataLoader(cv_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + return train_dataset, cv_dataset, train_data_loader, cv_data_loader + + +def check_modify_and_save_config(args, configs): + if args.train_engine == "torch_ddp": + configs['train_conf']["dtype"] = 'fp32' + else: + with open(args.deepspeed_config, 'r') as fin: + ds_configs = json.load(fin) + if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]: + configs['train_conf']["dtype"] = "fp16" + elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]: + configs['train_conf']["dtype"] = "bf16" + else: + configs['train_conf']["dtype"] = "fp32" + assert ds_configs["train_micro_batch_size_per_gpu"] == 1 + # if use deepspeed, override ddp config + configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * + configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"]) + configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"] + configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"] + configs['train_conf']['log_interval'] = ds_configs["steps_per_print"] + return configs + + +def wrap_cuda_model(args, model): + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) + world_size = int(os.environ.get('WORLD_SIZE', 1)) + if args.train_engine == "torch_ddp": # native pytorch ddp + assert (torch.cuda.is_available()) + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) + else: + if int(os.environ.get('RANK', 0)) == 0: + logging.info("Estimating model states memory needs (zero2)...") + estimate_zero2_model_states_mem_needs_all_live( + model, + num_gpus_per_node=local_world_size, + num_nodes=world_size // local_world_size) + return model + + +def init_optimizer_and_scheduler(args, configs, model, gan): + if gan is False: + if configs['train_conf']['optim'] == 'adam': + optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim'] == 'adamw': + optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler = ConstantLR(optimizer) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + + # use deepspeed optimizer for speedup + if args.train_engine == "deepspeed": + def scheduler(opt): + return scheduler_type(opt, **configs['train_conf']['scheduler_conf']) + model, optimizer, _, scheduler = deepspeed.initialize( + args=args, + model=model, + optimizer=None, + lr_scheduler=scheduler, + model_parameters=model.parameters()) + + optimizer_d, scheduler_d = None, None + + else: + # currently we wrap generator and discriminator in one model, so we cannot use deepspeed + if configs['train_conf']['optim'] == 'adam': + optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim'] == 'adamw': + optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler = ConstantLR(optimizer) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + + if configs['train_conf']['optim_d'] == 'adam': + optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim_d'] == 'adamw': + optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler_d'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler_d = ConstantLR(optimizer_d) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + return model, optimizer, scheduler, optimizer_d, scheduler_d + + +def init_summarywriter(args): + writer = None + if int(os.environ.get('RANK', 0)) == 0: + os.makedirs(args.model_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + return writer + + +def save_model(model, model_name, info_dict): + rank = int(os.environ.get('RANK', 0)) + model_dir = info_dict["model_dir"] + save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name)) + + if info_dict["train_engine"] == "torch_ddp": + if rank == 0: + torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path) + else: + with torch.no_grad(): + model.save_checkpoint(save_dir=model_dir, + tag=model_name, + client_state=info_dict) + if rank == 0: + info_path = re.sub('.pt$', '.yaml', save_model_path) + info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + with open(info_path, 'w') as fout: + data = yaml.dump(info_dict) + fout.write(data) + logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path)) + + +def cosyvoice_join(group_join, info_dict): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + + if info_dict["batch_idx"] != 0: + # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr + try: + dist.monitored_barrier(group=group_join, + timeout=group_join.options._timeout) + return False + except RuntimeError as e: + logging.info("Detected uneven workload distribution: {}\n".format(e) + + "Break current worker to manually join all workers, " + + "world_size {}, current rank {}, current local_rank {}\n". + format(world_size, rank, local_rank)) + return True + else: + return False + + +def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None): + device = int(os.environ.get('LOCAL_RANK', 0)) + + dtype = info_dict["dtype"] + if dtype == "fp16": + dtype = torch.float16 + elif dtype == "bf16": + dtype = torch.bfloat16 + else: # fp32 + dtype = torch.float32 + + if info_dict['train_engine'] == 'torch_ddp': + autocast = torch.cuda.amp.autocast(enabled=scaler is not None) + else: + autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False) + + with autocast: + info_dict['loss_dict'] = model(batch, device) + if ref_model and dpo_loss: + chosen_logps = info_dict['loss_dict']["chosen_logps"] + rejected_logps = info_dict['loss_dict']["rejected_logps"] + sft_loss = info_dict['loss_dict']['loss'] + with torch.no_grad(): + ref_model = ref_model.to(device) + ref_loss_dict = ref_model(batch, device) + reference_chosen_logps = ref_loss_dict["chosen_logps"] + reference_rejected_logps = ref_loss_dict["rejected_logps"] + preference_loss, chosen_reward, reject_reward = dpo_loss( + chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + dpo_acc = (chosen_reward > reject_reward).float().mean() + info_dict['loss_dict']["loss"] = preference_loss + sft_loss + info_dict['loss_dict']["sft_loss"] = sft_loss + info_dict['loss_dict']["dpo_loss"] = preference_loss + info_dict['loss_dict']["dpo_acc"] = dpo_acc + info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean() + info_dict['loss_dict']["reject_reward"] = reject_reward.mean() + return info_dict + + +def batch_backward(model, scaler, info_dict): + if info_dict["train_engine"] == "deepspeed": + scaled_loss = model.backward(info_dict['loss_dict']['loss']) + else: + scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad'] + if scaler is not None: + scaler.scale(scaled_loss).backward() + else: + scaled_loss.backward() + + info_dict['loss_dict']['loss'] = scaled_loss + return info_dict + + +def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): + grad_norm = 0.0 + if info_dict['train_engine'] == "deepspeed": + info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary() + model.step() + grad_norm = model.get_global_grad_norm() + elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0: + # Use mixed precision training + if scaler is not None: + scaler.unscale_(optimizer) + grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) + # We don't check grad here since that if the gradient + # has inf/nan values, scaler.step will skip + # optimizer.step(). + if torch.isfinite(grad_norm): + scaler.step(optimizer) + scaler.update() + else: + grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) + if torch.isfinite(grad_norm): + optimizer.step() + optimizer.zero_grad() + scheduler.step() + info_dict["lr"] = optimizer.param_groups[0]['lr'] + info_dict["grad_norm"] = grad_norm + return info_dict + + +def log_per_step(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict.get('epoch', 0) + step = info_dict["step"] + batch_idx = info_dict["batch_idx"] + loss_dict = info_dict['loss_dict'] + rank = int(os.environ.get('RANK', 0)) + + # only rank 0 write to tensorboard to avoid multi-process write + if writer is not None: + if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \ + (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0): + for k in ['epoch', 'lr', 'grad_norm']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1) + + # TRAIN & CV, Shell log (stdout) + if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0: + log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1) + for name, value in loss_dict.items(): + log_str += '{} {:.6f} '.format(name, value) + if tag == "TRAIN": + log_str += 'lr {:.8f} grad_norm {:.6f}'.format( + info_dict["lr"], info_dict['grad_norm']) + log_str += ' rank {}'.format(rank) + logging.debug(log_str) + + +def log_per_save(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict["epoch"] + step = info_dict["step"] + loss_dict = info_dict["loss_dict"] + lr = info_dict['lr'] + rank = int(os.environ.get('RANK', 0)) + logging.info( + 'Epoch {} Step {} CV info lr {} {} rank {}'.format( + epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()]))) + + if writer is not None: + for k in ['epoch', 'lr']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1) diff --git a/examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml b/examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml new file mode 100644 index 0000000..d811026 --- /dev/null +++ b/examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml @@ -0,0 +1,226 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1986] +__set_seed2: !apply:numpy.random.seed [1986] +__set_seed3: !apply:torch.manual_seed [1986] +__set_seed4: !apply:torch.cuda.manual_seed_all [1986] + +# fixed params +sample_rate: 24000 # 16000 for llm, 24000 for cfm +llm_input_size: 896 +llm_output_size: 896 +spk_embed_dim: 192 +qwen_pretrain_path: 'CosyVoice2-0.5B/CosyVoice-BlankEN' + +# model params +# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. +# for system/third_party class/function, we do not require this. +llm: !new:cosyvoice.llm.llm_dpo.Qwen2LM + llm_input_size: !ref + llm_output_size: !ref + speech_token_size: 6561 + length_normalized_loss: True + lsm_weight: 0 + dpo: True + llm: !new:cosyvoice.llm.llm.Qwen2Encoder + pretrain_path: !ref + sampling: !name:cosyvoice.utils.common.ras_sampling + top_p: 0.8 + top_k: 25 + win_size: 10 + tau_r: 0.1 +flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec + input_size: 512 + output_size: 80 + spk_embed_dim: !ref + output_type: 'mel' + vocab_size: 6561 + input_frame_rate: 25 + only_mask_loss: True + token_mel_ratio: 2 + pre_lookahead_len: 3 + encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder + output_size: 512 + attention_heads: 8 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + normalize_before: True + input_layer: 'linear' + pos_enc_layer_type: 'rel_pos_espnet' + selfattention_layer_type: 'rel_selfattn' + input_size: 512 + use_cnn_module: False + macaron_style: False + decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM + in_channels: 240 + n_spks: 1 + spk_emb_dim: 80 + cfm_params: !new:omegaconf.DictConfig + content: + sigma_min: 1e-06 + solver: 'euler' + t_scheduler: 'cosine' + training_cfg_rate: 0.2 + inference_cfg_rate: 0.7 + reg_loss_type: 'l1' + estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder + in_channels: 320 + out_channels: 80 + causal: True + channels: [256] + dropout: 0.0 + attention_head_dim: 64 + n_blocks: 4 + num_mid_blocks: 12 + num_heads: 8 + act_fn: 'gelu' + +hift: !new:cosyvoice.hifigan.generator.HiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 5, 3] + upsample_kernel_sizes: [16, 11, 7] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 + +# gan related module +mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 80 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: null + center: False +hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan + generator: !ref + discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator + mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator + mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator + mel_spec_transform: [ + !ref + ] + +# processor functions +parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener +get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe + multilingual: True + num_languages: 100 + language: 'en' + task: 'transcribe' +allowed_special: 'all' +tokenize: !name:cosyvoice.dataset.processor.tokenize + get_tokenizer: !ref + allowed_special: !ref +filter: !name:cosyvoice.dataset.processor.filter + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 +resample: !name:cosyvoice.dataset.processor.resample + resample_rate: !ref +truncate: !name:cosyvoice.dataset.processor.truncate + truncate_length: 24576 # must be a multiplier of hop_size +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 80 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: 8000 + center: False +compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank + feat_extractor: !ref +compute_f0: !name:cosyvoice.dataset.processor.compute_f0 + sample_rate: !ref + hop_size: 256 +parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding + normalize: True +shuffle: !name:cosyvoice.dataset.processor.shuffle + shuffle_size: 1000 +sort: !name:cosyvoice.dataset.processor.sort + sort_size: 500 # sort_size should be less than shuffle_size +batch: !name:cosyvoice.dataset.processor.batch + batch_type: 'dynamic' + max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g +padding: !name:cosyvoice.dataset.processor.padding + use_spk_embedding: True # change to True during sft + dpo: True + +# dataset processor pipeline +data_pipeline: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] +data_pipeline_gan: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] + +# llm flow train conf +train_conf: + optim: adam + optim_conf: + lr: 0.00001 # change to 1e-5 during sft + scheduler: warmuplr # change to constantlr during sft + scheduler_conf: + warmup_steps: 25000 + max_epoch: 200 + grad_clip: 5 + accum_grad: 2 + log_interval: 100 + save_per_step: -1 + +# gan train conf +train_conf_gan: + optim: adam + optim_conf: + lr: 0.0002 # use small lr for gan training + scheduler: constantlr + optim_d: adam + optim_conf_d: + lr: 0.0002 # use small lr for gan training + scheduler_d: constantlr + max_epoch: 200 + grad_clip: 5 + accum_grad: 1 # in gan training, accum_grad must be 1 + log_interval: 100 + save_per_step: -1 \ No newline at end of file diff --git a/tools/make_parquet_list_dpo.py b/tools/make_parquet_list_dpo.py new file mode 100755 index 0000000..c6ee6f5 --- /dev/null +++ b/tools/make_parquet_list_dpo.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import json +from tqdm import tqdm +import pandas as pd +import multiprocessing +import time +import torch + + +def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file): + start_time = time.time() + data_list = [] + for utt in tqdm(utt_list): + data = open(utt2wav[utt], 'rb').read() + data_list.append(data) + wav_list = [utt2wav[utt] for utt in utt_list] + text_list = [utt2text[utt] for utt in utt_list] + spk_list = [utt2spk[utt] for utt in utt_list] + uttembedding_list = [utt2embedding[utt] for utt in utt_list] + spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list] + speech_token_list = [utt2speech_token[utt] for utt in utt_list] + if utt2reject_speech_token: + reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list] + + # 保存到parquet,utt2parquet_file,spk2parquet_file + df = pd.DataFrame() + df['utt'] = utt_list + df['wav'] = wav_list + df['audio_data'] = data_list + df['text'] = text_list + df['spk'] = spk_list + df['utt_embedding'] = uttembedding_list + df['spk_embedding'] = spkembedding_list + df['speech_token'] = speech_token_list + if utt2reject_speech_token: + df['reject_speech_token'] = reject_speech_token_list + df.to_parquet(parquet_file) + with open(utt2parquet_file, 'w') as f: + json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2) + with open(spk2parquet_file, 'w') as f: + json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2) + logging.info('spend time {}'.format(time.time() - start_time)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--num_utts_per_parquet', + type=int, + default=1000, + help='num utts per parquet') + parser.add_argument('--num_processes', + type=int, + default=1, + help='num processes for make parquets') + parser.add_argument('--src_dir', + type=str) + parser.add_argument('--des_dir', + type=str) + parser.add_argument('--dpo', + action='store_true', + default=False, + help='Use Direct Preference Optimization') + args = parser.parse_args() + + utt2wav, utt2text, utt2spk = {}, {}, {} + with open('{}/wav.scp'.format(args.src_dir)) as f: + for l in f: + l = l.replace('\n', '').split() + utt2wav[l[0]] = l[1] + with open('{}/text'.format(args.src_dir)) as f: + for l in f: + l = l.replace('\n', '').split() + utt2text[l[0]] = ' '.join(l[1:]) + with open('{}/utt2spk'.format(args.src_dir)) as f: + for l in f: + l = l.replace('\n', '').split() + utt2spk[l[0]] = l[1] + utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) + spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) + utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) + if args.dpo: + utt2reject_speech_token = torch.load('{}/utt2reject_speech_token.pt'.format(args.src_dir)) + else: + utt2reject_speech_token = None + utts = list(utt2wav.keys()) + + # Using process pool to speedup + pool = multiprocessing.Pool(processes=args.num_processes) + parquet_list, utt2parquet_list, spk2parquet_list = [], [], [] + for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)): + parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i)) + utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i)) + spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i)) + parquet_list.append(parquet_file) + utt2parquet_list.append(utt2parquet_file) + spk2parquet_list.append(spk2parquet_file) + pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file)) + pool.close() + pool.join() + + with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \ + open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \ + open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3: + for name in parquet_list: + f1.write(name + '\n') + for name in utt2parquet_list: + f2.write(name + '\n') + for name in spk2parquet_list: + f3.write(name + '\n')