From a22873e3609c048ae99c77843f2ac5d71b58774f Mon Sep 17 00:00:00 2001 From: ShengqiangLi Date: Mon, 17 Feb 2025 12:04:48 +0800 Subject: [PATCH 1/6] feat: Support DPO --- cosyvoice/bin/train_dpo.py | 187 ++++++ cosyvoice/dataset/processor_dpo.py | 443 ++++++++++++++ cosyvoice/llm/llm_dpo.py | 556 ++++++++++++++++++ cosyvoice/utils/executor_dpo.py | 184 ++++++ cosyvoice/utils/losses_dpo.py | 57 ++ cosyvoice/utils/train_utils_dpo.py | 364 ++++++++++++ .../cosyvoice/conf/cosyvoice_dpo.yaml | 226 +++++++ tools/make_parquet_list_dpo.py | 125 ++++ 8 files changed, 2142 insertions(+) create mode 100644 cosyvoice/bin/train_dpo.py create mode 100644 cosyvoice/dataset/processor_dpo.py create mode 100644 cosyvoice/llm/llm_dpo.py create mode 100644 cosyvoice/utils/executor_dpo.py create mode 100644 cosyvoice/utils/losses_dpo.py create mode 100644 cosyvoice/utils/train_utils_dpo.py create mode 100644 examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml create mode 100755 tools/make_parquet_list_dpo.py diff --git a/cosyvoice/bin/train_dpo.py b/cosyvoice/bin/train_dpo.py new file mode 100644 index 0000000..b5b282f --- /dev/null +++ b/cosyvoice/bin/train_dpo.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import argparse +import datetime +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +from copy import deepcopy +import os +import torch +import torch.distributed as dist +import deepspeed + +from hyperpyyaml import load_hyperpyyaml + +from torch.distributed.elastic.multiprocessing.errors import record + +from cosyvoice.utils.executor_dpo import Executor +from cosyvoice.utils.train_utils_dpo import ( + init_distributed, + init_dataset_and_dataloader, + init_optimizer_and_scheduler, + init_summarywriter, save_model, + wrap_cuda_model, check_modify_and_save_config) + + +def get_args(): + parser = argparse.ArgumentParser(description='training your network') + parser.add_argument('--train_engine', + default='torch_ddp', + choices=['torch_ddp', 'deepspeed'], + help='Engine for paralleled training') + parser.add_argument('--model', required=True, help='model which will be trained') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--train_data', required=True, help='train data file') + parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--checkpoint', help='checkpoint model') + parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--tensorboard_dir', + default='tensorboard', + help='tensorboard log dir') + parser.add_argument('--ddp.dist_backend', + dest='dist_backend', + default='nccl', + choices=['nccl', 'gloo'], + help='distributed backend') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--use_amp', + action='store_true', + default=False, + help='Use automatic mixed precision training') + parser.add_argument('--deepspeed.save_states', + dest='save_states', + default='model_only', + choices=['model_only', 'model+optimizer'], + help='save model/optimizer states') + parser.add_argument('--timeout', + default=60, + type=int, + help='timeout (in seconds) of cosyvoice_join.') + parser.add_argument('--dpo', + action='store_true', + default=False, + help='Use Direct Preference Optimization') + parser.add_argument('--beta', + default=0.01, + type=float, + help='beta of dpo training') + parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args() + return args + + +@record +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + # gan train has some special initialization logic + gan = True if args.model == 'hifigan' else False + + override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model} + if gan is True: + override_dict.pop('hift') + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides=override_dict) + if gan is True: + configs['train_conf'] = configs['train_conf_gan'] + configs['train_conf'].update(vars(args)) + + # Init env for ddp + init_distributed(args) + + # Get dataset & dataloader + train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ + init_dataset_and_dataloader(args, configs, gan) + + # Do some sanity checks and save config to arsg.model_dir + configs = check_modify_and_save_config(args, configs) + + # Tensorboard summary + writer = init_summarywriter(args) + + # load checkpoint + model = configs[args.model] + ref_model = None + if args.dpo: + ref_model = deepcopy(model) + start_step, start_epoch = 0, -1 + if args.checkpoint is not None: + if os.path.exists(args.checkpoint): + state_dict = torch.load(args.checkpoint, map_location='cpu') + model.load_state_dict(state_dict, strict=False) + if args.dpo: + ref_model.load_state_dict(state_dict, strict=False) + if 'step' in state_dict: + start_step = state_dict['step'] + if 'epoch' in state_dict: + start_epoch = state_dict['epoch'] + else: + logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint)) + + # Dispatch model from cpu to gpu + model = wrap_cuda_model(args, model) + if args.dpo: + ref_model = wrap_cuda_model(args, ref_model) + + # Get optimizer & scheduler + model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan) + if args.dpo: + ref_model, _, _, _, _ = init_optimizer_and_scheduler(args, configs, ref_model, gan) + scheduler.set_step(start_step) + if scheduler_d is not None: + scheduler_d.set_step(start_step) + + # Save init checkpoints + info_dict = deepcopy(configs['train_conf']) + info_dict['step'] = start_step + info_dict['epoch'] = start_epoch + save_model(model, 'init', info_dict) + + # Get executor + executor = Executor(gan=gan, dpo=args.dpo, beta=args.beta) + executor.step = start_step + + # Init scaler, used for pytorch amp mixed precision training + scaler = torch.cuda.amp.GradScaler() if args.use_amp else None + print('start step {} start epoch {}'.format(start_step, start_epoch)) + # Start training loop + for epoch in range(start_epoch + 1, info_dict['max_epoch']): + executor.epoch = epoch + train_dataset.set_epoch(epoch) + dist.barrier() + group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) + if gan is True: + executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, + writer, info_dict, scaler, group_join) + else: + executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model) + dist.destroy_process_group(group_join) + + +if __name__ == '__main__': + main() diff --git a/cosyvoice/dataset/processor_dpo.py b/cosyvoice/dataset/processor_dpo.py new file mode 100644 index 0000000..719b474 --- /dev/null +++ b/cosyvoice/dataset/processor_dpo.py @@ -0,0 +1,443 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random + +import pyarrow.parquet as pq +from io import BytesIO +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F +import pyworld as pw + + +AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} + + +def parquet_opener(data, mode='train', tts_data={}): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + url = sample['src'] + try: + for df in pq.ParquetFile(url).iter_batches(batch_size=64): + df = df.to_pandas() + for i in range(len(df)): + if mode == 'inference' and df.loc[i, 'utt'] not in tts_data: + continue + sample.update(dict(df.loc[i])) + if mode == 'train': + # NOTE do not return sample directly, must initialize a new dict + yield {**sample} + else: + for index, text in enumerate(tts_data[df.loc[i, 'utt']]): + yield {**sample, 'tts_index': index, 'tts_text': text} + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(url, ex)) + + +def filter(data, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1, + mode='train'): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) + sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) + del sample['audio_data'] + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['text_token']) < token_min_length: + continue + if len(sample['text_token']) > token_max_length: + continue + if len(sample['speech_token']) == 0: + continue + if num_frames != 0: + if len(sample['text_token']) / num_frames < min_output_input_ratio: + continue + if len(sample['text_token']) / num_frames > max_output_input_ratio: + continue + yield sample + + +def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + sample_rate = sample['sample_rate'] + waveform = sample['speech'] + if sample_rate != resample_rate: + if sample_rate < min_sample_rate: + continue + sample['sample_rate'] = resample_rate + sample['speech'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + max_val = sample['speech'].abs().max() + if max_val > 1: + sample['speech'] /= max_val + yield sample + + +def truncate(data, truncate_length=24576, mode='train'): + """ Truncate data. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + truncate_length: truncate length + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + waveform = sample['speech'] + if waveform.shape[1] > truncate_length: + start = random.randint(0, waveform.shape[1] - truncate_length) + waveform = waveform[:, start: start + truncate_length] + else: + waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1) + sample['speech'] = waveform + yield sample + + +def compute_fbank(data, + feat_extractor, + mode='train'): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + assert 'utt' in sample + assert 'text_token' in sample + waveform = sample['speech'] + mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) + sample['speech_feat'] = mat + yield sample + + +def compute_f0(data, sample_rate, hop_size, mode='train'): + """ Extract f0 + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + frame_period = hop_size * 1000 / sample_rate + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + assert 'utt' in sample + assert 'text_token' in sample + waveform = sample['speech'] + _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) + if sum(_f0 != 0) < 5: # this happens when the algorithm fails + _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio + f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate) + f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1) + sample['pitch_feat'] = f0 + yield sample + + +def parse_embedding(data, normalize, mode='train'): + """ Parse utt_embedding/spk_embedding + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32) + sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32) + if normalize: + sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0) + sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0) + yield sample + + +def tokenize(data, get_tokenizer, allowed_special, mode='train'): + """ Decode text to chars or BPE + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + tokenizer = get_tokenizer() + for sample in data: + assert 'text' in sample + sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) + if mode == 'inference': + sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special) + yield sample + + +def shuffle(data, shuffle_size=10000, mode='train'): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def sort(data, sort_size=500, mode='train'): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + data: Iterable[{key, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{key, feat, label}] + """ + + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['speech_feat'].size(0)) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['speech_feat'].size(0)) + for x in buf: + yield x + + +def static_batch(data, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + + +def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + data: Iterable[{key, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in data: + assert 'speech_feat' in sample + assert isinstance(sample['speech_feat'], torch.Tensor) + new_sample_frames = sample['speech_feat'].size(0) + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + if frames_after_padding > max_frames_in_batch: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'): + """ Wrapper for static/dynamic batch + """ + if mode == 'inference': + return static_batch(data, 1) + else: + if batch_type == 'static': + return static_batch(data, batch_size) + elif batch_type == 'dynamic': + return dynamic_batch(data, max_frames_in_batch) + else: + logging.fatal('Unsupported batch type {}'.format(batch_type)) + + +def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], + dtype=torch.int32) + order = torch.argsort(speech_feat_len, descending=True) + + utts = [sample[i]['utt'] for i in order] + speech = [sample[i]['speech'].squeeze(dim=0) for i in order] + speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32) + speech = pad_sequence(speech, batch_first=True, padding_value=0) + speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] + speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) + speech_token = pad_sequence(speech_token, + batch_first=True, + padding_value=0) + speech_feat = [sample[i]['speech_feat'] for i in order] + speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) + speech_feat = pad_sequence(speech_feat, + batch_first=True, + padding_value=0) + text = [sample[i]['text'] for i in order] + text_token = [torch.tensor(sample[i]['text_token']) for i in order] + text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) + text_token = pad_sequence(text_token, batch_first=True, padding_value=0) + utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0) + spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) + batch = { + "utts": utts, + "speech": speech, + "speech_len": speech_len, + "speech_token": speech_token, + "speech_token_len": speech_token_len, + "speech_feat": speech_feat, + "speech_feat_len": speech_feat_len, + "text": text, + "text_token": text_token, + "text_token_len": text_token_len, + "utt_embedding": utt_embedding, + "spk_embedding": spk_embedding, + } + if dpo: + reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order] + reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32) + reject_speech_token = pad_sequence(reject_speech_token, + batch_first=True, + padding_value=0) + batch['reject_speech_token'] = reject_speech_token + batch['reject_speech_token_len'] = reject_speech_token_len + if gan is True: + # in gan train, we need pitch_feat + pitch_feat = [sample[i]['pitch_feat'] for i in order] + pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32) + pitch_feat = pad_sequence(pitch_feat, + batch_first=True, + padding_value=0) + batch["pitch_feat"] = pitch_feat + batch["pitch_feat_len"] = pitch_feat_len + else: + # only gan train needs speech, delete it to save memory + del batch["speech"] + del batch["speech_len"] + if mode == 'inference': + tts_text = [sample[i]['tts_text'] for i in order] + tts_index = [sample[i]['tts_index'] for i in order] + tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] + tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) + tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) + batch.update({'tts_text': tts_text, + 'tts_index': tts_index, + 'tts_text_token': tts_text_token, + 'tts_text_token_len': tts_text_token_len}) + if use_spk_embedding is True: + batch["embedding"] = batch["spk_embedding"] + else: + batch["embedding"] = batch["utt_embedding"] + yield batch diff --git a/cosyvoice/llm/llm_dpo.py b/cosyvoice/llm/llm_dpo.py new file mode 100644 index 0000000..6e0dc2d --- /dev/null +++ b/cosyvoice/llm/llm_dpo.py @@ -0,0 +1,556 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Callable, List, Generator +import torch +from torch import nn +import torch.nn.functional as F +from transformers import Qwen2ForCausalLM +from torch.nn.utils.rnn import pad_sequence, unpad_sequence +from cosyvoice.utils.common import IGNORE_ID +from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss +from cosyvoice.utils.common import th_accuracy +from cosyvoice.utils.file_utils import logging +from cosyvoice.utils.mask import make_pad_mask + + +class TransformerLM(torch.nn.Module): + def __init__( + self, + text_encoder_input_size: int, + llm_input_size: int, + llm_output_size: int, + text_token_size: int, + speech_token_size: int, + text_encoder: torch.nn.Module, + llm: torch.nn.Module, + sampling: Callable, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + spk_embed_dim: int = 192, + ): + super().__init__() + self.llm_input_size = llm_input_size + self.speech_token_size = speech_token_size + # 1. build text token inputs related modules + self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size) + self.text_encoder = text_encoder + self.text_encoder_affine_layer = nn.Linear( + self.text_encoder.output_size(), + llm_input_size + ) + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1) + self.criterion_ce = LabelSmoothingLoss( + size=speech_token_size + 1, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size) + self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size) + + # 4. sampling method + self.sampling = sampling + + def encode( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + ): + encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out = self.text_encoder_affine_layer(encoder_out) + return encoder_out, encoder_out_lens + + def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len): + text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) + speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) + lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) + for i in range(len(text_token))] + lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) + lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) + return lm_input, lm_input_len + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ + Args: + text: (B, L, D) + text_lengths: (B,) + audio: (B, T, N) or (B, T) + audio_lengths: (B,) + """ + text_token = batch['text_token'].to(device) + text_token_len = batch['text_token_len'].to(device) + speech_token = batch['speech_token'].to(device) + speech_token_len = batch['speech_token_len'].to(device) + embedding = batch['embedding'].to(device) + + # 1. prepare llm_target + lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + + [self.speech_token_size]) for i in range(text_token.size(0))] + lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device) + + # 1. encode text_token + text_token = self.text_embedding(text_token) + text_token, text_token_len = self.encode(text_token, text_token_len) + + # 2. embedding projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + embedding = embedding.unsqueeze(1) + + # 3. eos and task_id + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + # 4. encode speech_token + speech_token = self.speech_embedding(speech_token) + + # 5. unpad and pad + lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, + task_id_emb, speech_token, speech_token_len) + + # 6. run lm forward + lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device)) + logits = self.llm_decoder(lm_output) + loss = self.criterion_ce(logits, lm_target) + acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID) + return {'loss': loss, 'acc': acc} + + def sampling_ids( + self, + weighted_scores: torch.Tensor, + decoded_tokens: List, + sampling: int, + ignore_eos: bool = True, + ): + num_trials, max_trials = 0, 100 + while True: + top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) + if (not ignore_eos) or (self.speech_token_size not in top_ids): + break + num_trials += 1 + if num_trials > max_trials: + raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials)) + return top_ids + + @torch.inference_mode() + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + if self.fp16 is True: + embedding = embedding.half() + + device = text.device + text = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.text_embedding(text) + + # 1. encode text + text, text_len = self.encode(text, text_len) + + # 2. encode embedding + if embedding.shape[0] != 0: + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + embedding = embedding.unsqueeze(dim=1) + else: + embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype) + + # 3. concat llm_input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) + + # 4. cal min/max_length + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + + # 5. step by step decode + out_tokens = [] + offset = 0 + att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device) + for i in range(max_len): + y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1, + att_cache=att_cache, cnn_cache=cnn_cache, + att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), + device=lm_input.device)).to(torch.bool)) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + # force continue decode first token + if i == 0: + logp[:, self.speech_token_size] = -float('inf') + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + offset += lm_input.size(1) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + +class Qwen2Encoder(torch.nn.Module): + def __init__(self, pretrain_path): + super().__init__() + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) + + def forward_one_step(self, xs, masks, cache=None): + input_masks = masks[:, -1, :] + outs = self.model( + inputs_embeds=xs, + attention_mask=input_masks, + output_hidden_states=True, + return_dict=True, + use_cache=True, + past_key_values=cache, + ) + xs = outs.hidden_states[-1] + new_cache = outs.past_key_values + return xs, new_cache + + +class Qwen2LM(TransformerLM): + def __init__( + self, + llm_input_size: int, + llm_output_size: int, + speech_token_size: int, + llm: torch.nn.Module, + sampling: Callable, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + mix_ratio: List[int] = [5, 15], + dpo: bool = False, + ): + torch.nn.Module.__init__(self) + self.llm_input_size = llm_input_size + self.llm_output_size = llm_output_size + self.speech_token_size = speech_token_size + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.fill_token = 2 + + self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3) + self.criterion_ce = LabelSmoothingLoss( + size=speech_token_size + 3, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size) + + # 4. sampling method + self.sampling = sampling + self.mix_ratio = mix_ratio + + # 5. [Optional] set dpo + self.dpo = dpo + + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + text_token = batch['text_token'].to(device) + text_token_len = batch['text_token_len'].to(device) + speech_token = batch['speech_token'].to(device) + speech_token_len = batch['speech_token_len'].to(device) + if self.dpo: + reject_speech_token = batch['reject_speech_token'].to(device) + reject_speech_token_len = batch['reject_speech_token_len'].to(device) + # 1. prepare llm_target + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + + [self.speech_token_size]) for i in range(text_token.size(0))] + if self.dpo: + reject_target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + reject_speech_token[i, :reject_speech_token_len[i]].tolist() + + [self.speech_token_size]) for i in range(text_token.size(0))] + target_ids.extend(reject_target_ids) + target_ids = pad_sequence(target_ids, batch_first=True, padding_value=IGNORE_ID).to(device) + + # 2. speech token projection + speech_emb = self.speech_embedding(speech_token) + if self.dpo: + reject_speech_emb = self.speech_embedding(reject_speech_token) + + # 3. text token projection + text_token_lst = unpad_sequence(text_token, text_token_len, batch_first=True) + text_emb = [self.llm.model.model.embed_tokens(y) for y in text_token_lst] + + # 4. prepare llm_input + speech_emb = unpad_sequence(speech_emb, speech_token_len.cpu(), batch_first=True) + input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), speech_emb[i]], dim=0) + for i in range(len(text_emb))] + if self.dpo: + reject_speech_emb = unpad_sequence(reject_speech_emb, reject_speech_token_len.cpu(), batch_first=True) + reject_input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), reject_speech_emb[i]], dim=0) + for i in range(len(text_emb))] + input_emb.extend(reject_input_emb) + input_emb_lengths = torch.tensor([i.size(0) for i in input_emb], dtype=torch.int32).to(device) + input_emb = pad_sequence(input_emb, batch_first=True, padding_value=IGNORE_ID).to(device) + + attention_mask = ~make_pad_mask(input_emb_lengths) + + result = self.llm.model( + inputs_embeds=input_emb, + attention_mask=attention_mask, + return_dict=True + ) + hidden_states = result.hidden_states + logits = self.llm_decoder(hidden_states[-1]) + loss = self.criterion_ce(logits[: speech_token.shape[0]], target_ids[: speech_token.shape[0]]) + acc = th_accuracy( + logits[: speech_token.shape[0]].view(-1, self.speech_token_size + 3), + target_ids[: speech_token.shape[0]], + ignore_label=IGNORE_ID, + ) + if not self.dpo: + return { + "loss": loss, + "acc": acc, + } + else: + all_logps_sum, all_logps_mean = self.get_batch_logps( + logits, target_ids, attention_mask, text_token_len, average_log_prob=False, ignore_id=IGNORE_ID + ) + chosen_logps = all_logps_sum[: speech_token.shape[0]] + rejected_logps = all_logps_sum[speech_token.shape[0]:] + return { + "loss": loss, + "acc": acc, + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps + } + + + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + attention_mask, + prompt_token_lens, + average_log_prob: bool = False, + ignore_id: int = -1, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert average_log_prob == False + assert logits.shape[:-1] == labels.shape + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_masks = attention_mask.clone().bool() + # mask prompts + for mask, text_token_len in zip(loss_masks, prompt_token_lens): + mask[:text_token_len + 1] = False + loss_masks = loss_masks[:, 1:] + labels[loss_masks == False] = 0 + # dummy token; we'll ignore the losses on these tokens later + ignore = labels == ignore_id + labels = labels.masked_fill(ignore, 0) # avoid -1 index + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) # (bs, time,) + logprobs_sums = (per_token_logps * loss_masks).sum(-1) + logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1) + return logprobs_sums, logprobs_means + + + @torch.inference_mode() + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + device = text.device + text = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.llm.model.model.embed_tokens(text) + + # 3. concat llm_input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) + + # 4. cal min/max_length + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + + # 5. step by step decode + out_tokens = [] + cache = None + for i in range(max_len): + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + if top_ids > self.speech_token_size: + continue + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + @torch.inference_mode() + def inference_bistream( + self, + text: Generator, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + + device = prompt_text.device + # 1. prepare input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb], dim=1) + + # 2. iterate text + out_tokens = [] + cache = None + # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5 + text_cache = self.llm.model.model.embed_tokens(prompt_text) + next_fill_index = -1 + for this_text in text: + text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1) + # prompt_speech_token_emb not empty, try append to lm_input + while prompt_speech_token_emb.size(1) != 0: + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]] + logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1))) + lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1) + text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:] + else: + logging.info('not enough text token to decode, wait for more') + break + # no prompt_speech_token_emb remain, can decode some speech token + if prompt_speech_token_emb.size(1) == 0: + if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1): + logging.info('get fill token, need to append more text token') + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text = text_cache[:, :self.mix_ratio[0]] + logging.info('append {} text token'.format(lm_input_text.size(1))) + if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2: + lm_input = lm_input_text + else: + lm_input = torch.concat([lm_input, lm_input_text], dim=1) + text_cache = text_cache[:, self.mix_ratio[0]:] + else: + logging.info('not enough text token to decode, wait for more') + continue + while True: + seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + if next_fill_index != -1 and len(out_tokens) == next_fill_index: + top_ids = self.speech_token_size + 2 + next_fill_index += (self.mix_ratio[1] + 1) + else: + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item() + if top_ids == self.speech_token_size + 2: + next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1 + logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index)) + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size + 2: + break + else: + raise ValueError('should not get token {}'.format(top_ids)) + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + # 3. final decode + lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1) + logging.info('no more text token, decode until met eos') + while True: + seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item() + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size: + break + else: + raise ValueError('should not get token {}'.format(top_ids)) + # in stream mode, yield token one by one + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) diff --git a/cosyvoice/utils/executor_dpo.py b/cosyvoice/utils/executor_dpo.py new file mode 100644 index 0000000..89bb528 --- /dev/null +++ b/cosyvoice/utils/executor_dpo.py @@ -0,0 +1,184 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from contextlib import nullcontext +import os + +import torch +import torch.distributed as dist + +from cosyvoice.utils.train_utils_dpo import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join +from cosyvoice.utils.losses_dpo import DPOLoss + + +class Executor: + + def __init__(self, gan: bool = False, dpo: bool = False, beta: float = 0.01, label_smoothing: float = 0.0, ipo: bool = False): + self.gan = gan + self.step = 0 + self.epoch = 0 + self.rank = int(os.environ.get('RANK', 0)) + self.device = torch.device('cuda:{}'.format(self.rank)) + self.dpo = dpo + if self.dpo: + self.dpo_loss = DPOLoss(beta, label_smoothing, ipo) + else: + self.dpo_loss = None + + def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None): + ''' Train one epoch + ''' + + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) + logging.info('using accumulate grad, new batch size is {} times' + ' larger than before'.format(info_dict['accum_grad'])) + # A context manager to be used in conjunction with an instance of + # torch.nn.parallel.DistributedDataParallel to be able to train + # with uneven inputs across participating processes. + model.train() + if self.dpo: + assert ref_model is not None + ref_model.eval() + model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext + with model_context(): + for batch_idx, batch_dict in enumerate(train_data_loader): + info_dict["tag"] = "TRAIN" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + if cosyvoice_join(group_join, info_dict): + break + + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: + context = model.no_sync + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + + with context(): + info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model, self.dpo_loss) + info_dict = batch_backward(model, scaler, info_dict) + + info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) + log_per_step(writer, info_dict) + # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save + if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ + (batch_idx + 1) % info_dict["accum_grad"] == 0: + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False, ref_model=ref_model, dpo_loss=self.dpo_loss) + model.train() + if (batch_idx + 1) % info_dict["accum_grad"] == 0: + self.step += 1 + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True, ref_model=ref_model, dpo_loss=self.dpo_loss) + + def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, + writer, info_dict, scaler, group_join): + ''' Train one epoch + ''' + + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) + logging.info('using accumulate grad, new batch size is {} times' + ' larger than before'.format(info_dict['accum_grad'])) + # A context manager to be used in conjunction with an instance of + # torch.nn.parallel.DistributedDataParallel to be able to train + # with uneven inputs across participating processes. + model.train() + model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext + with model_context(): + for batch_idx, batch_dict in enumerate(train_data_loader): + info_dict["tag"] = "TRAIN" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + if cosyvoice_join(group_join, info_dict): + break + + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: + context = model.no_sync + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + + with context(): + batch_dict['turn'] = 'discriminator' + info_dict = batch_forward(model, batch_dict, scaler, info_dict) + info_dict = batch_backward(model, scaler, info_dict) + info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict) + optimizer.zero_grad() + log_per_step(writer, info_dict) + with context(): + batch_dict['turn'] = 'generator' + info_dict = batch_forward(model, batch_dict, scaler, info_dict) + info_dict = batch_backward(model, scaler, info_dict) + info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) + optimizer_d.zero_grad() + log_per_step(writer, info_dict) + # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save + if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ + (batch_idx + 1) % info_dict["accum_grad"] == 0: + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) + model.train() + if (batch_idx + 1) % info_dict["accum_grad"] == 0: + self.step += 1 + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) + + @torch.inference_mode() + def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, ref_model=None, dpo_loss=None): + ''' Cross validation on + ''' + logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank)) + model.eval() + if self.dpo: + assert ref_model is not None + ref_model.eval() + total_num_utts, total_loss_dict = 0, {} # avoid division by 0 + for batch_idx, batch_dict in enumerate(cv_data_loader): + info_dict["tag"] = "CV" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + + num_utts = len(batch_dict["utts"]) + total_num_utts += num_utts + + if self.gan is True: + batch_dict['turn'] = 'generator' + info_dict = batch_forward(model, batch_dict, None, info_dict, ref_model, dpo_loss) + + for k, v in info_dict['loss_dict'].items(): + if k not in total_loss_dict: + total_loss_dict[k] = [] + total_loss_dict[k].append(v.item() * num_utts) + log_per_step(None, info_dict) + for k, v in total_loss_dict.items(): + total_loss_dict[k] = sum(v) / total_num_utts + info_dict['loss_dict'] = total_loss_dict + log_per_save(writer, info_dict) + model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) + save_model(model, model_name, info_dict) diff --git a/cosyvoice/utils/losses_dpo.py b/cosyvoice/utils/losses_dpo.py new file mode 100644 index 0000000..2429fdc --- /dev/null +++ b/cosyvoice/utils/losses_dpo.py @@ -0,0 +1,57 @@ +import torch +import torch.nn.functional as F +from typing import Tuple + + +def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): + loss = 0 + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + m_DG = torch.median((dr - dg)) + L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) + loss += tau - F.relu(tau - L_rel) + return loss + + +def mel_loss(real_speech, generated_speech, mel_transforms): + loss = 0 + for transform in mel_transforms: + mel_r = transform(real_speech) + mel_g = transform(generated_speech) + loss += F.l1_loss(mel_g, mel_r) + return loss + + +class DPOLoss(torch.nn.Module): + """ + DPO Loss + """ + + def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None: + super().__init__() + self.beta = beta + self.label_smoothing = label_smoothing + self.ipo = ipo + + def forward( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + logits = pi_logratios - ref_logratios + if self.ipo: + losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf + else: + # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + loss = losses.mean() + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return loss, chosen_rewards, rejected_rewards diff --git a/cosyvoice/utils/train_utils_dpo.py b/cosyvoice/utils/train_utils_dpo.py new file mode 100644 index 0000000..fa1529e --- /dev/null +++ b/cosyvoice/utils/train_utils_dpo.py @@ -0,0 +1,364 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2023 Horizon Inc. (authors: Xingchen Song) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import torch +import json +import re +import datetime +import yaml + +import deepspeed +import torch.optim as optim +import torch.distributed as dist + +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader +from torch.nn.utils import clip_grad_norm_ + +from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live + +from cosyvoice.dataset.dataset import Dataset +from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR + + +def init_distributed(args): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + logging.info('training on multiple gpus, this gpu {}'.format(local_rank) + + ', rank {}, world_size {}'.format(rank, world_size)) + if args.train_engine == 'torch_ddp': + torch.cuda.set_device(local_rank) + dist.init_process_group(args.dist_backend) + else: + deepspeed.init_distributed(dist_backend=args.dist_backend) + return world_size, local_rank, rank + + +def init_dataset_and_dataloader(args, configs, gan): + data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline'] + train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True) + cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False) + + # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts + train_data_loader = DataLoader(train_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + cv_data_loader = DataLoader(cv_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + return train_dataset, cv_dataset, train_data_loader, cv_data_loader + + +def check_modify_and_save_config(args, configs): + if args.train_engine == "torch_ddp": + configs['train_conf']["dtype"] = 'fp32' + else: + with open(args.deepspeed_config, 'r') as fin: + ds_configs = json.load(fin) + if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]: + configs['train_conf']["dtype"] = "fp16" + elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]: + configs['train_conf']["dtype"] = "bf16" + else: + configs['train_conf']["dtype"] = "fp32" + assert ds_configs["train_micro_batch_size_per_gpu"] == 1 + # if use deepspeed, override ddp config + configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * + configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"]) + configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"] + configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"] + configs['train_conf']['log_interval'] = ds_configs["steps_per_print"] + return configs + + +def wrap_cuda_model(args, model): + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) + world_size = int(os.environ.get('WORLD_SIZE', 1)) + if args.train_engine == "torch_ddp": # native pytorch ddp + assert (torch.cuda.is_available()) + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) + else: + if int(os.environ.get('RANK', 0)) == 0: + logging.info("Estimating model states memory needs (zero2)...") + estimate_zero2_model_states_mem_needs_all_live( + model, + num_gpus_per_node=local_world_size, + num_nodes=world_size // local_world_size) + return model + + +def init_optimizer_and_scheduler(args, configs, model, gan): + if gan is False: + if configs['train_conf']['optim'] == 'adam': + optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim'] == 'adamw': + optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler = ConstantLR(optimizer) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + + # use deepspeed optimizer for speedup + if args.train_engine == "deepspeed": + def scheduler(opt): + return scheduler_type(opt, **configs['train_conf']['scheduler_conf']) + model, optimizer, _, scheduler = deepspeed.initialize( + args=args, + model=model, + optimizer=None, + lr_scheduler=scheduler, + model_parameters=model.parameters()) + + optimizer_d, scheduler_d = None, None + + else: + # currently we wrap generator and discriminator in one model, so we cannot use deepspeed + if configs['train_conf']['optim'] == 'adam': + optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim'] == 'adamw': + optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler = ConstantLR(optimizer) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + + if configs['train_conf']['optim_d'] == 'adam': + optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim_d'] == 'adamw': + optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler_d'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler_d = ConstantLR(optimizer_d) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + return model, optimizer, scheduler, optimizer_d, scheduler_d + + +def init_summarywriter(args): + writer = None + if int(os.environ.get('RANK', 0)) == 0: + os.makedirs(args.model_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + return writer + + +def save_model(model, model_name, info_dict): + rank = int(os.environ.get('RANK', 0)) + model_dir = info_dict["model_dir"] + save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name)) + + if info_dict["train_engine"] == "torch_ddp": + if rank == 0: + torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path) + else: + with torch.no_grad(): + model.save_checkpoint(save_dir=model_dir, + tag=model_name, + client_state=info_dict) + if rank == 0: + info_path = re.sub('.pt$', '.yaml', save_model_path) + info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + with open(info_path, 'w') as fout: + data = yaml.dump(info_dict) + fout.write(data) + logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path)) + + +def cosyvoice_join(group_join, info_dict): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + + if info_dict["batch_idx"] != 0: + # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr + try: + dist.monitored_barrier(group=group_join, + timeout=group_join.options._timeout) + return False + except RuntimeError as e: + logging.info("Detected uneven workload distribution: {}\n".format(e) + + "Break current worker to manually join all workers, " + + "world_size {}, current rank {}, current local_rank {}\n". + format(world_size, rank, local_rank)) + return True + else: + return False + + +def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None): + device = int(os.environ.get('LOCAL_RANK', 0)) + + dtype = info_dict["dtype"] + if dtype == "fp16": + dtype = torch.float16 + elif dtype == "bf16": + dtype = torch.bfloat16 + else: # fp32 + dtype = torch.float32 + + if info_dict['train_engine'] == 'torch_ddp': + autocast = torch.cuda.amp.autocast(enabled=scaler is not None) + else: + autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False) + + with autocast: + info_dict['loss_dict'] = model(batch, device) + if ref_model and dpo_loss: + chosen_logps = info_dict['loss_dict']["chosen_logps"] + rejected_logps = info_dict['loss_dict']["rejected_logps"] + sft_loss = info_dict['loss_dict']['loss'] + with torch.no_grad(): + ref_model = ref_model.to(device) + ref_loss_dict = ref_model(batch, device) + reference_chosen_logps = ref_loss_dict["chosen_logps"] + reference_rejected_logps = ref_loss_dict["rejected_logps"] + preference_loss, chosen_reward, reject_reward = dpo_loss( + chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + dpo_acc = (chosen_reward > reject_reward).float().mean() + info_dict['loss_dict']["loss"] = preference_loss + sft_loss + info_dict['loss_dict']["sft_loss"] = sft_loss + info_dict['loss_dict']["dpo_loss"] = preference_loss + info_dict['loss_dict']["dpo_acc"] = dpo_acc + info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean() + info_dict['loss_dict']["reject_reward"] = reject_reward.mean() + return info_dict + + +def batch_backward(model, scaler, info_dict): + if info_dict["train_engine"] == "deepspeed": + scaled_loss = model.backward(info_dict['loss_dict']['loss']) + else: + scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad'] + if scaler is not None: + scaler.scale(scaled_loss).backward() + else: + scaled_loss.backward() + + info_dict['loss_dict']['loss'] = scaled_loss + return info_dict + + +def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): + grad_norm = 0.0 + if info_dict['train_engine'] == "deepspeed": + info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary() + model.step() + grad_norm = model.get_global_grad_norm() + elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0: + # Use mixed precision training + if scaler is not None: + scaler.unscale_(optimizer) + grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) + # We don't check grad here since that if the gradient + # has inf/nan values, scaler.step will skip + # optimizer.step(). + if torch.isfinite(grad_norm): + scaler.step(optimizer) + scaler.update() + else: + grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) + if torch.isfinite(grad_norm): + optimizer.step() + optimizer.zero_grad() + scheduler.step() + info_dict["lr"] = optimizer.param_groups[0]['lr'] + info_dict["grad_norm"] = grad_norm + return info_dict + + +def log_per_step(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict.get('epoch', 0) + step = info_dict["step"] + batch_idx = info_dict["batch_idx"] + loss_dict = info_dict['loss_dict'] + rank = int(os.environ.get('RANK', 0)) + + # only rank 0 write to tensorboard to avoid multi-process write + if writer is not None: + if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \ + (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0): + for k in ['epoch', 'lr', 'grad_norm']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1) + + # TRAIN & CV, Shell log (stdout) + if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0: + log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1) + for name, value in loss_dict.items(): + log_str += '{} {:.6f} '.format(name, value) + if tag == "TRAIN": + log_str += 'lr {:.8f} grad_norm {:.6f}'.format( + info_dict["lr"], info_dict['grad_norm']) + log_str += ' rank {}'.format(rank) + logging.debug(log_str) + + +def log_per_save(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict["epoch"] + step = info_dict["step"] + loss_dict = info_dict["loss_dict"] + lr = info_dict['lr'] + rank = int(os.environ.get('RANK', 0)) + logging.info( + 'Epoch {} Step {} CV info lr {} {} rank {}'.format( + epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()]))) + + if writer is not None: + for k in ['epoch', 'lr']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1) diff --git a/examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml b/examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml new file mode 100644 index 0000000..d811026 --- /dev/null +++ b/examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml @@ -0,0 +1,226 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1986] +__set_seed2: !apply:numpy.random.seed [1986] +__set_seed3: !apply:torch.manual_seed [1986] +__set_seed4: !apply:torch.cuda.manual_seed_all [1986] + +# fixed params +sample_rate: 24000 # 16000 for llm, 24000 for cfm +llm_input_size: 896 +llm_output_size: 896 +spk_embed_dim: 192 +qwen_pretrain_path: 'CosyVoice2-0.5B/CosyVoice-BlankEN' + +# model params +# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. +# for system/third_party class/function, we do not require this. +llm: !new:cosyvoice.llm.llm_dpo.Qwen2LM + llm_input_size: !ref + llm_output_size: !ref + speech_token_size: 6561 + length_normalized_loss: True + lsm_weight: 0 + dpo: True + llm: !new:cosyvoice.llm.llm.Qwen2Encoder + pretrain_path: !ref + sampling: !name:cosyvoice.utils.common.ras_sampling + top_p: 0.8 + top_k: 25 + win_size: 10 + tau_r: 0.1 +flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec + input_size: 512 + output_size: 80 + spk_embed_dim: !ref + output_type: 'mel' + vocab_size: 6561 + input_frame_rate: 25 + only_mask_loss: True + token_mel_ratio: 2 + pre_lookahead_len: 3 + encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder + output_size: 512 + attention_heads: 8 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + normalize_before: True + input_layer: 'linear' + pos_enc_layer_type: 'rel_pos_espnet' + selfattention_layer_type: 'rel_selfattn' + input_size: 512 + use_cnn_module: False + macaron_style: False + decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM + in_channels: 240 + n_spks: 1 + spk_emb_dim: 80 + cfm_params: !new:omegaconf.DictConfig + content: + sigma_min: 1e-06 + solver: 'euler' + t_scheduler: 'cosine' + training_cfg_rate: 0.2 + inference_cfg_rate: 0.7 + reg_loss_type: 'l1' + estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder + in_channels: 320 + out_channels: 80 + causal: True + channels: [256] + dropout: 0.0 + attention_head_dim: 64 + n_blocks: 4 + num_mid_blocks: 12 + num_heads: 8 + act_fn: 'gelu' + +hift: !new:cosyvoice.hifigan.generator.HiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 5, 3] + upsample_kernel_sizes: [16, 11, 7] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 + +# gan related module +mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 80 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: null + center: False +hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan + generator: !ref + discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator + mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator + mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator + mel_spec_transform: [ + !ref + ] + +# processor functions +parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener +get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe + multilingual: True + num_languages: 100 + language: 'en' + task: 'transcribe' +allowed_special: 'all' +tokenize: !name:cosyvoice.dataset.processor.tokenize + get_tokenizer: !ref + allowed_special: !ref +filter: !name:cosyvoice.dataset.processor.filter + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 +resample: !name:cosyvoice.dataset.processor.resample + resample_rate: !ref +truncate: !name:cosyvoice.dataset.processor.truncate + truncate_length: 24576 # must be a multiplier of hop_size +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 80 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: 8000 + center: False +compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank + feat_extractor: !ref +compute_f0: !name:cosyvoice.dataset.processor.compute_f0 + sample_rate: !ref + hop_size: 256 +parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding + normalize: True +shuffle: !name:cosyvoice.dataset.processor.shuffle + shuffle_size: 1000 +sort: !name:cosyvoice.dataset.processor.sort + sort_size: 500 # sort_size should be less than shuffle_size +batch: !name:cosyvoice.dataset.processor.batch + batch_type: 'dynamic' + max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g +padding: !name:cosyvoice.dataset.processor.padding + use_spk_embedding: True # change to True during sft + dpo: True + +# dataset processor pipeline +data_pipeline: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] +data_pipeline_gan: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] + +# llm flow train conf +train_conf: + optim: adam + optim_conf: + lr: 0.00001 # change to 1e-5 during sft + scheduler: warmuplr # change to constantlr during sft + scheduler_conf: + warmup_steps: 25000 + max_epoch: 200 + grad_clip: 5 + accum_grad: 2 + log_interval: 100 + save_per_step: -1 + +# gan train conf +train_conf_gan: + optim: adam + optim_conf: + lr: 0.0002 # use small lr for gan training + scheduler: constantlr + optim_d: adam + optim_conf_d: + lr: 0.0002 # use small lr for gan training + scheduler_d: constantlr + max_epoch: 200 + grad_clip: 5 + accum_grad: 1 # in gan training, accum_grad must be 1 + log_interval: 100 + save_per_step: -1 \ No newline at end of file diff --git a/tools/make_parquet_list_dpo.py b/tools/make_parquet_list_dpo.py new file mode 100755 index 0000000..c6ee6f5 --- /dev/null +++ b/tools/make_parquet_list_dpo.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import json +from tqdm import tqdm +import pandas as pd +import multiprocessing +import time +import torch + + +def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file): + start_time = time.time() + data_list = [] + for utt in tqdm(utt_list): + data = open(utt2wav[utt], 'rb').read() + data_list.append(data) + wav_list = [utt2wav[utt] for utt in utt_list] + text_list = [utt2text[utt] for utt in utt_list] + spk_list = [utt2spk[utt] for utt in utt_list] + uttembedding_list = [utt2embedding[utt] for utt in utt_list] + spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list] + speech_token_list = [utt2speech_token[utt] for utt in utt_list] + if utt2reject_speech_token: + reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list] + + # 保存到parquet,utt2parquet_file,spk2parquet_file + df = pd.DataFrame() + df['utt'] = utt_list + df['wav'] = wav_list + df['audio_data'] = data_list + df['text'] = text_list + df['spk'] = spk_list + df['utt_embedding'] = uttembedding_list + df['spk_embedding'] = spkembedding_list + df['speech_token'] = speech_token_list + if utt2reject_speech_token: + df['reject_speech_token'] = reject_speech_token_list + df.to_parquet(parquet_file) + with open(utt2parquet_file, 'w') as f: + json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2) + with open(spk2parquet_file, 'w') as f: + json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2) + logging.info('spend time {}'.format(time.time() - start_time)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--num_utts_per_parquet', + type=int, + default=1000, + help='num utts per parquet') + parser.add_argument('--num_processes', + type=int, + default=1, + help='num processes for make parquets') + parser.add_argument('--src_dir', + type=str) + parser.add_argument('--des_dir', + type=str) + parser.add_argument('--dpo', + action='store_true', + default=False, + help='Use Direct Preference Optimization') + args = parser.parse_args() + + utt2wav, utt2text, utt2spk = {}, {}, {} + with open('{}/wav.scp'.format(args.src_dir)) as f: + for l in f: + l = l.replace('\n', '').split() + utt2wav[l[0]] = l[1] + with open('{}/text'.format(args.src_dir)) as f: + for l in f: + l = l.replace('\n', '').split() + utt2text[l[0]] = ' '.join(l[1:]) + with open('{}/utt2spk'.format(args.src_dir)) as f: + for l in f: + l = l.replace('\n', '').split() + utt2spk[l[0]] = l[1] + utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) + spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) + utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) + if args.dpo: + utt2reject_speech_token = torch.load('{}/utt2reject_speech_token.pt'.format(args.src_dir)) + else: + utt2reject_speech_token = None + utts = list(utt2wav.keys()) + + # Using process pool to speedup + pool = multiprocessing.Pool(processes=args.num_processes) + parquet_list, utt2parquet_list, spk2parquet_list = [], [], [] + for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)): + parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i)) + utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i)) + spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i)) + parquet_list.append(parquet_file) + utt2parquet_list.append(utt2parquet_file) + spk2parquet_list.append(spk2parquet_file) + pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file)) + pool.close() + pool.join() + + with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \ + open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \ + open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3: + for name in parquet_list: + f1.write(name + '\n') + for name in utt2parquet_list: + f2.write(name + '\n') + for name in spk2parquet_list: + f3.write(name + '\n') From 6d876f573cff4a665ca3010a805ed2589b78e364 Mon Sep 17 00:00:00 2001 From: ShengqiangLi Date: Mon, 17 Feb 2025 12:04:48 +0800 Subject: [PATCH 2/6] feat: Support DPO --- cosyvoice/bin/train_dpo.py | 187 ++++++ cosyvoice/dataset/processor_dpo.py | 443 ++++++++++++++ cosyvoice/llm/llm_dpo.py | 556 ++++++++++++++++++ cosyvoice/utils/executor_dpo.py | 184 ++++++ cosyvoice/utils/losses_dpo.py | 57 ++ cosyvoice/utils/train_utils_dpo.py | 364 ++++++++++++ .../cosyvoice/conf/cosyvoice_dpo.yaml | 226 +++++++ tools/make_parquet_list_dpo.py | 125 ++++ 8 files changed, 2142 insertions(+) create mode 100644 cosyvoice/bin/train_dpo.py create mode 100644 cosyvoice/dataset/processor_dpo.py create mode 100644 cosyvoice/llm/llm_dpo.py create mode 100644 cosyvoice/utils/executor_dpo.py create mode 100644 cosyvoice/utils/losses_dpo.py create mode 100644 cosyvoice/utils/train_utils_dpo.py create mode 100644 examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml create mode 100755 tools/make_parquet_list_dpo.py diff --git a/cosyvoice/bin/train_dpo.py b/cosyvoice/bin/train_dpo.py new file mode 100644 index 0000000..b5b282f --- /dev/null +++ b/cosyvoice/bin/train_dpo.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import argparse +import datetime +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +from copy import deepcopy +import os +import torch +import torch.distributed as dist +import deepspeed + +from hyperpyyaml import load_hyperpyyaml + +from torch.distributed.elastic.multiprocessing.errors import record + +from cosyvoice.utils.executor_dpo import Executor +from cosyvoice.utils.train_utils_dpo import ( + init_distributed, + init_dataset_and_dataloader, + init_optimizer_and_scheduler, + init_summarywriter, save_model, + wrap_cuda_model, check_modify_and_save_config) + + +def get_args(): + parser = argparse.ArgumentParser(description='training your network') + parser.add_argument('--train_engine', + default='torch_ddp', + choices=['torch_ddp', 'deepspeed'], + help='Engine for paralleled training') + parser.add_argument('--model', required=True, help='model which will be trained') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--train_data', required=True, help='train data file') + parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--checkpoint', help='checkpoint model') + parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--tensorboard_dir', + default='tensorboard', + help='tensorboard log dir') + parser.add_argument('--ddp.dist_backend', + dest='dist_backend', + default='nccl', + choices=['nccl', 'gloo'], + help='distributed backend') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--use_amp', + action='store_true', + default=False, + help='Use automatic mixed precision training') + parser.add_argument('--deepspeed.save_states', + dest='save_states', + default='model_only', + choices=['model_only', 'model+optimizer'], + help='save model/optimizer states') + parser.add_argument('--timeout', + default=60, + type=int, + help='timeout (in seconds) of cosyvoice_join.') + parser.add_argument('--dpo', + action='store_true', + default=False, + help='Use Direct Preference Optimization') + parser.add_argument('--beta', + default=0.01, + type=float, + help='beta of dpo training') + parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args() + return args + + +@record +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + # gan train has some special initialization logic + gan = True if args.model == 'hifigan' else False + + override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model} + if gan is True: + override_dict.pop('hift') + with open(args.config, 'r') as f: + configs = load_hyperpyyaml(f, overrides=override_dict) + if gan is True: + configs['train_conf'] = configs['train_conf_gan'] + configs['train_conf'].update(vars(args)) + + # Init env for ddp + init_distributed(args) + + # Get dataset & dataloader + train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ + init_dataset_and_dataloader(args, configs, gan) + + # Do some sanity checks and save config to arsg.model_dir + configs = check_modify_and_save_config(args, configs) + + # Tensorboard summary + writer = init_summarywriter(args) + + # load checkpoint + model = configs[args.model] + ref_model = None + if args.dpo: + ref_model = deepcopy(model) + start_step, start_epoch = 0, -1 + if args.checkpoint is not None: + if os.path.exists(args.checkpoint): + state_dict = torch.load(args.checkpoint, map_location='cpu') + model.load_state_dict(state_dict, strict=False) + if args.dpo: + ref_model.load_state_dict(state_dict, strict=False) + if 'step' in state_dict: + start_step = state_dict['step'] + if 'epoch' in state_dict: + start_epoch = state_dict['epoch'] + else: + logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint)) + + # Dispatch model from cpu to gpu + model = wrap_cuda_model(args, model) + if args.dpo: + ref_model = wrap_cuda_model(args, ref_model) + + # Get optimizer & scheduler + model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan) + if args.dpo: + ref_model, _, _, _, _ = init_optimizer_and_scheduler(args, configs, ref_model, gan) + scheduler.set_step(start_step) + if scheduler_d is not None: + scheduler_d.set_step(start_step) + + # Save init checkpoints + info_dict = deepcopy(configs['train_conf']) + info_dict['step'] = start_step + info_dict['epoch'] = start_epoch + save_model(model, 'init', info_dict) + + # Get executor + executor = Executor(gan=gan, dpo=args.dpo, beta=args.beta) + executor.step = start_step + + # Init scaler, used for pytorch amp mixed precision training + scaler = torch.cuda.amp.GradScaler() if args.use_amp else None + print('start step {} start epoch {}'.format(start_step, start_epoch)) + # Start training loop + for epoch in range(start_epoch + 1, info_dict['max_epoch']): + executor.epoch = epoch + train_dataset.set_epoch(epoch) + dist.barrier() + group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) + if gan is True: + executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, + writer, info_dict, scaler, group_join) + else: + executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model) + dist.destroy_process_group(group_join) + + +if __name__ == '__main__': + main() diff --git a/cosyvoice/dataset/processor_dpo.py b/cosyvoice/dataset/processor_dpo.py new file mode 100644 index 0000000..719b474 --- /dev/null +++ b/cosyvoice/dataset/processor_dpo.py @@ -0,0 +1,443 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random + +import pyarrow.parquet as pq +from io import BytesIO +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F +import pyworld as pw + + +AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} + + +def parquet_opener(data, mode='train', tts_data={}): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + url = sample['src'] + try: + for df in pq.ParquetFile(url).iter_batches(batch_size=64): + df = df.to_pandas() + for i in range(len(df)): + if mode == 'inference' and df.loc[i, 'utt'] not in tts_data: + continue + sample.update(dict(df.loc[i])) + if mode == 'train': + # NOTE do not return sample directly, must initialize a new dict + yield {**sample} + else: + for index, text in enumerate(tts_data[df.loc[i, 'utt']]): + yield {**sample, 'tts_index': index, 'tts_text': text} + except Exception as ex: + logging.warning('Failed to open {}, ex info {}'.format(url, ex)) + + +def filter(data, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1, + mode='train'): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) + sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) + del sample['audio_data'] + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['text_token']) < token_min_length: + continue + if len(sample['text_token']) > token_max_length: + continue + if len(sample['speech_token']) == 0: + continue + if num_frames != 0: + if len(sample['text_token']) / num_frames < min_output_input_ratio: + continue + if len(sample['text_token']) / num_frames > max_output_input_ratio: + continue + yield sample + + +def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + sample_rate = sample['sample_rate'] + waveform = sample['speech'] + if sample_rate != resample_rate: + if sample_rate < min_sample_rate: + continue + sample['sample_rate'] = resample_rate + sample['speech'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + max_val = sample['speech'].abs().max() + if max_val > 1: + sample['speech'] /= max_val + yield sample + + +def truncate(data, truncate_length=24576, mode='train'): + """ Truncate data. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + truncate_length: truncate length + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + waveform = sample['speech'] + if waveform.shape[1] > truncate_length: + start = random.randint(0, waveform.shape[1] - truncate_length) + waveform = waveform[:, start: start + truncate_length] + else: + waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1) + sample['speech'] = waveform + yield sample + + +def compute_fbank(data, + feat_extractor, + mode='train'): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + assert 'utt' in sample + assert 'text_token' in sample + waveform = sample['speech'] + mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) + sample['speech_feat'] = mat + yield sample + + +def compute_f0(data, sample_rate, hop_size, mode='train'): + """ Extract f0 + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + frame_period = hop_size * 1000 / sample_rate + for sample in data: + assert 'sample_rate' in sample + assert 'speech' in sample + assert 'utt' in sample + assert 'text_token' in sample + waveform = sample['speech'] + _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) + if sum(_f0 != 0) < 5: # this happens when the algorithm fails + _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio + f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate) + f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1) + sample['pitch_feat'] = f0 + yield sample + + +def parse_embedding(data, normalize, mode='train'): + """ Parse utt_embedding/spk_embedding + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32) + sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32) + if normalize: + sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0) + sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0) + yield sample + + +def tokenize(data, get_tokenizer, allowed_special, mode='train'): + """ Decode text to chars or BPE + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + tokenizer = get_tokenizer() + for sample in data: + assert 'text' in sample + sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) + if mode == 'inference': + sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special) + yield sample + + +def shuffle(data, shuffle_size=10000, mode='train'): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def sort(data, sort_size=500, mode='train'): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + data: Iterable[{key, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{key, feat, label}] + """ + + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['speech_feat'].size(0)) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['speech_feat'].size(0)) + for x in buf: + yield x + + +def static_batch(data, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + + +def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + data: Iterable[{key, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in data: + assert 'speech_feat' in sample + assert isinstance(sample['speech_feat'], torch.Tensor) + new_sample_frames = sample['speech_feat'].size(0) + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + if frames_after_padding > max_frames_in_batch: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'): + """ Wrapper for static/dynamic batch + """ + if mode == 'inference': + return static_batch(data, 1) + else: + if batch_type == 'static': + return static_batch(data, batch_size) + elif batch_type == 'dynamic': + return dynamic_batch(data, max_frames_in_batch) + else: + logging.fatal('Unsupported batch type {}'.format(batch_type)) + + +def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], + dtype=torch.int32) + order = torch.argsort(speech_feat_len, descending=True) + + utts = [sample[i]['utt'] for i in order] + speech = [sample[i]['speech'].squeeze(dim=0) for i in order] + speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32) + speech = pad_sequence(speech, batch_first=True, padding_value=0) + speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] + speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) + speech_token = pad_sequence(speech_token, + batch_first=True, + padding_value=0) + speech_feat = [sample[i]['speech_feat'] for i in order] + speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) + speech_feat = pad_sequence(speech_feat, + batch_first=True, + padding_value=0) + text = [sample[i]['text'] for i in order] + text_token = [torch.tensor(sample[i]['text_token']) for i in order] + text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) + text_token = pad_sequence(text_token, batch_first=True, padding_value=0) + utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0) + spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) + batch = { + "utts": utts, + "speech": speech, + "speech_len": speech_len, + "speech_token": speech_token, + "speech_token_len": speech_token_len, + "speech_feat": speech_feat, + "speech_feat_len": speech_feat_len, + "text": text, + "text_token": text_token, + "text_token_len": text_token_len, + "utt_embedding": utt_embedding, + "spk_embedding": spk_embedding, + } + if dpo: + reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order] + reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32) + reject_speech_token = pad_sequence(reject_speech_token, + batch_first=True, + padding_value=0) + batch['reject_speech_token'] = reject_speech_token + batch['reject_speech_token_len'] = reject_speech_token_len + if gan is True: + # in gan train, we need pitch_feat + pitch_feat = [sample[i]['pitch_feat'] for i in order] + pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32) + pitch_feat = pad_sequence(pitch_feat, + batch_first=True, + padding_value=0) + batch["pitch_feat"] = pitch_feat + batch["pitch_feat_len"] = pitch_feat_len + else: + # only gan train needs speech, delete it to save memory + del batch["speech"] + del batch["speech_len"] + if mode == 'inference': + tts_text = [sample[i]['tts_text'] for i in order] + tts_index = [sample[i]['tts_index'] for i in order] + tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] + tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) + tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) + batch.update({'tts_text': tts_text, + 'tts_index': tts_index, + 'tts_text_token': tts_text_token, + 'tts_text_token_len': tts_text_token_len}) + if use_spk_embedding is True: + batch["embedding"] = batch["spk_embedding"] + else: + batch["embedding"] = batch["utt_embedding"] + yield batch diff --git a/cosyvoice/llm/llm_dpo.py b/cosyvoice/llm/llm_dpo.py new file mode 100644 index 0000000..6e0dc2d --- /dev/null +++ b/cosyvoice/llm/llm_dpo.py @@ -0,0 +1,556 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Callable, List, Generator +import torch +from torch import nn +import torch.nn.functional as F +from transformers import Qwen2ForCausalLM +from torch.nn.utils.rnn import pad_sequence, unpad_sequence +from cosyvoice.utils.common import IGNORE_ID +from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss +from cosyvoice.utils.common import th_accuracy +from cosyvoice.utils.file_utils import logging +from cosyvoice.utils.mask import make_pad_mask + + +class TransformerLM(torch.nn.Module): + def __init__( + self, + text_encoder_input_size: int, + llm_input_size: int, + llm_output_size: int, + text_token_size: int, + speech_token_size: int, + text_encoder: torch.nn.Module, + llm: torch.nn.Module, + sampling: Callable, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + spk_embed_dim: int = 192, + ): + super().__init__() + self.llm_input_size = llm_input_size + self.speech_token_size = speech_token_size + # 1. build text token inputs related modules + self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size) + self.text_encoder = text_encoder + self.text_encoder_affine_layer = nn.Linear( + self.text_encoder.output_size(), + llm_input_size + ) + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1) + self.criterion_ce = LabelSmoothingLoss( + size=speech_token_size + 1, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size) + self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size) + + # 4. sampling method + self.sampling = sampling + + def encode( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + ): + encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out = self.text_encoder_affine_layer(encoder_out) + return encoder_out, encoder_out_lens + + def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len): + text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) + speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True) + lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) + for i in range(len(text_token))] + lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32) + lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID) + return lm_input, lm_input_len + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ + Args: + text: (B, L, D) + text_lengths: (B,) + audio: (B, T, N) or (B, T) + audio_lengths: (B,) + """ + text_token = batch['text_token'].to(device) + text_token_len = batch['text_token_len'].to(device) + speech_token = batch['speech_token'].to(device) + speech_token_len = batch['speech_token_len'].to(device) + embedding = batch['embedding'].to(device) + + # 1. prepare llm_target + lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + + [self.speech_token_size]) for i in range(text_token.size(0))] + lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device) + + # 1. encode text_token + text_token = self.text_embedding(text_token) + text_token, text_token_len = self.encode(text_token, text_token_len) + + # 2. embedding projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + embedding = embedding.unsqueeze(1) + + # 3. eos and task_id + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + # 4. encode speech_token + speech_token = self.speech_embedding(speech_token) + + # 5. unpad and pad + lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, + task_id_emb, speech_token, speech_token_len) + + # 6. run lm forward + lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device)) + logits = self.llm_decoder(lm_output) + loss = self.criterion_ce(logits, lm_target) + acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID) + return {'loss': loss, 'acc': acc} + + def sampling_ids( + self, + weighted_scores: torch.Tensor, + decoded_tokens: List, + sampling: int, + ignore_eos: bool = True, + ): + num_trials, max_trials = 0, 100 + while True: + top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) + if (not ignore_eos) or (self.speech_token_size not in top_ids): + break + num_trials += 1 + if num_trials > max_trials: + raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials)) + return top_ids + + @torch.inference_mode() + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + if self.fp16 is True: + embedding = embedding.half() + + device = text.device + text = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.text_embedding(text) + + # 1. encode text + text, text_len = self.encode(text, text_len) + + # 2. encode embedding + if embedding.shape[0] != 0: + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + embedding = embedding.unsqueeze(dim=1) + else: + embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype) + + # 3. concat llm_input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) + + # 4. cal min/max_length + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + + # 5. step by step decode + out_tokens = [] + offset = 0 + att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device) + for i in range(max_len): + y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1, + att_cache=att_cache, cnn_cache=cnn_cache, + att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), + device=lm_input.device)).to(torch.bool)) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + # force continue decode first token + if i == 0: + logp[:, self.speech_token_size] = -float('inf') + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + offset += lm_input.size(1) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + +class Qwen2Encoder(torch.nn.Module): + def __init__(self, pretrain_path): + super().__init__() + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) + + def forward_one_step(self, xs, masks, cache=None): + input_masks = masks[:, -1, :] + outs = self.model( + inputs_embeds=xs, + attention_mask=input_masks, + output_hidden_states=True, + return_dict=True, + use_cache=True, + past_key_values=cache, + ) + xs = outs.hidden_states[-1] + new_cache = outs.past_key_values + return xs, new_cache + + +class Qwen2LM(TransformerLM): + def __init__( + self, + llm_input_size: int, + llm_output_size: int, + speech_token_size: int, + llm: torch.nn.Module, + sampling: Callable, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + mix_ratio: List[int] = [5, 15], + dpo: bool = False, + ): + torch.nn.Module.__init__(self) + self.llm_input_size = llm_input_size + self.llm_output_size = llm_output_size + self.speech_token_size = speech_token_size + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.fill_token = 2 + + self.llm_embedding = torch.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3) + self.criterion_ce = LabelSmoothingLoss( + size=speech_token_size + 3, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size) + + # 4. sampling method + self.sampling = sampling + self.mix_ratio = mix_ratio + + # 5. [Optional] set dpo + self.dpo = dpo + + + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + text_token = batch['text_token'].to(device) + text_token_len = batch['text_token_len'].to(device) + speech_token = batch['speech_token'].to(device) + speech_token_len = batch['speech_token_len'].to(device) + if self.dpo: + reject_speech_token = batch['reject_speech_token'].to(device) + reject_speech_token_len = batch['reject_speech_token_len'].to(device) + # 1. prepare llm_target + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + + [self.speech_token_size]) for i in range(text_token.size(0))] + if self.dpo: + reject_target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + reject_speech_token[i, :reject_speech_token_len[i]].tolist() + + [self.speech_token_size]) for i in range(text_token.size(0))] + target_ids.extend(reject_target_ids) + target_ids = pad_sequence(target_ids, batch_first=True, padding_value=IGNORE_ID).to(device) + + # 2. speech token projection + speech_emb = self.speech_embedding(speech_token) + if self.dpo: + reject_speech_emb = self.speech_embedding(reject_speech_token) + + # 3. text token projection + text_token_lst = unpad_sequence(text_token, text_token_len, batch_first=True) + text_emb = [self.llm.model.model.embed_tokens(y) for y in text_token_lst] + + # 4. prepare llm_input + speech_emb = unpad_sequence(speech_emb, speech_token_len.cpu(), batch_first=True) + input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), speech_emb[i]], dim=0) + for i in range(len(text_emb))] + if self.dpo: + reject_speech_emb = unpad_sequence(reject_speech_emb, reject_speech_token_len.cpu(), batch_first=True) + reject_input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), reject_speech_emb[i]], dim=0) + for i in range(len(text_emb))] + input_emb.extend(reject_input_emb) + input_emb_lengths = torch.tensor([i.size(0) for i in input_emb], dtype=torch.int32).to(device) + input_emb = pad_sequence(input_emb, batch_first=True, padding_value=IGNORE_ID).to(device) + + attention_mask = ~make_pad_mask(input_emb_lengths) + + result = self.llm.model( + inputs_embeds=input_emb, + attention_mask=attention_mask, + return_dict=True + ) + hidden_states = result.hidden_states + logits = self.llm_decoder(hidden_states[-1]) + loss = self.criterion_ce(logits[: speech_token.shape[0]], target_ids[: speech_token.shape[0]]) + acc = th_accuracy( + logits[: speech_token.shape[0]].view(-1, self.speech_token_size + 3), + target_ids[: speech_token.shape[0]], + ignore_label=IGNORE_ID, + ) + if not self.dpo: + return { + "loss": loss, + "acc": acc, + } + else: + all_logps_sum, all_logps_mean = self.get_batch_logps( + logits, target_ids, attention_mask, text_token_len, average_log_prob=False, ignore_id=IGNORE_ID + ) + chosen_logps = all_logps_sum[: speech_token.shape[0]] + rejected_logps = all_logps_sum[speech_token.shape[0]:] + return { + "loss": loss, + "acc": acc, + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps + } + + + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + attention_mask, + prompt_token_lens, + average_log_prob: bool = False, + ignore_id: int = -1, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert average_log_prob == False + assert logits.shape[:-1] == labels.shape + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_masks = attention_mask.clone().bool() + # mask prompts + for mask, text_token_len in zip(loss_masks, prompt_token_lens): + mask[:text_token_len + 1] = False + loss_masks = loss_masks[:, 1:] + labels[loss_masks == False] = 0 + # dummy token; we'll ignore the losses on these tokens later + ignore = labels == ignore_id + labels = labels.masked_fill(ignore, 0) # avoid -1 index + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) # (bs, time,) + logprobs_sums = (per_token_logps * loss_masks).sum(-1) + logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1) + return logprobs_sums, logprobs_means + + + @torch.inference_mode() + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + device = text.device + text = torch.concat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.llm.model.model.embed_tokens(text) + + # 3. concat llm_input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) + + # 4. cal min/max_length + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + + # 5. step by step decode + out_tokens = [] + cache = None + for i in range(max_len): + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + if top_ids > self.speech_token_size: + continue + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + @torch.inference_mode() + def inference_bistream( + self, + text: Generator, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + + device = prompt_text.device + # 1. prepare input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb], dim=1) + + # 2. iterate text + out_tokens = [] + cache = None + # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5 + text_cache = self.llm.model.model.embed_tokens(prompt_text) + next_fill_index = -1 + for this_text in text: + text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1) + # prompt_speech_token_emb not empty, try append to lm_input + while prompt_speech_token_emb.size(1) != 0: + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]] + logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1))) + lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1) + text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:] + else: + logging.info('not enough text token to decode, wait for more') + break + # no prompt_speech_token_emb remain, can decode some speech token + if prompt_speech_token_emb.size(1) == 0: + if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1): + logging.info('get fill token, need to append more text token') + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text = text_cache[:, :self.mix_ratio[0]] + logging.info('append {} text token'.format(lm_input_text.size(1))) + if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2: + lm_input = lm_input_text + else: + lm_input = torch.concat([lm_input, lm_input_text], dim=1) + text_cache = text_cache[:, self.mix_ratio[0]:] + else: + logging.info('not enough text token to decode, wait for more') + continue + while True: + seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + if next_fill_index != -1 and len(out_tokens) == next_fill_index: + top_ids = self.speech_token_size + 2 + next_fill_index += (self.mix_ratio[1] + 1) + else: + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item() + if top_ids == self.speech_token_size + 2: + next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1 + logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index)) + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size + 2: + break + else: + raise ValueError('should not get token {}'.format(top_ids)) + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + # 3. final decode + lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1) + logging.info('no more text token, decode until met eos') + while True: + seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item() + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size: + break + else: + raise ValueError('should not get token {}'.format(top_ids)) + # in stream mode, yield token one by one + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) diff --git a/cosyvoice/utils/executor_dpo.py b/cosyvoice/utils/executor_dpo.py new file mode 100644 index 0000000..89bb528 --- /dev/null +++ b/cosyvoice/utils/executor_dpo.py @@ -0,0 +1,184 @@ +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from contextlib import nullcontext +import os + +import torch +import torch.distributed as dist + +from cosyvoice.utils.train_utils_dpo import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join +from cosyvoice.utils.losses_dpo import DPOLoss + + +class Executor: + + def __init__(self, gan: bool = False, dpo: bool = False, beta: float = 0.01, label_smoothing: float = 0.0, ipo: bool = False): + self.gan = gan + self.step = 0 + self.epoch = 0 + self.rank = int(os.environ.get('RANK', 0)) + self.device = torch.device('cuda:{}'.format(self.rank)) + self.dpo = dpo + if self.dpo: + self.dpo_loss = DPOLoss(beta, label_smoothing, ipo) + else: + self.dpo_loss = None + + def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None): + ''' Train one epoch + ''' + + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) + logging.info('using accumulate grad, new batch size is {} times' + ' larger than before'.format(info_dict['accum_grad'])) + # A context manager to be used in conjunction with an instance of + # torch.nn.parallel.DistributedDataParallel to be able to train + # with uneven inputs across participating processes. + model.train() + if self.dpo: + assert ref_model is not None + ref_model.eval() + model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext + with model_context(): + for batch_idx, batch_dict in enumerate(train_data_loader): + info_dict["tag"] = "TRAIN" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + if cosyvoice_join(group_join, info_dict): + break + + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: + context = model.no_sync + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + + with context(): + info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model, self.dpo_loss) + info_dict = batch_backward(model, scaler, info_dict) + + info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) + log_per_step(writer, info_dict) + # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save + if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ + (batch_idx + 1) % info_dict["accum_grad"] == 0: + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False, ref_model=ref_model, dpo_loss=self.dpo_loss) + model.train() + if (batch_idx + 1) % info_dict["accum_grad"] == 0: + self.step += 1 + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True, ref_model=ref_model, dpo_loss=self.dpo_loss) + + def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, + writer, info_dict, scaler, group_join): + ''' Train one epoch + ''' + + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) + logging.info('using accumulate grad, new batch size is {} times' + ' larger than before'.format(info_dict['accum_grad'])) + # A context manager to be used in conjunction with an instance of + # torch.nn.parallel.DistributedDataParallel to be able to train + # with uneven inputs across participating processes. + model.train() + model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext + with model_context(): + for batch_idx, batch_dict in enumerate(train_data_loader): + info_dict["tag"] = "TRAIN" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + if cosyvoice_join(group_join, info_dict): + break + + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: + context = model.no_sync + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + + with context(): + batch_dict['turn'] = 'discriminator' + info_dict = batch_forward(model, batch_dict, scaler, info_dict) + info_dict = batch_backward(model, scaler, info_dict) + info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict) + optimizer.zero_grad() + log_per_step(writer, info_dict) + with context(): + batch_dict['turn'] = 'generator' + info_dict = batch_forward(model, batch_dict, scaler, info_dict) + info_dict = batch_backward(model, scaler, info_dict) + info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) + optimizer_d.zero_grad() + log_per_step(writer, info_dict) + # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save + if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ + (batch_idx + 1) % info_dict["accum_grad"] == 0: + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) + model.train() + if (batch_idx + 1) % info_dict["accum_grad"] == 0: + self.step += 1 + dist.barrier() + self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) + + @torch.inference_mode() + def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, ref_model=None, dpo_loss=None): + ''' Cross validation on + ''' + logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank)) + model.eval() + if self.dpo: + assert ref_model is not None + ref_model.eval() + total_num_utts, total_loss_dict = 0, {} # avoid division by 0 + for batch_idx, batch_dict in enumerate(cv_data_loader): + info_dict["tag"] = "CV" + info_dict["step"] = self.step + info_dict["epoch"] = self.epoch + info_dict["batch_idx"] = batch_idx + + num_utts = len(batch_dict["utts"]) + total_num_utts += num_utts + + if self.gan is True: + batch_dict['turn'] = 'generator' + info_dict = batch_forward(model, batch_dict, None, info_dict, ref_model, dpo_loss) + + for k, v in info_dict['loss_dict'].items(): + if k not in total_loss_dict: + total_loss_dict[k] = [] + total_loss_dict[k].append(v.item() * num_utts) + log_per_step(None, info_dict) + for k, v in total_loss_dict.items(): + total_loss_dict[k] = sum(v) / total_num_utts + info_dict['loss_dict'] = total_loss_dict + log_per_save(writer, info_dict) + model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) + save_model(model, model_name, info_dict) diff --git a/cosyvoice/utils/losses_dpo.py b/cosyvoice/utils/losses_dpo.py new file mode 100644 index 0000000..2429fdc --- /dev/null +++ b/cosyvoice/utils/losses_dpo.py @@ -0,0 +1,57 @@ +import torch +import torch.nn.functional as F +from typing import Tuple + + +def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): + loss = 0 + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + m_DG = torch.median((dr - dg)) + L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) + loss += tau - F.relu(tau - L_rel) + return loss + + +def mel_loss(real_speech, generated_speech, mel_transforms): + loss = 0 + for transform in mel_transforms: + mel_r = transform(real_speech) + mel_g = transform(generated_speech) + loss += F.l1_loss(mel_g, mel_r) + return loss + + +class DPOLoss(torch.nn.Module): + """ + DPO Loss + """ + + def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None: + super().__init__() + self.beta = beta + self.label_smoothing = label_smoothing + self.ipo = ipo + + def forward( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + logits = pi_logratios - ref_logratios + if self.ipo: + losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf + else: + # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + loss = losses.mean() + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return loss, chosen_rewards, rejected_rewards diff --git a/cosyvoice/utils/train_utils_dpo.py b/cosyvoice/utils/train_utils_dpo.py new file mode 100644 index 0000000..fa1529e --- /dev/null +++ b/cosyvoice/utils/train_utils_dpo.py @@ -0,0 +1,364 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2023 Horizon Inc. (authors: Xingchen Song) +# 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import torch +import json +import re +import datetime +import yaml + +import deepspeed +import torch.optim as optim +import torch.distributed as dist + +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader +from torch.nn.utils import clip_grad_norm_ + +from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live + +from cosyvoice.dataset.dataset import Dataset +from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR + + +def init_distributed(args): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + logging.info('training on multiple gpus, this gpu {}'.format(local_rank) + + ', rank {}, world_size {}'.format(rank, world_size)) + if args.train_engine == 'torch_ddp': + torch.cuda.set_device(local_rank) + dist.init_process_group(args.dist_backend) + else: + deepspeed.init_distributed(dist_backend=args.dist_backend) + return world_size, local_rank, rank + + +def init_dataset_and_dataloader(args, configs, gan): + data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline'] + train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True) + cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False) + + # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts + train_data_loader = DataLoader(train_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + cv_data_loader = DataLoader(cv_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + return train_dataset, cv_dataset, train_data_loader, cv_data_loader + + +def check_modify_and_save_config(args, configs): + if args.train_engine == "torch_ddp": + configs['train_conf']["dtype"] = 'fp32' + else: + with open(args.deepspeed_config, 'r') as fin: + ds_configs = json.load(fin) + if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]: + configs['train_conf']["dtype"] = "fp16" + elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]: + configs['train_conf']["dtype"] = "bf16" + else: + configs['train_conf']["dtype"] = "fp32" + assert ds_configs["train_micro_batch_size_per_gpu"] == 1 + # if use deepspeed, override ddp config + configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * + configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"]) + configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"] + configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"] + configs['train_conf']['log_interval'] = ds_configs["steps_per_print"] + return configs + + +def wrap_cuda_model(args, model): + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) + world_size = int(os.environ.get('WORLD_SIZE', 1)) + if args.train_engine == "torch_ddp": # native pytorch ddp + assert (torch.cuda.is_available()) + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) + else: + if int(os.environ.get('RANK', 0)) == 0: + logging.info("Estimating model states memory needs (zero2)...") + estimate_zero2_model_states_mem_needs_all_live( + model, + num_gpus_per_node=local_world_size, + num_nodes=world_size // local_world_size) + return model + + +def init_optimizer_and_scheduler(args, configs, model, gan): + if gan is False: + if configs['train_conf']['optim'] == 'adam': + optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim'] == 'adamw': + optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler = ConstantLR(optimizer) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + + # use deepspeed optimizer for speedup + if args.train_engine == "deepspeed": + def scheduler(opt): + return scheduler_type(opt, **configs['train_conf']['scheduler_conf']) + model, optimizer, _, scheduler = deepspeed.initialize( + args=args, + model=model, + optimizer=None, + lr_scheduler=scheduler, + model_parameters=model.parameters()) + + optimizer_d, scheduler_d = None, None + + else: + # currently we wrap generator and discriminator in one model, so we cannot use deepspeed + if configs['train_conf']['optim'] == 'adam': + optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim'] == 'adamw': + optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler = ConstantLR(optimizer) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + + if configs['train_conf']['optim_d'] == 'adam': + optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf']) + elif configs['train_conf']['optim_d'] == 'adamw': + optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['train_conf']) + + if configs['train_conf']['scheduler_d'] == 'warmuplr': + scheduler_type = WarmupLR + scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing + scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf']) + elif configs['train_conf']['scheduler'] == 'constantlr': + scheduler_type = ConstantLR + scheduler_d = ConstantLR(optimizer_d) + else: + raise ValueError("unknown scheduler: " + configs['train_conf']) + return model, optimizer, scheduler, optimizer_d, scheduler_d + + +def init_summarywriter(args): + writer = None + if int(os.environ.get('RANK', 0)) == 0: + os.makedirs(args.model_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + return writer + + +def save_model(model, model_name, info_dict): + rank = int(os.environ.get('RANK', 0)) + model_dir = info_dict["model_dir"] + save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name)) + + if info_dict["train_engine"] == "torch_ddp": + if rank == 0: + torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path) + else: + with torch.no_grad(): + model.save_checkpoint(save_dir=model_dir, + tag=model_name, + client_state=info_dict) + if rank == 0: + info_path = re.sub('.pt$', '.yaml', save_model_path) + info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + with open(info_path, 'w') as fout: + data = yaml.dump(info_dict) + fout.write(data) + logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path)) + + +def cosyvoice_join(group_join, info_dict): + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + + if info_dict["batch_idx"] != 0: + # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr + try: + dist.monitored_barrier(group=group_join, + timeout=group_join.options._timeout) + return False + except RuntimeError as e: + logging.info("Detected uneven workload distribution: {}\n".format(e) + + "Break current worker to manually join all workers, " + + "world_size {}, current rank {}, current local_rank {}\n". + format(world_size, rank, local_rank)) + return True + else: + return False + + +def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None): + device = int(os.environ.get('LOCAL_RANK', 0)) + + dtype = info_dict["dtype"] + if dtype == "fp16": + dtype = torch.float16 + elif dtype == "bf16": + dtype = torch.bfloat16 + else: # fp32 + dtype = torch.float32 + + if info_dict['train_engine'] == 'torch_ddp': + autocast = torch.cuda.amp.autocast(enabled=scaler is not None) + else: + autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False) + + with autocast: + info_dict['loss_dict'] = model(batch, device) + if ref_model and dpo_loss: + chosen_logps = info_dict['loss_dict']["chosen_logps"] + rejected_logps = info_dict['loss_dict']["rejected_logps"] + sft_loss = info_dict['loss_dict']['loss'] + with torch.no_grad(): + ref_model = ref_model.to(device) + ref_loss_dict = ref_model(batch, device) + reference_chosen_logps = ref_loss_dict["chosen_logps"] + reference_rejected_logps = ref_loss_dict["rejected_logps"] + preference_loss, chosen_reward, reject_reward = dpo_loss( + chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + dpo_acc = (chosen_reward > reject_reward).float().mean() + info_dict['loss_dict']["loss"] = preference_loss + sft_loss + info_dict['loss_dict']["sft_loss"] = sft_loss + info_dict['loss_dict']["dpo_loss"] = preference_loss + info_dict['loss_dict']["dpo_acc"] = dpo_acc + info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean() + info_dict['loss_dict']["reject_reward"] = reject_reward.mean() + return info_dict + + +def batch_backward(model, scaler, info_dict): + if info_dict["train_engine"] == "deepspeed": + scaled_loss = model.backward(info_dict['loss_dict']['loss']) + else: + scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad'] + if scaler is not None: + scaler.scale(scaled_loss).backward() + else: + scaled_loss.backward() + + info_dict['loss_dict']['loss'] = scaled_loss + return info_dict + + +def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): + grad_norm = 0.0 + if info_dict['train_engine'] == "deepspeed": + info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary() + model.step() + grad_norm = model.get_global_grad_norm() + elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0: + # Use mixed precision training + if scaler is not None: + scaler.unscale_(optimizer) + grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) + # We don't check grad here since that if the gradient + # has inf/nan values, scaler.step will skip + # optimizer.step(). + if torch.isfinite(grad_norm): + scaler.step(optimizer) + scaler.update() + else: + grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) + if torch.isfinite(grad_norm): + optimizer.step() + optimizer.zero_grad() + scheduler.step() + info_dict["lr"] = optimizer.param_groups[0]['lr'] + info_dict["grad_norm"] = grad_norm + return info_dict + + +def log_per_step(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict.get('epoch', 0) + step = info_dict["step"] + batch_idx = info_dict["batch_idx"] + loss_dict = info_dict['loss_dict'] + rank = int(os.environ.get('RANK', 0)) + + # only rank 0 write to tensorboard to avoid multi-process write + if writer is not None: + if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \ + (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0): + for k in ['epoch', 'lr', 'grad_norm']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1) + + # TRAIN & CV, Shell log (stdout) + if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0: + log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1) + for name, value in loss_dict.items(): + log_str += '{} {:.6f} '.format(name, value) + if tag == "TRAIN": + log_str += 'lr {:.8f} grad_norm {:.6f}'.format( + info_dict["lr"], info_dict['grad_norm']) + log_str += ' rank {}'.format(rank) + logging.debug(log_str) + + +def log_per_save(writer, info_dict): + tag = info_dict["tag"] + epoch = info_dict["epoch"] + step = info_dict["step"] + loss_dict = info_dict["loss_dict"] + lr = info_dict['lr'] + rank = int(os.environ.get('RANK', 0)) + logging.info( + 'Epoch {} Step {} CV info lr {} {} rank {}'.format( + epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()]))) + + if writer is not None: + for k in ['epoch', 'lr']: + writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1) + for k, v in loss_dict.items(): + writer.add_scalar('{}/{}'.format(tag, k), v, step + 1) diff --git a/examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml b/examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml new file mode 100644 index 0000000..d811026 --- /dev/null +++ b/examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml @@ -0,0 +1,226 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1986] +__set_seed2: !apply:numpy.random.seed [1986] +__set_seed3: !apply:torch.manual_seed [1986] +__set_seed4: !apply:torch.cuda.manual_seed_all [1986] + +# fixed params +sample_rate: 24000 # 16000 for llm, 24000 for cfm +llm_input_size: 896 +llm_output_size: 896 +spk_embed_dim: 192 +qwen_pretrain_path: 'CosyVoice2-0.5B/CosyVoice-BlankEN' + +# model params +# for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. +# for system/third_party class/function, we do not require this. +llm: !new:cosyvoice.llm.llm_dpo.Qwen2LM + llm_input_size: !ref + llm_output_size: !ref + speech_token_size: 6561 + length_normalized_loss: True + lsm_weight: 0 + dpo: True + llm: !new:cosyvoice.llm.llm.Qwen2Encoder + pretrain_path: !ref + sampling: !name:cosyvoice.utils.common.ras_sampling + top_p: 0.8 + top_k: 25 + win_size: 10 + tau_r: 0.1 +flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec + input_size: 512 + output_size: 80 + spk_embed_dim: !ref + output_type: 'mel' + vocab_size: 6561 + input_frame_rate: 25 + only_mask_loss: True + token_mel_ratio: 2 + pre_lookahead_len: 3 + encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder + output_size: 512 + attention_heads: 8 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + normalize_before: True + input_layer: 'linear' + pos_enc_layer_type: 'rel_pos_espnet' + selfattention_layer_type: 'rel_selfattn' + input_size: 512 + use_cnn_module: False + macaron_style: False + decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM + in_channels: 240 + n_spks: 1 + spk_emb_dim: 80 + cfm_params: !new:omegaconf.DictConfig + content: + sigma_min: 1e-06 + solver: 'euler' + t_scheduler: 'cosine' + training_cfg_rate: 0.2 + inference_cfg_rate: 0.7 + reg_loss_type: 'l1' + estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder + in_channels: 320 + out_channels: 80 + causal: True + channels: [256] + dropout: 0.0 + attention_head_dim: 64 + n_blocks: 4 + num_mid_blocks: 12 + num_heads: 8 + act_fn: 'gelu' + +hift: !new:cosyvoice.hifigan.generator.HiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 5, 3] + upsample_kernel_sizes: [16, 11, 7] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 + +# gan related module +mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 80 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: null + center: False +hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan + generator: !ref + discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator + mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator + mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator + mel_spec_transform: [ + !ref + ] + +# processor functions +parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener +get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe + multilingual: True + num_languages: 100 + language: 'en' + task: 'transcribe' +allowed_special: 'all' +tokenize: !name:cosyvoice.dataset.processor.tokenize + get_tokenizer: !ref + allowed_special: !ref +filter: !name:cosyvoice.dataset.processor.filter + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 +resample: !name:cosyvoice.dataset.processor.resample + resample_rate: !ref +truncate: !name:cosyvoice.dataset.processor.truncate + truncate_length: 24576 # must be a multiplier of hop_size +feat_extractor: !name:matcha.utils.audio.mel_spectrogram + n_fft: 1024 + num_mels: 80 + sampling_rate: !ref + hop_size: 256 + win_size: 1024 + fmin: 0 + fmax: 8000 + center: False +compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank + feat_extractor: !ref +compute_f0: !name:cosyvoice.dataset.processor.compute_f0 + sample_rate: !ref + hop_size: 256 +parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding + normalize: True +shuffle: !name:cosyvoice.dataset.processor.shuffle + shuffle_size: 1000 +sort: !name:cosyvoice.dataset.processor.sort + sort_size: 500 # sort_size should be less than shuffle_size +batch: !name:cosyvoice.dataset.processor.batch + batch_type: 'dynamic' + max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g +padding: !name:cosyvoice.dataset.processor.padding + use_spk_embedding: True # change to True during sft + dpo: True + +# dataset processor pipeline +data_pipeline: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] +data_pipeline_gan: [ + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , + !ref , +] + +# llm flow train conf +train_conf: + optim: adam + optim_conf: + lr: 0.00001 # change to 1e-5 during sft + scheduler: warmuplr # change to constantlr during sft + scheduler_conf: + warmup_steps: 25000 + max_epoch: 200 + grad_clip: 5 + accum_grad: 2 + log_interval: 100 + save_per_step: -1 + +# gan train conf +train_conf_gan: + optim: adam + optim_conf: + lr: 0.0002 # use small lr for gan training + scheduler: constantlr + optim_d: adam + optim_conf_d: + lr: 0.0002 # use small lr for gan training + scheduler_d: constantlr + max_epoch: 200 + grad_clip: 5 + accum_grad: 1 # in gan training, accum_grad must be 1 + log_interval: 100 + save_per_step: -1 \ No newline at end of file diff --git a/tools/make_parquet_list_dpo.py b/tools/make_parquet_list_dpo.py new file mode 100755 index 0000000..c6ee6f5 --- /dev/null +++ b/tools/make_parquet_list_dpo.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import json +from tqdm import tqdm +import pandas as pd +import multiprocessing +import time +import torch + + +def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file): + start_time = time.time() + data_list = [] + for utt in tqdm(utt_list): + data = open(utt2wav[utt], 'rb').read() + data_list.append(data) + wav_list = [utt2wav[utt] for utt in utt_list] + text_list = [utt2text[utt] for utt in utt_list] + spk_list = [utt2spk[utt] for utt in utt_list] + uttembedding_list = [utt2embedding[utt] for utt in utt_list] + spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list] + speech_token_list = [utt2speech_token[utt] for utt in utt_list] + if utt2reject_speech_token: + reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list] + + # 保存到parquet,utt2parquet_file,spk2parquet_file + df = pd.DataFrame() + df['utt'] = utt_list + df['wav'] = wav_list + df['audio_data'] = data_list + df['text'] = text_list + df['spk'] = spk_list + df['utt_embedding'] = uttembedding_list + df['spk_embedding'] = spkembedding_list + df['speech_token'] = speech_token_list + if utt2reject_speech_token: + df['reject_speech_token'] = reject_speech_token_list + df.to_parquet(parquet_file) + with open(utt2parquet_file, 'w') as f: + json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2) + with open(spk2parquet_file, 'w') as f: + json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2) + logging.info('spend time {}'.format(time.time() - start_time)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--num_utts_per_parquet', + type=int, + default=1000, + help='num utts per parquet') + parser.add_argument('--num_processes', + type=int, + default=1, + help='num processes for make parquets') + parser.add_argument('--src_dir', + type=str) + parser.add_argument('--des_dir', + type=str) + parser.add_argument('--dpo', + action='store_true', + default=False, + help='Use Direct Preference Optimization') + args = parser.parse_args() + + utt2wav, utt2text, utt2spk = {}, {}, {} + with open('{}/wav.scp'.format(args.src_dir)) as f: + for l in f: + l = l.replace('\n', '').split() + utt2wav[l[0]] = l[1] + with open('{}/text'.format(args.src_dir)) as f: + for l in f: + l = l.replace('\n', '').split() + utt2text[l[0]] = ' '.join(l[1:]) + with open('{}/utt2spk'.format(args.src_dir)) as f: + for l in f: + l = l.replace('\n', '').split() + utt2spk[l[0]] = l[1] + utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) + spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) + utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) + if args.dpo: + utt2reject_speech_token = torch.load('{}/utt2reject_speech_token.pt'.format(args.src_dir)) + else: + utt2reject_speech_token = None + utts = list(utt2wav.keys()) + + # Using process pool to speedup + pool = multiprocessing.Pool(processes=args.num_processes) + parquet_list, utt2parquet_list, spk2parquet_list = [], [], [] + for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)): + parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i)) + utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i)) + spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i)) + parquet_list.append(parquet_file) + utt2parquet_list.append(utt2parquet_file) + spk2parquet_list.append(spk2parquet_file) + pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file)) + pool.close() + pool.join() + + with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \ + open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \ + open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3: + for name in parquet_list: + f1.write(name + '\n') + for name in utt2parquet_list: + f2.write(name + '\n') + for name in spk2parquet_list: + f3.write(name + '\n') From a442317d171fc0ffa0f381149919d7467ec0f20b Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Wed, 16 Apr 2025 17:57:02 +0800 Subject: [PATCH 3/6] add flow trt wrapper --- cosyvoice/cli/cosyvoice.py | 4 +- cosyvoice/cli/model.py | 33 ++- cosyvoice/flow/flow_matching.py | 98 ++++---- cosyvoice/llm/llm_vllm.py | 212 +++++++++++++++++ cosyvoice/llm/vllm_use_cosyvoice2_model.py | 263 +++++++++++++++++++++ cosyvoice/utils/common.py | 19 ++ cosyvoice/utils/file_utils.py | 2 +- requirements_vllm.txt | 40 ++++ 8 files changed, 615 insertions(+), 56 deletions(-) create mode 100644 cosyvoice/llm/llm_vllm.py create mode 100644 cosyvoice/llm/vllm_use_cosyvoice2_model.py create mode 100644 requirements_vllm.txt diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index efebe4d..1f17620 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -137,7 +137,7 @@ class CosyVoice: class CosyVoice2(CosyVoice): - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False, trt_concurrent=1): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 @@ -159,7 +159,7 @@ class CosyVoice2(CosyVoice): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') - self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache) + self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache, trt_concurrent) self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 20ddad0..104c217 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) +# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +14,7 @@ # limitations under the License. import os from typing import Generator +import queue import torch import numpy as np import threading @@ -22,6 +24,7 @@ from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out from cosyvoice.utils.file_utils import convert_onnx_to_trt +from cosyvoice.utils.common import TrtContextWrapper class CosyVoiceModel: @@ -89,9 +92,12 @@ class CosyVoiceModel: del self.flow.decoder.estimator import tensorrt as trt with open(flow_decoder_estimator_model, 'rb') as f: - self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) - assert self.flow.decoder.estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) - self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() + estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) + if isinstance(self, CosyVoice2Model): + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent) + else: + self.flow.decoder.estimator = estimator_engine.create_execution_context() def get_trt_kwargs(self): min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] @@ -231,7 +237,9 @@ class CosyVoiceModel: self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.current_stream().synchronize() class CosyVoice2Model(CosyVoiceModel): @@ -241,13 +249,15 @@ class CosyVoice2Model(CosyVoiceModel): flow: torch.nn.Module, hift: torch.nn.Module, fp16: bool = False, - use_flow_cache: bool = False): + use_flow_cache: bool = False, + trt_concurrent: int = 1): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift self.fp16 = fp16 self.use_flow_cache = use_flow_cache + self.trt_concurrent = trt_concurrent if self.fp16 is True: self.llm.half() self.flow.half() @@ -261,12 +271,16 @@ class CosyVoice2Model(CosyVoiceModel): self.speech_window = np.hamming(2 * self.source_cache_len) # rtf and decoding related self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) + for _ in range(trt_concurrent): + self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()) self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} self.flow_cache_dict = {} self.hift_cache_dict = {} + self.trt_context_dict = {} def init_flow_cache(self): encoder_cache = {'offset': 0, @@ -304,7 +318,7 @@ class CosyVoice2Model(CosyVoiceModel): return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): - with torch.cuda.amp.autocast(self.fp16): + with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]: tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), prompt_token=prompt_token.to(self.device), @@ -349,6 +363,7 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None self.flow_cache_dict[this_uuid] = self.init_flow_cache() + self.trt_context_dict[this_uuid] = self.trt_context_pool.get() if source_speech_token.shape[1] == 0: p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) else: @@ -405,4 +420,8 @@ class CosyVoice2Model(CosyVoiceModel): self.llm_end_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) - torch.cuda.empty_cache() + self.trt_context_pool.put(self.trt_context_dict[this_uuid]) + self.trt_context_dict.pop(this_uuid) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.current_stream().synchronize() diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index ccf235b..47e6961 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -1,4 +1,5 @@ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -290,50 +291,55 @@ class CausalConditionalCFM(ConditionalCFM): x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache) cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7) else: - with self.lock: - self.estimator.set_input_shape('x', (2, 80, x.size(2))) - self.estimator.set_input_shape('mask', (2, 1, x.size(2))) - self.estimator.set_input_shape('mu', (2, 80, x.size(2))) - self.estimator.set_input_shape('t', (2,)) - self.estimator.set_input_shape('spks', (2, 80)) - self.estimator.set_input_shape('cond', (2, 80, x.size(2))) - self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape) - self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape) - self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape) - self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape) - self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape) - self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape) - self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape) - # run trt engine - down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x) - mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x) - up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x) - assert self.estimator.execute_v2([x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - cache['down_blocks_conv_cache'].contiguous().data_ptr(), - cache['down_blocks_kv_cache'].contiguous().data_ptr(), - cache['mid_blocks_conv_cache'].contiguous().data_ptr(), - cache['mid_blocks_kv_cache'].contiguous().data_ptr(), - cache['up_blocks_conv_cache'].contiguous().data_ptr(), - cache['up_blocks_kv_cache'].contiguous().data_ptr(), - cache['final_blocks_conv_cache'].contiguous().data_ptr(), - x.data_ptr(), - cache['down_blocks_conv_cache'].data_ptr(), - down_blocks_kv_cache_out.data_ptr(), - cache['mid_blocks_conv_cache'].data_ptr(), - mid_blocks_kv_cache_out.data_ptr(), - cache['up_blocks_conv_cache'].data_ptr(), - up_blocks_kv_cache_out.data_ptr(), - cache['final_blocks_conv_cache'].data_ptr()]) is True - cache = (cache['down_blocks_conv_cache'], - down_blocks_kv_cache_out, - cache['mid_blocks_conv_cache'], - mid_blocks_kv_cache_out, - cache['up_blocks_conv_cache'], - up_blocks_kv_cache_out, - cache['final_blocks_conv_cache']) + estimator, trt_engine = self.estimator.acquire_estimator() + estimator.set_input_shape('x', (2, 80, x.size(2))) + estimator.set_input_shape('mask', (2, 1, x.size(2))) + estimator.set_input_shape('mu', (2, 80, x.size(2))) + estimator.set_input_shape('t', (2,)) + estimator.set_input_shape('spks', (2, 80)) + estimator.set_input_shape('cond', (2, 80, x.size(2))) + estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape) + estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape) + estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape) + estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape) + estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape) + estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape) + estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape) + down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x) + mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x) + up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x) + data_ptrs = [x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + cache['down_blocks_conv_cache'].contiguous().data_ptr(), + cache['down_blocks_kv_cache'].contiguous().data_ptr(), + cache['mid_blocks_conv_cache'].contiguous().data_ptr(), + cache['mid_blocks_kv_cache'].contiguous().data_ptr(), + cache['up_blocks_conv_cache'].contiguous().data_ptr(), + cache['up_blocks_kv_cache'].contiguous().data_ptr(), + cache['final_blocks_conv_cache'].contiguous().data_ptr(), + x.data_ptr(), + cache['down_blocks_conv_cache'].data_ptr(), + down_blocks_kv_cache_out.data_ptr(), + cache['mid_blocks_conv_cache'].data_ptr(), + mid_blocks_kv_cache_out.data_ptr(), + cache['up_blocks_conv_cache'].data_ptr(), + up_blocks_kv_cache_out.data_ptr(), + cache['final_blocks_conv_cache'].data_ptr()] + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.estimator.release_estimator(estimator) + cache = (cache['down_blocks_conv_cache'], + down_blocks_kv_cache_out, + cache['mid_blocks_conv_cache'], + mid_blocks_kv_cache_out, + cache['up_blocks_conv_cache'], + up_blocks_kv_cache_out, + cache['final_blocks_conv_cache']) return x, cache diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py new file mode 100644 index 0000000..a864a04 --- /dev/null +++ b/cosyvoice/llm/llm_vllm.py @@ -0,0 +1,212 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +import queue +import asyncio +import threading +from typing import List, Generator, AsyncGenerator +import torch +from cosyvoice.utils.file_utils import logging +from cosyvoice.llm.llm import Qwen2LM + +# 启用vllm V1版本 +import os +os.environ["VLLM_USE_V1"] = '1' +from vllm import ModelRegistry +from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput +from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs +from vllm.sampling_params import SamplingParams + +from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM +ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM) + +# EngineArgs +ENGINE_ARGS = { + "block_size": 16, + "swap_space": 0, + # "enforce_eager": True, + "gpu_memory_utilization": 0.4, + "max_num_batched_tokens": 1024, + "max_model_len": 1024, + "max_num_seqs": 256, + "disable_log_requests": True, + "disable_log_stats": True, + "dtype": "float16" +} + +from vllm.sampling_params import RequestOutputKind +# SamplingParams +SAMPLING_PARAMS = { + "temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token + "top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token + "top_k": 25, + # "min_tokens": 80, # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动 + # "presence_penalty": 1.0, # 不支持设置 + # "frequency_penalty": 0.0, # 不支持设置 + "max_tokens": 1024, + "detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算 + "ignore_eos": False, + "output_kind": RequestOutputKind.DELTA # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码 +} + +def tensor_to_list(tensor: torch.tensor): + return tensor.view(-1).cpu().numpy().tolist() + +class VllmQwen2LM(Qwen2LM): + def __init__( + self, + model_dir, + mix_ratio: List[int] = [5, 15], + ): + self.fp16 = False + self.half = lambda: None + self.mix_ratio = mix_ratio + # --------------------------------------------- + # vllm engine 的参数配置 + engine_args = AsyncEngineArgs( + model=model_dir, + **ENGINE_ARGS, + ) + self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args) + + self.speech_token_size = 6564 # 6561 + 3 + self.llm_token_size = 151936 # llm vocab_size + self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1 + self.task_token_id = self.sos_eos_token_id + 1 + self.zero_token_id = self.task_token_id + 1 + + # vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务 + self.loop = asyncio.new_event_loop() + self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True) + self.loop_thread.start() + + def _run_event_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens): + sampling_params = SamplingParams(**SAMPLING_PARAMS) + sampling_params.stop_token_ids = stop_token_ids or [6561] + if max_tokens: + sampling_params.max_tokens = max_tokens + async for output in self.llm_engine.generate( + { + "prompt_token_ids": prompt_token_ids, + }, + sampling_params=sampling_params, + request_id=request_id or f"{time.time()}", + ): + out_queue.put((output.outputs[0], output.finished)) + + def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None): + out_queue = queue.Queue() + asyncio.run_coroutine_threadsafe( + self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop + ) + # 接收 out_queue 返回的结果 + finished = False + while not finished: + (output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get() + yield output + + def inference( + self, + text: torch.Tensor, + text_len: torch.Tensor, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor|int, None, None]: + prompt_text = tensor_to_list(prompt_text + torch.tensor(6564)) + prompt_speech_token = tensor_to_list(prompt_speech_token) + + text = tensor_to_list(text + torch.tensor(6564)) + prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \ + [self.task_token_id] + prompt_speech_token + max_tokens = len(text) * 20 + for output in self.llm_inference( + prompt_token_ids, + stop_token_ids=[6561], + max_tokens=max_tokens, + ): + if output.token_ids[-1] == 6561: + need_add_tokens = output.token_ids[:-1] + else: + need_add_tokens = output.token_ids + for token in need_add_tokens: + yield token + + def inference_bistream( + self, + text: Generator, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + prompt_text = tensor_to_list(prompt_text + torch.tensor(6564)) + prompt_speech_token = tensor_to_list(prompt_speech_token) + + last_tokens = [] + prompt_token_ids = [self.sos_eos_token_id] + text_tokens_cache = prompt_text + for this_text in text: + this_text = tensor_to_list(this_text + torch.tensor(6564)) + # text need tokens + assert isinstance(this_text, list), "text need token ids List[int]." + text_tokens_cache += this_text + while len(prompt_speech_token) != 0: + if len(text_tokens_cache) >= self.mix_ratio[0]: + text_input_token = text_tokens_cache[:self.mix_ratio[0]] + speech_input_token = prompt_speech_token[:self.mix_ratio[1]] + prompt_token_ids += text_input_token + speech_input_token + # reset the last cache + text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] + prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:] + else: + break + if len(prompt_speech_token) == 0: + if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1: + if len(text_tokens_cache) >= self.mix_ratio[0]: + text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]] + prompt_token_ids += text_tokens_temp + text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] + else: + continue + for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]): + last_tokens = output.token_ids + if last_tokens[-1] == 6563: + need_add_tokens = last_tokens[:-1] + else: + need_add_tokens = last_tokens + for token in need_add_tokens: + yield token + prompt_token_ids.extend(need_add_tokens) + prompt_token_ids += text_tokens_cache + [self.task_token_id] + for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]): + if output.token_ids[-1] == 6561: + need_add_tokens = output.token_ids[:-1] + else: + need_add_tokens = output.token_ids + for token in need_add_tokens: + yield token diff --git a/cosyvoice/llm/vllm_use_cosyvoice2_model.py b/cosyvoice/llm/vllm_use_cosyvoice2_model.py new file mode 100644 index 0000000..6e36ef3 --- /dev/null +++ b/cosyvoice/llm/vllm_use_cosyvoice2_model.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2 model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any +from typing_extensions import TypeVar + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm.model_executor.models.interfaces import T +from vllm.model_executor.models.qwen2 import Qwen2Model + +from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings + +logger = init_logger(__name__) + +IGNORE_ID = -1 + + +class CosyVoice2Model(nn.Module): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.llm_input_size = 896 + self.llm_output_size = 896 + + self.speech_token_size = 6561+3 + self.llm_token_size = config.vocab_size + + # 2. build speech token language model related modules + self.sos_eos = 0 + self.task_id = 1 + self.fill_token = 2 + + + self.allow_patterns_overrides = ["llm.*"] + self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size) + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + # self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size) + self.llm_decoder = ParallelLMHead(self.speech_token_size, + self.llm_output_size, + bias=True, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "llm_decoder")) + self.logits_processor = LogitsProcessor(self.speech_token_size) + + # length_normalized_loss: bool = True, + # lsm_weight: float = 0.0, + # self.criterion_ce = LabelSmoothingLoss( + # size=self.speech_token_size, + # padding_idx=IGNORE_ID, + # smoothing=lsm_weight, + # normalize_length=length_normalized_loss, + # ) + + # 3. [Optional] build speech token related modules + self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size) + + # 4. sampling method + ## use vllm sampling method + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.mix_ratio: List[int] = [5, 15] + + # 定义特殊token常量 + self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32) + self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404 + self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405 + self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32) + + self.zero_embed_buffer = torch.zeros( + (vllm_config.scheduler_config.max_num_seqs, self.llm_input_size), + dtype=self.llm_embedding.weight.dtype, + device=self.llm_embedding.weight.device + ) + self.inputs_embed_buffer = torch.zeros( + (vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size), + dtype=self.llm_embedding.weight.dtype, + device=self.llm_embedding.weight.device, + ) + + def get_sos_eos_emb(self): + return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + + def get_task_id_emb(self): + return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[T] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + ) -> torch.Tensor: + """ + Returns the input embeddings merged from the text embeddings from + input_ids and the multimodal embeddings generated from multimodal + kwargs. + """ + # 创建掩码,标记哪些 token_id 属于音频 Token + mask = input_ids < self.speech_token_size + + # 获取 input_ids 的原始形状 + input_shape = input_ids.shape + # 展平 input_ids 和掩码以便统一处理 + flat_input_ids = input_ids.view(-1) + flat_mask = mask.view(-1) + + inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]] + inputs_embeds.zero_() + + # Process speech tokens + if flat_mask.any(): + speech_token_ids = flat_input_ids[flat_mask] + inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids) + + # 处理大于 delta 的 token_id + if (~flat_mask).any(): + llm_token_ids = flat_input_ids[~flat_mask] + llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask]) + + sos_eos_mask = llm_token_ids == self.sos_eos_token_id + task_mask = llm_token_ids == self.task_token_id + zero_mask = llm_token_ids == self.zero_token_id + normal_mask = ~(sos_eos_mask | task_mask | zero_mask) + + # 分层处理逻辑 + # 第一优先级:SOS/EOS标记 + if sos_eos_mask.any(): + llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0) + + # 第二优先级:任务标记 + if task_mask.any(): + llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0) + + # 第二优先级:空音频标记 + if zero_mask.any(): + llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])] + + # 常规LLM token + if normal_mask.any(): + original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta + # print('original_ids: ',original_ids) + llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids) + + inputs_embeds[~flat_mask] = llm_embeds + + inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size) + + # 合并多模态嵌入(如果有) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.audio_token_index + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings( + input_ids, + attn_metadata=attn_metadata, + ) + return self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.llm_decoder, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + @staticmethod + def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]: + for name, param in weights: + # 处理Qwen2Model核心参数 + if name.startswith("llm."): + if name.startswith("llm.model.model."): + name = name.replace("llm.model.model.", "model.") + else: + continue + # print('weights name: ', name) + yield name, param + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + weights = self.convert_weights(weights) + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 3e61a8c..088ca69 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -1,5 +1,6 @@ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) # 2024 Alibaba Inc (authors: Xiang Lyu) +# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +16,7 @@ # Modified from ESPnet(https://github.com/espnet/espnet) """Unility functions for Transformer.""" +import queue import random from typing import List @@ -164,3 +166,20 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min mask = (1.0 - mask) * -1.0e+10 return mask + + +class TrtContextWrapper: + def __init__(self, trt_engine, trt_concurrent=1): + self.trt_context_pool = queue.Queue() + self.trt_engine = trt_engine + for _ in range(trt_concurrent): + trt_context = trt_engine.create_execution_context() + assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) + self.trt_context_pool.put(trt_context) + assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' + + def acquire_estimator(self): + return self.trt_context_pool.get(), self.trt_engine + + def release_estimator(self, context): + self.trt_context_pool.put(context) diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index f0a450c..80eafaf 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -56,7 +56,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): network = builder.create_network(network_flags) parser = trt.OnnxParser(network, logger) config = builder.create_builder_config() - config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB if fp16: config.set_flag(trt.BuilderFlag.FP16) profile = builder.create_optimization_profile() diff --git a/requirements_vllm.txt b/requirements_vllm.txt new file mode 100644 index 0000000..f3dcb25 --- /dev/null +++ b/requirements_vllm.txt @@ -0,0 +1,40 @@ +vllm==0.7.3 +pydantic==2.10.6 +torch==2.5.1 +torchaudio==2.5.1 + +conformer==0.3.2 + +diffusers==0.32.2 +gdown==5.1.0 +grpcio==1.57.0 +grpcio-tools==1.57.0 +hydra-core==1.3.2 +HyperPyYAML==1.2.2 +inflect==7.3.1 +librosa==0.10.2 + +lightning==2.5.0.post0 +matplotlib==3.7.5 +modelscope==1.15.0 + +networkx==3.4.2 +omegaconf==2.3.0 +onnx==1.17.0 + +onnxruntime-gpu==1.19.0; sys_platform == 'linux' + +#openai-whisper==20231117 +openai-whisper==20240930 +protobuf==4.25 +pyworld==0.3.4 +rich==13.7.1 +soundfile==0.12.1 +tensorboard==2.14.0 +wget==3.2 +WeTextProcessing==1.0.3 + +# trt use +tensorrt-cu12==10.0.1 +tensorrt-cu12-bindings==10.0.1 +tensorrt-cu12-libs==10.0.1 \ No newline at end of file From 65ad448714c60fc6d4133ea3a0439e2ed5320b43 Mon Sep 17 00:00:00 2001 From: burkliu Date: Thu, 24 Apr 2025 17:14:49 +0800 Subject: [PATCH 4/6] [debug] a better solution for mismatch of speech feat len and speech token len, refer to https://github.com/FunAudioLLM/CosyVoice/issues/1051 --- cosyvoice/dataset/processor.py | 12 ++++++++++-- cosyvoice/flow/flow.py | 2 -- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py index 8424ada..8ac82a1 100644 --- a/cosyvoice/dataset/processor.py +++ b/cosyvoice/dataset/processor.py @@ -159,6 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'): def compute_fbank(data, feat_extractor, + token_mel_ratio=2, mode='train'): """ Extract fbank @@ -174,8 +175,15 @@ def compute_fbank(data, assert 'utt' in sample assert 'text_token' in sample waveform = sample['speech'] - mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) - sample['speech_feat'] = mat + feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) + + # padding with replicate mode (align to speech_token len * token_mel_ratio) + pad_len = sample["speech_token"].shape[0] * token_mel_ratio - feat.shape[0] + if pad_len > 0: + feat_to_pad = feat[-1:].repeat((pad_len, 1)) + feat = torch.cat([feat, feat_to_pad], dim=0) + + sample['speech_feat'] = feat yield sample diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 9c642ee..e1cf429 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -92,7 +92,6 @@ class MaskedDiffWithXvec(torch.nn.Module): mask = (~make_pad_mask(feat_len)).to(h) # NOTE this is unnecessary, feat/h already same shape - feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) loss, _ = self.decoder.compute_loss( feat.transpose(1, 2).contiguous(), mask.unsqueeze(1), @@ -214,7 +213,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): h = self.encoder_proj(h) # get conditions - feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) conds = torch.zeros(feat.shape, device=token.device) for i, j in enumerate(feat_len): if random.random() < 0.5: From 038ff9f353b21c98c54b744eaa19ba9b3674c35a Mon Sep 17 00:00:00 2001 From: burkliu Date: Fri, 25 Apr 2025 10:31:43 +0800 Subject: [PATCH 5/6] [feature] modify pad to trim --- cosyvoice/dataset/processor.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py index 8ac82a1..08030d6 100644 --- a/cosyvoice/dataset/processor.py +++ b/cosyvoice/dataset/processor.py @@ -177,11 +177,10 @@ def compute_fbank(data, waveform = sample['speech'] feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) - # padding with replicate mode (align to speech_token len * token_mel_ratio) - pad_len = sample["speech_token"].shape[0] * token_mel_ratio - feat.shape[0] - if pad_len > 0: - feat_to_pad = feat[-1:].repeat((pad_len, 1)) - feat = torch.cat([feat, feat_to_pad], dim=0) + # trim to align speech_token and speech_feat + token_len = min(feat.shape[0] // token_mel_ratio, sample["speech_token"].shape[0]) + feat = feat[:token_mel_ratio * token_len] + sample["speech_token"] = sample["speech_token"][:token_len] sample['speech_feat'] = feat yield sample From 68100c267a0a4a01e88bb52511f64d1bd97c21fd Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Fri, 23 May 2025 12:50:47 +0800 Subject: [PATCH 6/6] remove flow_cache --- README.md | 2 +- cosyvoice/bin/export_jit.py | 7 +- cosyvoice/bin/export_onnx.py | 169 +++----- cosyvoice/cli/cosyvoice.py | 6 +- cosyvoice/cli/model.py | 100 ++--- cosyvoice/dataset/processor.py | 13 +- cosyvoice/flow/decoder.py | 459 ++-------------------- cosyvoice/flow/flow.py | 16 +- cosyvoice/flow/flow_matching.py | 168 ++------ cosyvoice/hifigan/generator.py | 170 +++++++- cosyvoice/transformer/upsample_encoder.py | 132 +------ cosyvoice/utils/file_utils.py | 2 +- cosyvoice/utils/mask.py | 39 +- test1.py | 37 -- 14 files changed, 365 insertions(+), 955 deletions(-) delete mode 100644 test1.py diff --git a/README.md b/README.md index 4a1dbd3..c7a724d 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,7 @@ import torchaudio **CosyVoice2 Usage** ```python -cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False) +cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False) # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference # zero_shot usage diff --git a/cosyvoice/bin/export_jit.py b/cosyvoice/bin/export_jit.py index 1e89005..4eedc1a 100644 --- a/cosyvoice/bin/export_jit.py +++ b/cosyvoice/bin/export_jit.py @@ -61,8 +61,7 @@ def main(): model = CosyVoice(args.model_dir) except Exception: try: - # NOTE set use_flow_cache=True when export jit for cache inference - model = CosyVoice2(args.model_dir, use_flow_cache=True) + model = CosyVoice2(args.model_dir) except Exception: raise TypeError('no valid model_type!') @@ -93,9 +92,9 @@ def main(): else: # 3. export flow encoder flow_encoder = model.model.flow.encoder - script = get_optimized_script(flow_encoder, ['forward_chunk']) + script = get_optimized_script(flow_encoder) script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) - script = get_optimized_script(flow_encoder.half(), ['forward_chunk']) + script = get_optimized_script(flow_encoder.half()) script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) logging.info('successfully export flow_encoder') diff --git a/cosyvoice/bin/export_onnx.py b/cosyvoice/bin/export_onnx.py index fcb1594..dd9f009 100644 --- a/cosyvoice/bin/export_onnx.py +++ b/cosyvoice/bin/export_onnx.py @@ -62,135 +62,58 @@ def main(): model = CosyVoice(args.model_dir) except Exception: try: - # NOTE set use_flow_cache=True when export jit for cache inference - model = CosyVoice2(args.model_dir, use_flow_cache=True) + model = CosyVoice2(args.model_dir) except Exception: raise TypeError('no valid model_type!') - if not isinstance(model, CosyVoice2): - # 1. export flow decoder estimator - estimator = model.model.flow.decoder.estimator - estimator.eval() + # 1. export flow decoder estimator + estimator = model.model.flow.decoder.estimator + estimator.eval() - device = model.model.device - batch_size, seq_len = 2, 256 - out_channels = model.model.flow.decoder.estimator.out_channels - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) - torch.onnx.export( - estimator, - (x, mask, mu, t, spks, cond), - '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - export_params=True, - opset_version=18, - do_constant_folding=True, - input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], - output_names=['estimator_out'], - dynamic_axes={ - 'x': {2: 'seq_len'}, - 'mask': {2: 'seq_len'}, - 'mu': {2: 'seq_len'}, - 'cond': {2: 'seq_len'}, - 'estimator_out': {2: 'seq_len'}, - } - ) + device = model.model.device + batch_size, seq_len = 2, 256 + out_channels = model.model.flow.decoder.estimator.out_channels + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) + torch.onnx.export( + estimator, + (x, mask, mu, t, spks, cond), + '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], + output_names=['estimator_out'], + dynamic_axes={ + 'x': {2: 'seq_len'}, + 'mask': {2: 'seq_len'}, + 'mu': {2: 'seq_len'}, + 'cond': {2: 'seq_len'}, + 'estimator_out': {2: 'seq_len'}, + } + ) - # 2. test computation consistency - option = onnxruntime.SessionOptions() - option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - option.intra_op_num_threads = 1 - providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] - estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - sess_options=option, providers=providers) + # 2. test computation consistency + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), + sess_options=option, providers=providers) - for _ in tqdm(range(10)): - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) - output_pytorch = estimator(x, mask, mu, t, spks, cond) - ort_inputs = { - 'x': x.cpu().numpy(), - 'mask': mask.cpu().numpy(), - 'mu': mu.cpu().numpy(), - 't': t.cpu().numpy(), - 'spks': spks.cpu().numpy(), - 'cond': cond.cpu().numpy() - } - output_onnx = estimator_onnx.run(None, ort_inputs)[0] - torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) - logging.info('successfully export estimator') - else: - # 1. export flow decoder estimator - estimator = model.model.flow.decoder.estimator - estimator.forward = estimator.forward_chunk - estimator.eval() - - device = model.model.device - batch_size, seq_len = 2, 256 - out_channels = model.model.flow.decoder.estimator.out_channels - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) - cache = model.model.init_flow_cache()['decoder_cache'] - cache.pop('offset') - cache = {k: v[0] for k, v in cache.items()} - torch.onnx.export( - estimator, - (x, mask, mu, t, spks, cond, - cache['down_blocks_conv_cache'], - cache['down_blocks_kv_cache'], - cache['mid_blocks_conv_cache'], - cache['mid_blocks_kv_cache'], - cache['up_blocks_conv_cache'], - cache['up_blocks_kv_cache'], - cache['final_blocks_conv_cache']), - '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - export_params=True, - opset_version=18, - do_constant_folding=True, - input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache', - 'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'], - output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out', - 'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'], - dynamic_axes={ - 'x': {2: 'seq_len'}, - 'mask': {2: 'seq_len'}, - 'mu': {2: 'seq_len'}, - 'cond': {2: 'seq_len'}, - 'down_blocks_kv_cache': {3: 'cache_in_len'}, - 'mid_blocks_kv_cache': {3: 'cache_in_len'}, - 'up_blocks_kv_cache': {3: 'cache_in_len'}, - 'estimator_out': {2: 'seq_len'}, - 'down_blocks_kv_cache_out': {3: 'cache_out_len'}, - 'mid_blocks_kv_cache_out': {3: 'cache_out_len'}, - 'up_blocks_kv_cache_out': {3: 'cache_out_len'}, - } - ) - - # 2. test computation consistency - option = onnxruntime.SessionOptions() - option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - option.intra_op_num_threads = 1 - providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] - estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), - sess_options=option, providers=providers) - - for iter in tqdm(range(10)): - x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) - cache = model.model.init_flow_cache()['decoder_cache'] - cache.pop('offset') - cache = {k: v[0] for k, v in cache.items()} - output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()}) - ort_inputs = { - 'x': x.cpu().numpy(), - 'mask': mask.cpu().numpy(), - 'mu': mu.cpu().numpy(), - 't': t.cpu().numpy(), - 'spks': spks.cpu().numpy(), - 'cond': cond.cpu().numpy(), - } - output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}}) - if iter == 0: - # NOTE why can not pass first iteration check? - continue - for i, j in zip(output_pytorch, output_onnx): - torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4) - logging.info('successfully export estimator') + for _ in tqdm(range(10)): + x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) + output_pytorch = estimator(x, mask, mu, t, spks, cond) + ort_inputs = { + 'x': x.cpu().numpy(), + 'mask': mask.cpu().numpy(), + 'mu': mu.cpu().numpy(), + 't': t.cpu().numpy(), + 'spks': spks.cpu().numpy(), + 'cond': cond.cpu().numpy() + } + output_onnx = estimator_onnx.run(None, ort_inputs)[0] + torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) + logging.info('successfully export estimator') if __name__ == "__main__": diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 3b9a7d5..b95a9e0 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -140,7 +140,7 @@ class CosyVoice: class CosyVoice2(CosyVoice): - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False, trt_concurrent=1): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 @@ -162,9 +162,9 @@ class CosyVoice2(CosyVoice): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') - self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache, trt_concurrent) + self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent) self.model.load('{}/llm.pt'.format(model_dir), - '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir), + '{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir)) if load_jit: self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 104c217..aa110b1 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -33,12 +33,14 @@ class CosyVoiceModel: llm: torch.nn.Module, flow: torch.nn.Module, hift: torch.nn.Module, - fp16: bool = False): + fp16: bool = False, + trt_concurrent: int = 1): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow self.hift = hift self.fp16 = fp16 + self.trt_concurrent = trt_concurrent if self.fp16 is True: self.llm.half() self.flow.half() @@ -85,23 +87,18 @@ class CosyVoiceModel: def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): assert torch.cuda.is_available(), 'tensorrt only supports gpu!' - if not os.path.exists(flow_decoder_estimator_model): + if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0: convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) - if os.path.getsize(flow_decoder_estimator_model) == 0: - raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model)) del self.flow.decoder.estimator import tensorrt as trt with open(flow_decoder_estimator_model, 'rb') as f: estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) - if isinstance(self, CosyVoice2Model): - self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent) - else: - self.flow.decoder.estimator = estimator_engine.create_execution_context() + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent) def get_trt_kwargs(self): min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] - opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)] + opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)] max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] input_names = ["x", "mask", "mu", "cond"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} @@ -249,21 +246,21 @@ class CosyVoice2Model(CosyVoiceModel): flow: torch.nn.Module, hift: torch.nn.Module, fp16: bool = False, - use_flow_cache: bool = False, trt_concurrent: int = 1): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.llm = llm self.flow = flow + # NOTE default setting for jit/onnx export, you can set to False when using pytorch inference + self.flow.encoder.streaming = True + self.flow.decoder.estimator.streaming = True self.hift = hift self.fp16 = fp16 - self.use_flow_cache = use_flow_cache self.trt_concurrent = trt_concurrent if self.fp16 is True: self.llm.half() self.flow.half() - # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml + # NOTE must matching training static_chunk_size self.token_hop_len = 25 - self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio # hift cache self.mel_cache_len = 8 self.source_cache_len = int(self.mel_cache_len * 480) @@ -278,56 +275,24 @@ class CosyVoice2Model(CosyVoiceModel): # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} - self.flow_cache_dict = {} self.hift_cache_dict = {} self.trt_context_dict = {} - def init_flow_cache(self): - encoder_cache = {'offset': 0, - 'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device), - 'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device), - 'upsample_offset': 0, - 'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device), - 'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)} - decoder_cache = {'offset': 0, - 'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device), - 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device), - 'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device), - 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device), - 'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device), - 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device), - 'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)} - if self.fp16 is True: - for cache in [encoder_cache, decoder_cache]: - for k, v in cache.items(): - if isinstance(v, torch.Tensor): - cache[k] = v.half() - cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache} - return cache - def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder - def get_trt_kwargs(self): - min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)] - opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)] - max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)] - input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache'] - assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape" - return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} - - def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): + def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0): with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]: - tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device), - cache=self.flow_cache_dict[uuid], - finalize=finalize) + tts_mel, _ = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + finalize=finalize) + tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] # append hift cache if self.hift_cache_dict[uuid] is not None: hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] @@ -362,7 +327,6 @@ class CosyVoice2Model(CosyVoiceModel): with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None - self.flow_cache_dict[this_uuid] = self.init_flow_cache() self.trt_context_dict[this_uuid] = self.trt_context_pool.get() if source_speech_token.shape[1] == 0: p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) @@ -370,27 +334,23 @@ class CosyVoice2Model(CosyVoiceModel): p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid)) p.start() if stream is True: - assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM" - # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size - flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):] - prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:] + token_offset = 0 + prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1]) while True: time.sleep(0.1) - if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) + this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len + if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len: + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, embedding=flow_embedding, + token_offset=token_offset, uuid=this_uuid, finalize=False) - # NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk - flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device) - prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device) + token_offset += this_token_hop_len yield {'tts_speech': this_tts_speech.cpu()} - with self.lock: - self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:] - if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len: + if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len: break p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None @@ -399,18 +359,19 @@ class CosyVoice2Model(CosyVoiceModel): prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, embedding=flow_embedding, + token_offset=token_offset, uuid=this_uuid, finalize=True) yield {'tts_speech': this_tts_speech.cpu()} else: # deal with all tokens - assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference" p.join() this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, embedding=flow_embedding, + token_offset=0, uuid=this_uuid, finalize=True, speed=speed) @@ -419,7 +380,6 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) - self.flow_cache_dict.pop(this_uuid) self.trt_context_pool.put(self.trt_context_dict[this_uuid]) self.trt_context_dict.pop(this_uuid) if torch.cuda.is_available(): diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py index 08030d6..a94eb15 100644 --- a/cosyvoice/dataset/processor.py +++ b/cosyvoice/dataset/processor.py @@ -159,7 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'): def compute_fbank(data, feat_extractor, - token_mel_ratio=2, + token_mel_ratio=0, mode='train'): """ Extract fbank @@ -176,12 +176,11 @@ def compute_fbank(data, assert 'text_token' in sample waveform = sample['speech'] feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) - - # trim to align speech_token and speech_feat - token_len = min(feat.shape[0] // token_mel_ratio, sample["speech_token"].shape[0]) - feat = feat[:token_mel_ratio * token_len] - sample["speech_token"] = sample["speech_token"][:token_len] - + if token_mel_ratio != 0: + # trim to align speech_token and speech_feat + token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0])) + feat = feat[:token_mel_ratio * token_len] + sample["speech_token"] = sample["speech_token"][:token_len] sample['speech_feat'] = feat yield sample diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 4a89fb1..9e28c3f 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -11,16 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Optional, Dict, Any +from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import pack, rearrange, repeat -from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate from cosyvoice.utils.common import mask_to_bias from cosyvoice.utils.mask import add_optional_chunk_mask from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D -from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph +from matcha.models.components.transformer import BasicTransformerBlock class Transpose(torch.nn.Module): @@ -29,7 +28,7 @@ class Transpose(torch.nn.Module): self.dim0 = dim0 self.dim1 = dim1 - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: + def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.transpose(x, self.dim0, self.dim1) return x @@ -57,15 +56,10 @@ class CausalConv1d(torch.nn.Conv1d): assert stride == 1 self.causal_padding = kernel_size - 1 - def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: - if cache.size(2) == 0: - x = F.pad(x, (self.causal_padding, 0), value=0.0) - else: - assert cache.size(2) == self.causal_padding - x = torch.concat([cache, x], dim=2) - cache = x[:, :, -self.causal_padding:] + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, (self.causal_padding, 0), value=0.0) x = super(CausalConv1d, self).forward(x) - return x, cache + return x class CausalBlock1D(Block1D): @@ -79,11 +73,9 @@ class CausalBlock1D(Block1D): nn.Mish(), ) - def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: - output, cache = self.block[0](x * mask, cache) - for i in range(1, len(self.block)): - output = self.block[i](output) - return output * mask, cache + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + output = self.block(x * mask) + return output * mask class CausalResnetBlock1D(ResnetBlock1D): @@ -92,303 +84,6 @@ class CausalResnetBlock1D(ResnetBlock1D): self.block1 = CausalBlock1D(dim, dim_out) self.block2 = CausalBlock1D(dim_out, dim_out) - def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor, - block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0) - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - h, block1_cache = self.block1(x, mask, block1_cache) - h += self.mlp(time_emb).unsqueeze(-1) - h, block2_cache = self.block2(h, mask, block2_cache) - output = h + self.res_conv(x * mask) - return output, block1_cache, block2_cache - - -class CausalAttnProcessor2_0(AttnProcessor2_0): - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - super(CausalAttnProcessor2_0, self).__init__() - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, - cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - *args, - **kwargs, - ) -> Tuple[torch.FloatTensor, torch.Tensor]: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \ - `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - # NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key_cache = attn.to_k(encoder_hidden_states) - value_cache = attn.to_v(encoder_hidden_states) - # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2) - if cache.size(0) != 0: - key = torch.concat([cache[:, :, :, 0], key_cache], dim=1) - value = torch.concat([cache[:, :, :, 1], value_cache], dim=1) - else: - key, value = key_cache, value_cache - cache = torch.stack([key_cache, value_cache], dim=3) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states, cache - - -@maybe_allow_in_graph -class CausalAttention(Attention): - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: Optional[str] = None, - cross_attention_norm_num_groups: int = 32, - qk_norm: Optional[str] = None, - added_kv_proj_dim: Optional[int] = None, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - processor: Optional["AttnProcessor2_0"] = None, - out_dim: int = None, - ): - super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, - cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups, - spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, - _from_deprecated_attn_block, processor, out_dim) - processor = CausalAttnProcessor2_0() - self.set_processor(processor) - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - **cross_attention_kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - The forward method of the `Attention` class. - - Args: - hidden_states (`torch.Tensor`): - The hidden states of the query. - encoder_hidden_states (`torch.Tensor`, *optional*): - The hidden states of the encoder. - attention_mask (`torch.Tensor`, *optional*): - The attention mask to use. If `None`, no mask is applied. - **cross_attention_kwargs: - Additional keyword arguments to pass along to the cross attention. - - Returns: - `torch.Tensor`: The output of the attention layer. - """ - # The `Attention` class can call different attention processors / attention functions - # here we simply pass along all tensors to the selected processor class - # For standard processors that are defined here, `**cross_attention_kwargs` is empty - - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters] - if len(unused_kwargs) > 0: - logger.warning( - f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." - ) - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} - - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cache=cache, - **cross_attention_kwargs, - ) - - -@maybe_allow_in_graph -class CausalBasicTransformerBlock(BasicTransformerBlock): - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - final_dropout: bool = False, - ): - super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout, - cross_attention_dim, activation_fn, num_embeds_ada_norm, - attention_bias, only_cross_attention, double_self_attention, - upcast_attention, norm_elementwise_affine, norm_type, final_dropout) - self.attn1 = CausalAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Notice that normalization is always applied before the real computation in the following blocks. - # 1. Self-Attention - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - attn_output, cache = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, - cache=cache, - **cross_attention_kwargs, - ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - - # 2. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) - ) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: - raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \ - {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.") - - num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size - ff_output = torch.cat( - [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], - dim=self._chunk_dim, - ) - else: - ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = ff_output + hidden_states - - return hidden_states, cache - class ConditionalDecoder(nn.Module): def __init__( @@ -640,7 +335,7 @@ class CausalConditionalDecoder(ConditionalDecoder): resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ - CausalBasicTransformerBlock( + BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, @@ -662,7 +357,7 @@ class CausalConditionalDecoder(ConditionalDecoder): transformer_blocks = nn.ModuleList( [ - CausalBasicTransformerBlock( + BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, @@ -687,7 +382,7 @@ class CausalConditionalDecoder(ConditionalDecoder): ) transformer_blocks = nn.ModuleList( [ - CausalBasicTransformerBlock( + BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, @@ -724,6 +419,9 @@ class CausalConditionalDecoder(ConditionalDecoder): Returns: _type_: _description_ """ + if hasattr(self, 'streaming'): + assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode' + streaming = self.streaming t = self.time_embeddings(t).to(t.dtype) t = self.time_mlp(t) @@ -740,36 +438,36 @@ class CausalConditionalDecoder(ConditionalDecoder): masks = [mask] for resnet, transformer_blocks, downsample in self.down_blocks: mask_down = masks[-1] - x, _, _ = resnet(x, mask_down, t) + x = resnet(x, mask_down, t) x = rearrange(x, "b c t -> b t c").contiguous() if streaming is True: - attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) else: attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: - x, _ = transformer_block( + x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, ) x = rearrange(x, "b t c -> b c t").contiguous() hiddens.append(x) # Save hidden states for skip connections - x, _ = downsample(x * mask_down) + x = downsample(x * mask_down) masks.append(mask_down[:, :, ::2]) masks = masks[:-1] mask_mid = masks[-1] for resnet, transformer_blocks in self.mid_blocks: - x, _, _ = resnet(x, mask_mid, t) + x = resnet(x, mask_mid, t) x = rearrange(x, "b c t -> b t c").contiguous() if streaming is True: - attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) else: attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: - x, _ = transformer_block( + x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, @@ -780,124 +478,21 @@ class CausalConditionalDecoder(ConditionalDecoder): mask_up = masks.pop() skip = hiddens.pop() x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] - x, _, _ = resnet(x, mask_up, t) + x = resnet(x, mask_up, t) x = rearrange(x, "b c t -> b t c").contiguous() if streaming is True: - attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) else: attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = mask_to_bias(attn_mask, x.dtype) for transformer_block in transformer_blocks: - x, _ = transformer_block( + x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, ) x = rearrange(x, "b t c -> b c t").contiguous() - x, _ = upsample(x * mask_up) - x, _ = self.final_block(x, mask_up) + x = upsample(x * mask_up) + x = self.final_block(x, mask_up) output = self.final_proj(x * mask_up) return output * mask - - @torch.inference_mode() - def forward_chunk(self, x, mask, mu, t, spks=None, cond=None, - down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), - mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), - up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0), - final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0) - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Forward pass of the UNet1DConditional model. - - Args: - x (torch.Tensor): shape (batch_size, in_channels, time) - mask (_type_): shape (batch_size, 1, time) - t (_type_): shape (batch_size) - spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. - cond (_type_, optional): placeholder for future use. Defaults to None. - - Raises: - ValueError: _description_ - ValueError: _description_ - - Returns: - _type_: _description_ - """ - - t = self.time_embeddings(t).to(t.dtype) - t = self.time_mlp(t) - - x = pack([x, mu], "b * t")[0] - - if spks is not None: - spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) - x = pack([x, spks], "b * t")[0] - if cond is not None: - x = pack([x, cond], "b * t")[0] - - hiddens = [] - masks = [mask] - - down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device) - mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device) - up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device) - for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks): - mask_down = masks[-1] - x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \ - resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576]) - x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool() - attn_mask = mask_to_bias(attn_mask, x.dtype) - for i, transformer_block in enumerate(transformer_blocks): - x, down_blocks_kv_cache_new[index, i] = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - cache=down_blocks_kv_cache[index, i], - ) - x = rearrange(x, "b t c -> b c t").contiguous() - hiddens.append(x) # Save hidden states for skip connections - x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:]) - masks.append(mask_down[:, :, ::2]) - masks = masks[:-1] - mask_mid = masks[-1] - - for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks): - x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \ - resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:]) - x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool() - attn_mask = mask_to_bias(attn_mask, x.dtype) - for i, transformer_block in enumerate(transformer_blocks): - x, mid_blocks_kv_cache_new[index, i] = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - cache=mid_blocks_kv_cache[index, i] - ) - x = rearrange(x, "b t c -> b c t").contiguous() - - for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks): - mask_up = masks.pop() - skip = hiddens.pop() - x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] - x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \ - resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768]) - x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool() - attn_mask = mask_to_bias(attn_mask, x.dtype) - for i, transformer_block in enumerate(transformer_blocks): - x, up_blocks_kv_cache_new[index, i] = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - cache=up_blocks_kv_cache[index, i] - ) - x = rearrange(x, "b t c -> b c t").contiguous() - x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:]) - x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache) - output = self.final_proj(x * mask_up) - return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \ - up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index e1cf429..d9e832b 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -241,7 +241,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): prompt_feat, prompt_feat_len, embedding, - cache, finalize): assert token.shape[0] == 1 # xvec projection @@ -255,16 +254,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): # text encode if finalize is True: - h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache']) + h, h_lengths = self.encoder(token, token_len) else: token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:] - h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache']) - cache['encoder_cache']['offset'] = encoder_cache[0] - cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1] - cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2] - cache['encoder_cache']['upsample_offset'] = encoder_cache[3] - cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4] - cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5] + h, h_lengths = self.encoder(token, token_len, context=context) mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] h = self.encoder_proj(h) @@ -274,14 +267,13 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): conds = conds.transpose(1, 2) mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) - feat, cache['decoder_cache'] = self.decoder( + feat, _ = self.decoder( mu=h.transpose(1, 2).contiguous(), mask=mask.unsqueeze(1), spks=embedding, cond=conds, n_timesteps=10, - cache=cache['decoder_cache'] ) feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 - return feat.float(), cache + return feat.float(), None diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 47e6961..735889f 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -126,21 +126,26 @@ class ConditionalCFM(BASECFM): if isinstance(self.estimator, torch.nn.Module): return self.estimator(x, mask, mu, t, spks, cond) else: - with self.lock: - self.estimator.set_input_shape('x', (2, 80, x.size(2))) - self.estimator.set_input_shape('mask', (2, 1, x.size(2))) - self.estimator.set_input_shape('mu', (2, 80, x.size(2))) - self.estimator.set_input_shape('t', (2,)) - self.estimator.set_input_shape('spks', (2, 80)) - self.estimator.set_input_shape('cond', (2, 80, x.size(2))) - # run trt engine - assert self.estimator.execute_v2([x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - x.data_ptr()]) is True + estimator, trt_engine = self.estimator.acquire_estimator() + estimator.set_input_shape('x', (2, 80, x.size(2))) + estimator.set_input_shape('mask', (2, 1, x.size(2))) + estimator.set_input_shape('mu', (2, 80, x.size(2))) + estimator.set_input_shape('t', (2,)) + estimator.set_input_shape('spks', (2, 80)) + estimator.set_input_shape('cond', (2, 80, x.size(2))) + data_ptrs = [x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()] + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.estimator.release_estimator(estimator) return x def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): @@ -191,7 +196,7 @@ class CausalConditionalCFM(ConditionalCFM): self.rand_noise = torch.randn([1, 80, 50 * 300]) @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}): + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): """Forward diffusion Args: @@ -210,136 +215,9 @@ class CausalConditionalCFM(ConditionalCFM): shape: (batch_size, n_feats, mel_timesteps) """ - offset = cache.pop('offset') - z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature - z = z[:, :, offset:] - offset += mu.size(2) + z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature # fix prompt and overlap part mu and z t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache) - cache['offset'] = offset - return mel, cache - - def solve_euler(self, x, t_span, mu, mask, spks, cond, cache): - """ - Fixed euler solver for ODEs. - Args: - x (torch.Tensor): random noise - t_span (torch.Tensor): n_timesteps interpolated - shape: (n_timesteps + 1,) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): output_mask - shape: (batch_size, 1, mel_timesteps) - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size, spk_emb_dim) - cond: Not used but kept for future purposes - """ - t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] - t = t.unsqueeze(dim=0) - - # I am storing this because I can later plot it by putting a debugger here and saving it to a file - # Or in future might add like a return_all_steps flag - sol = [] - - # Do not use concat, it may cause memory format changed and trt infer with wrong results! - x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) - mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype) - mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) - t_in = torch.zeros([2], device=x.device, dtype=x.dtype) - spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype) - cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) - flow_cache_size = cache['down_blocks_kv_cache'].shape[4] - for step in range(1, len(t_span)): - # Classifier-Free Guidance inference introduced in VoiceBox - x_in[:] = x - mask_in[:] = mask - mu_in[0] = mu - t_in[:] = t.unsqueeze(0) - spks_in[0] = spks - cond_in[0] = cond - cache_step = {k: v[step - 1] for k, v in cache.items()} - dphi_dt, cache_step = self.forward_estimator( - x_in, mask_in, - mu_in, t_in, - spks_in, - cond_in, - cache_step - ) - # NOTE if smaller than flow_cache_size, means last chunk, no need to cache - if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size: - cache['down_blocks_conv_cache'][step - 1] = cache_step[0] - cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:] - cache['mid_blocks_conv_cache'][step - 1] = cache_step[2] - cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:] - cache['up_blocks_conv_cache'][step - 1] = cache_step[4] - cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:] - cache['final_blocks_conv_cache'][step - 1] = cache_step[6] - dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) - dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t - return sol[-1].float(), cache - - def forward_estimator(self, x, mask, mu, t, spks, cond, cache): - if isinstance(self.estimator, torch.nn.Module): - x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache) - cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7) - else: - estimator, trt_engine = self.estimator.acquire_estimator() - estimator.set_input_shape('x', (2, 80, x.size(2))) - estimator.set_input_shape('mask', (2, 1, x.size(2))) - estimator.set_input_shape('mu', (2, 80, x.size(2))) - estimator.set_input_shape('t', (2,)) - estimator.set_input_shape('spks', (2, 80)) - estimator.set_input_shape('cond', (2, 80, x.size(2))) - estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape) - estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape) - estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape) - estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape) - estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape) - estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape) - estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape) - down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x) - mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x) - up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x) - data_ptrs = [x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - cache['down_blocks_conv_cache'].contiguous().data_ptr(), - cache['down_blocks_kv_cache'].contiguous().data_ptr(), - cache['mid_blocks_conv_cache'].contiguous().data_ptr(), - cache['mid_blocks_kv_cache'].contiguous().data_ptr(), - cache['up_blocks_conv_cache'].contiguous().data_ptr(), - cache['up_blocks_kv_cache'].contiguous().data_ptr(), - cache['final_blocks_conv_cache'].contiguous().data_ptr(), - x.data_ptr(), - cache['down_blocks_conv_cache'].data_ptr(), - down_blocks_kv_cache_out.data_ptr(), - cache['mid_blocks_conv_cache'].data_ptr(), - mid_blocks_kv_cache_out.data_ptr(), - cache['up_blocks_conv_cache'].data_ptr(), - up_blocks_kv_cache_out.data_ptr(), - cache['final_blocks_conv_cache'].data_ptr()] - for i, j in enumerate(data_ptrs): - estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) - # run trt engine - assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True - torch.cuda.current_stream().synchronize() - self.estimator.release_estimator(estimator) - cache = (cache['down_blocks_conv_cache'], - down_blocks_kv_cache_out, - cache['mid_blocks_conv_cache'], - mid_blocks_kv_cache_out, - cache['up_blocks_conv_cache'], - up_blocks_kv_cache_out, - cache['final_blocks_conv_cache']) - return x, cache + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py index 50d7f99..326a1a7 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -223,6 +223,172 @@ class SourceModuleHnNSF(torch.nn.Module): return sine_merge, noise, uv +class SineGen2(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, upsample_scale, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen2, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.upsample_scale = upsample_scale + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2), + scale_factor=1 / self.upsample_scale, + mode="linear").transpose(1, 2) + + phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi + phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale, + scale_factor=self.upsample_scale, mode="linear").transpose(1, 2) + sines = torch.sin(phase) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + + # generate uv signal + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF2(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF2, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + class HiFTGenerator(nn.Module): """ HiFTNet Generator: Neural Source Filter + ISTFTNet @@ -259,7 +425,9 @@ class HiFTGenerator(nn.Module): self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) - self.m_source = SourceModuleHnNSF( + # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation + this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2 + self.m_source = this_SourceModuleHnNSF( sampling_rate=sampling_rate, upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], harmonic_num=nb_harmonics, diff --git a/cosyvoice/transformer/upsample_encoder.py b/cosyvoice/transformer/upsample_encoder.py index 0d98406..e17b188 100644 --- a/cosyvoice/transformer/upsample_encoder.py +++ b/cosyvoice/transformer/upsample_encoder.py @@ -56,16 +56,11 @@ class Upsample1D(nn.Module): # In this mode, first repeat interpolate, than conv with stride=1 self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) - def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest") - if conv_cache.size(2) == 0: - outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) - else: - assert conv_cache.size(2) == self.stride * 2 - outputs = torch.concat([conv_cache, outputs], dim=2) - conv_cache_new = outputs[:, :, -self.stride * 2:] + outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) outputs = self.conv(outputs) - return outputs, input_lengths * self.stride, conv_cache_new + return outputs, input_lengths * self.stride class PreLookaheadLayer(nn.Module): @@ -83,7 +78,7 @@ class PreLookaheadLayer(nn.Module): kernel_size=3, stride=1, padding=0, ) - def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor: """ inputs: (batch_size, seq_len, channels) """ @@ -93,22 +88,18 @@ class PreLookaheadLayer(nn.Module): if context.size(2) == 0: outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) else: + assert self.training is False, 'you have passed context, make sure that you are running inference mode' assert context.size(2) == self.pre_lookahead_len outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0) outputs = F.leaky_relu(self.conv1(outputs)) # outputs - if conv2_cache.size(2) == 0: - outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0) - else: - assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1 - outputs = torch.concat([conv2_cache, outputs], dim=2) - conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):] + outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0) outputs = self.conv2(outputs) outputs = outputs.transpose(1, 2).contiguous() # residual connection outputs = outputs + inputs - return outputs, conv2_cache_new + return outputs class UpsampleConformerEncoder(torch.nn.Module): @@ -253,6 +244,7 @@ class UpsampleConformerEncoder(torch.nn.Module): self, xs: torch.Tensor, xs_lens: torch.Tensor, + context: torch.Tensor = torch.zeros(0, 0, 0), decoding_chunk_size: int = 0, num_decoding_left_chunks: int = -1, streaming: bool = False, @@ -280,20 +272,27 @@ class UpsampleConformerEncoder(torch.nn.Module): checkpointing API because `__call__` attaches all the hooks of the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 """ + if hasattr(self, 'streaming'): + assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode' + streaming = self.streaming T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) if self.global_cmvn is not None: xs = self.global_cmvn(xs) xs, pos_emb, masks = self.embed(xs, masks) + if context.size(1) != 0: + assert self.training is False, 'you have passed context, make sure that you are running inference mode' + context_masks = torch.ones(1, 1, context.size(1)).to(masks) + context, _, _ = self.embed(context, context_masks, offset=xs.size(1)) mask_pad = masks # (B, 1, T/subsample_rate) chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1) # lookahead + conformer encoder - xs, _ = self.pre_lookahead_layer(xs) + xs = self.pre_lookahead_layer(xs, context=context) xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) # upsample + conformer encoder xs = xs.transpose(1, 2).contiguous() - xs, xs_lens, _ = self.up_layer(xs, xs_lens) + xs, xs_lens = self.up_layer(xs, xs_lens) xs = xs.transpose(1, 2).contiguous() T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) @@ -322,100 +321,3 @@ class UpsampleConformerEncoder(torch.nn.Module): for layer in self.up_encoders: xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs - - @torch.jit.export - def forward_chunk( - self, - xs: torch.Tensor, - xs_lens: torch.Tensor, - offset: int = 0, - context: torch.Tensor = torch.zeros(0, 0, 0), - pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0), - encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0), - upsample_offset: int = 0, - upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0), - upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0) - ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]: - """Embed positions in tensor. - - Args: - xs: padded input tensor (B, T, D) - xs_lens: input length (B) - decoding_chunk_size: decoding chunk size for dynamic chunk - 0: default for training, use random dynamic chunk. - <0: for decoding, use full chunk. - >0: for decoding, use fixed chunk size as set. - num_decoding_left_chunks: number of left chunks, this is for decoding, - the chunk size is decoding_chunk_size. - >=0: use num_decoding_left_chunks - <0: use all left chunks - Returns: - encoder output tensor xs, and subsampled masks - xs: padded output tensor (B, T' ~= T/subsample_rate, D) - masks: torch.Tensor batch padding mask after subsample - (B, 1, T' ~= T/subsample_rate) - NOTE(xcsong): - We pass the `__call__` method of the modules instead of `forward` to the - checkpointing API because `__call__` attaches all the hooks of the module. - https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 - """ - assert xs.size(0) == 1 - # tmp_masks is just for interface compatibility - tmp_masks = torch.ones(1, - xs.size(1), - device=xs.device, - dtype=torch.bool) - tmp_masks = tmp_masks.unsqueeze(1) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) - xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) - offset += xs.size(1) - tmp_masks = torch.ones(1, - context.size(1), - device=context.device, - dtype=torch.bool) - tmp_masks = tmp_masks.unsqueeze(1) - if context.size(1) != 0: - context, _, _ = self.embed(context, tmp_masks, offset) - - # lookahead + conformer encoder - xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache) - # NOTE in cache mode we do not need to call add_optional_chunk_mask - chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device) - mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device) - encoders_kv_cache_list = [] - for index, layer in enumerate(self.encoders): - xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index]) - encoders_kv_cache_list.append(encoders_kv_cache_new) - encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0) - - # upsample - xs = xs.transpose(1, 2).contiguous() - xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache) - xs = xs.transpose(1, 2).contiguous() - - # tmp_masks is just for interface compatibility - tmp_masks = torch.ones(1, - xs.size(1), - device=xs.device, - dtype=torch.bool) - tmp_masks = tmp_masks.unsqueeze(1) - xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset) - upsample_offset += xs.size(1) - - # conformer encoder - chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device) - mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device) - upsample_kv_cache_list = [] - for index, layer in enumerate(self.up_encoders): - xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index]) - upsample_kv_cache_list.append(upsample_kv_cache_new) - upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0) - - if self.normalize_before: - xs = self.after_norm(xs) - # Here we assume the mask is not changed in encoder layers, so just - # return the masks before encoder layers, and the masks will be used - # for cross attention with decoder later - return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache) diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index 80eafaf..ae860c9 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -56,7 +56,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): network = builder.create_network(network_flags) parser = trt.OnnxParser(network, logger) config = builder.create_builder_config() - config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) # 1GB if fp16: config.set_flag(trt.BuilderFlag.FP16) profile = builder.create_optimization_profile() diff --git a/cosyvoice/utils/mask.py b/cosyvoice/utils/mask.py index 35dcd69..c966cc9 100644 --- a/cosyvoice/utils/mask.py +++ b/cosyvoice/utils/mask.py @@ -86,7 +86,7 @@ def subsequent_mask( return mask -def subsequent_chunk_mask( +def subsequent_chunk_mask_deprecated( size: int, chunk_size: int, num_left_chunks: int = -1, @@ -124,6 +124,40 @@ def subsequent_chunk_mask( return ret +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks + pos_idx = torch.arange(size, device=device) + block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size + ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) + return ret + + def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, use_dynamic_chunk: bool, @@ -196,9 +230,6 @@ def add_optional_chunk_mask(xs: torch.Tensor, else: chunk_masks = masks assert chunk_masks.dtype == torch.bool - if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: - print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') - chunk_masks[chunk_masks.sum(dim=-1) == 0] = True return chunk_masks diff --git a/test1.py b/test1.py deleted file mode 100644 index a1243e4..0000000 --- a/test1.py +++ /dev/null @@ -1,37 +0,0 @@ -import sys -sys.path.append('third_party/Matcha-TTS') -from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 -from cosyvoice.utils.file_utils import load_wav -import torchaudio # type: ignore - -cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False) - -# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference -# zero_shot usage -prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) -for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): - torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - -# save zero_shot spk for future usage -assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True -for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)): - torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) -cosyvoice.save_spkinfo() - -# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248 -for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)): - torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - -# instruct usage -for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)): - torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - -# bistream usage, you can use generator as input, this is useful when using text llm model as input -# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length -def text_generator(): - yield '收到好友从远方寄来的生日礼物,' - yield '那份意外的惊喜与深深的祝福' - yield '让我心中充满了甜蜜的快乐,' - yield '笑容如花儿般绽放。' -for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): - torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) \ No newline at end of file