Merge pull request #983 from Shengqiang-Li/main

feat: Support DPO
This commit is contained in:
Xiang Lyu
2025-04-08 12:14:51 +08:00
committed by GitHub
12 changed files with 2156 additions and 3 deletions

187
cosyvoice/bin/train_dpo.py Normal file
View File

@@ -0,0 +1,187 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import datetime
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from copy import deepcopy
import os
import torch
import torch.distributed as dist
import deepspeed
from hyperpyyaml import load_hyperpyyaml
from torch.distributed.elastic.multiprocessing.errors import record
from cosyvoice.utils.executor_dpo import Executor
from cosyvoice.utils.train_utils_dpo import (
init_distributed,
init_dataset_and_dataloader,
init_optimizer_and_scheduler,
init_summarywriter, save_model,
wrap_cuda_model, check_modify_and_save_config)
def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--train_engine',
default='torch_ddp',
choices=['torch_ddp', 'deepspeed'],
help='Engine for paralleled training')
parser.add_argument('--model', required=True, help='model which will be trained')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo'],
help='distributed backend')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--use_amp',
action='store_true',
default=False,
help='Use automatic mixed precision training')
parser.add_argument('--deepspeed.save_states',
dest='save_states',
default='model_only',
choices=['model_only', 'model+optimizer'],
help='save model/optimizer states')
parser.add_argument('--timeout',
default=60,
type=int,
help='timeout (in seconds) of cosyvoice_join.')
parser.add_argument('--dpo',
action='store_true',
default=False,
help='Use Direct Preference Optimization')
parser.add_argument('--beta',
default=0.01,
type=float,
help='beta of dpo training')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
@record
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
# gan train has some special initialization logic
gan = True if args.model == 'hifigan' else False
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
if gan is True:
override_dict.pop('hift')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
if gan is True:
configs['train_conf'] = configs['train_conf_gan']
configs['train_conf'].update(vars(args))
# Init env for ddp
init_distributed(args)
# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs, gan)
# Do some sanity checks and save config to arsg.model_dir
configs = check_modify_and_save_config(args, configs)
# Tensorboard summary
writer = init_summarywriter(args)
# load checkpoint
model = configs[args.model]
ref_model = None
if args.dpo:
ref_model = deepcopy(model)
start_step, start_epoch = 0, -1
if args.checkpoint is not None:
if os.path.exists(args.checkpoint):
state_dict = torch.load(args.checkpoint, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
if args.dpo:
ref_model.load_state_dict(state_dict, strict=False)
if 'step' in state_dict:
start_step = state_dict['step']
if 'epoch' in state_dict:
start_epoch = state_dict['epoch']
else:
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
# Dispatch model from cpu to gpu
model = wrap_cuda_model(args, model)
if args.dpo:
ref_model = wrap_cuda_model(args, ref_model)
# Get optimizer & scheduler
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
if args.dpo:
ref_model, _, _, _, _ = init_optimizer_and_scheduler(args, configs, ref_model, gan)
scheduler.set_step(start_step)
if scheduler_d is not None:
scheduler_d.set_step(start_step)
# Save init checkpoints
info_dict = deepcopy(configs['train_conf'])
info_dict['step'] = start_step
info_dict['epoch'] = start_epoch
save_model(model, 'init', info_dict)
# Get executor
executor = Executor(gan=gan, dpo=args.dpo, beta=args.beta)
executor.step = start_step
# Init scaler, used for pytorch amp mixed precision training
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
print('start step {} start epoch {}'.format(start_step, start_epoch))
# Start training loop
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
executor.epoch = epoch
train_dataset.set_epoch(epoch)
dist.barrier()
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
if gan is True:
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
writer, info_dict, scaler, group_join)
else:
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model)
dist.destroy_process_group(group_join)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,443 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import pyarrow.parquet as pq
from io import BytesIO
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import pyworld as pw
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
def parquet_opener(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
url = sample['src']
try:
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
df = df.to_pandas()
for i in range(len(df)):
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
continue
sample.update(dict(df.loc[i]))
if mode == 'train':
# NOTE do not return sample directly, must initialize a new dict
yield {**sample}
else:
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
yield {**sample, 'tts_index': index, 'tts_text': text}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
def filter(data,
max_length=10240,
min_length=10,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
del sample['audio_data']
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['text_token']) < token_min_length:
continue
if len(sample['text_token']) > token_max_length:
continue
if len(sample['speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['text_token']) / num_frames < min_output_input_ratio:
continue
if len(sample['text_token']) / num_frames > max_output_input_ratio:
continue
yield sample
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
sample_rate = sample['sample_rate']
waveform = sample['speech']
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
continue
sample['sample_rate'] = resample_rate
sample['speech'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
max_val = sample['speech'].abs().max()
if max_val > 1:
sample['speech'] /= max_val
yield sample
def truncate(data, truncate_length=24576, mode='train'):
""" Truncate data.
Args:
data: Iterable[{key, wav, label, sample_rate}]
truncate_length: truncate length
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
waveform = sample['speech']
if waveform.shape[1] > truncate_length:
start = random.randint(0, waveform.shape[1] - truncate_length)
waveform = waveform[:, start: start + truncate_length]
else:
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
sample['speech'] = waveform
yield sample
def compute_fbank(data,
feat_extractor,
mode='train'):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
sample['speech_feat'] = mat
yield sample
def compute_f0(data, sample_rate, hop_size, mode='train'):
""" Extract f0
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
frame_period = hop_size * 1000 / sample_rate
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
sample['pitch_feat'] = f0
yield sample
def parse_embedding(data, normalize, mode='train'):
""" Parse utt_embedding/spk_embedding
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
if normalize:
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
yield sample
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
tokenizer = get_tokenizer()
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if mode == 'inference':
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
yield sample
def shuffle(data, shuffle_size=10000, mode='train'):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500, mode='train'):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['speech_feat'].size(0))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['speech_feat'].size(0))
for x in buf:
yield x
def static_batch(data, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in data:
assert 'speech_feat' in sample
assert isinstance(sample['speech_feat'], torch.Tensor)
new_sample_frames = sample['speech_feat'].size(0)
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
""" Wrapper for static/dynamic batch
"""
if mode == 'inference':
return static_batch(data, 1)
else:
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
dtype=torch.int32)
order = torch.argsort(speech_feat_len, descending=True)
utts = [sample[i]['utt'] for i in order]
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
speech = pad_sequence(speech, batch_first=True, padding_value=0)
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
"utts": utts,
"speech": speech,
"speech_len": speech_len,
"speech_token": speech_token,
"speech_token_len": speech_token_len,
"speech_feat": speech_feat,
"speech_feat_len": speech_feat_len,
"text": text,
"text_token": text_token,
"text_token_len": text_token_len,
"utt_embedding": utt_embedding,
"spk_embedding": spk_embedding,
}
if dpo:
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
reject_speech_token = pad_sequence(reject_speech_token,
batch_first=True,
padding_value=0)
batch['reject_speech_token'] = reject_speech_token
batch['reject_speech_token_len'] = reject_speech_token_len
if gan is True:
# in gan train, we need pitch_feat
pitch_feat = [sample[i]['pitch_feat'] for i in order]
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
pitch_feat = pad_sequence(pitch_feat,
batch_first=True,
padding_value=0)
batch["pitch_feat"] = pitch_feat
batch["pitch_feat_len"] = pitch_feat_len
else:
# only gan train needs speech, delete it to save memory
del batch["speech"]
del batch["speech_len"]
if mode == 'inference':
tts_text = [sample[i]['tts_text'] for i in order]
tts_index = [sample[i]['tts_index'] for i in order]
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
batch.update({'tts_text': tts_text,
'tts_index': tts_index,
'tts_text_token': tts_text_token,
'tts_text_token_len': tts_text_token_len})
if use_spk_embedding is True:
batch["embedding"] = batch["spk_embedding"]
else:
batch["embedding"] = batch["utt_embedding"]
yield batch

556
cosyvoice/llm/llm_dpo.py Normal file
View File

@@ -0,0 +1,556 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Callable, List, Generator
import torch
from torch import nn
import torch.nn.functional as F
from transformers import Qwen2ForCausalLM
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
from cosyvoice.utils.common import IGNORE_ID
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.mask import make_pad_mask
class TransformerLM(torch.nn.Module):
def __init__(
self,
text_encoder_input_size: int,
llm_input_size: int,
llm_output_size: int,
text_token_size: int,
speech_token_size: int,
text_encoder: torch.nn.Module,
llm: torch.nn.Module,
sampling: Callable,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
spk_embed_dim: int = 192,
):
super().__init__()
self.llm_input_size = llm_input_size
self.speech_token_size = speech_token_size
# 1. build text token inputs related modules
self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
self.text_encoder = text_encoder
self.text_encoder_affine_layer = nn.Linear(
self.text_encoder.output_size(),
llm_input_size
)
# 2. build speech token language model related modules
self.sos_eos = 0
self.task_id = 1
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
self.criterion_ce = LabelSmoothingLoss(
size=speech_token_size + 1,
padding_idx=IGNORE_ID,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
# 3. [Optional] build speech token related modules
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
# 4. sampling method
self.sampling = sampling
def encode(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
):
encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out = self.text_encoder_affine_layer(encoder_out)
return encoder_out, encoder_out_lens
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
return lm_input, lm_input_len
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
embedding = batch['embedding'].to(device)
# 1. prepare llm_target
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
[self.speech_token_size]) for i in range(text_token.size(0))]
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
# 1. encode text_token
text_token = self.text_embedding(text_token)
text_token, text_token_len = self.encode(text_token, text_token_len)
# 2. embedding projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
embedding = embedding.unsqueeze(1)
# 3. eos and task_id
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 4. encode speech_token
speech_token = self.speech_embedding(speech_token)
# 5. unpad and pad
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
task_id_emb, speech_token, speech_token_len)
# 6. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target)
acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}
def sampling_ids(
self,
weighted_scores: torch.Tensor,
decoded_tokens: List,
sampling: int,
ignore_eos: bool = True,
):
num_trials, max_trials = 0, 100
while True:
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
if (not ignore_eos) or (self.speech_token_size not in top_ids):
break
num_trials += 1
if num_trials > max_trials:
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
return top_ids
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor, None, None]:
if self.fp16 is True:
embedding = embedding.half()
device = text.device
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
text = self.text_embedding(text)
# 1. encode text
text, text_len = self.encode(text, text_len)
# 2. encode embedding
if embedding.shape[0] != 0:
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
embedding = embedding.unsqueeze(dim=1)
else:
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 5. step by step decode
out_tokens = []
offset = 0
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
for i in range(max_len):
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
att_cache=att_cache, cnn_cache=cnn_cache,
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
# force continue decode first token
if i == 0:
logp[:, self.speech_token_size] = -float('inf')
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
offset += lm_input.size(1)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
class Qwen2Encoder(torch.nn.Module):
def __init__(self, pretrain_path):
super().__init__()
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
def forward_one_step(self, xs, masks, cache=None):
input_masks = masks[:, -1, :]
outs = self.model(
inputs_embeds=xs,
attention_mask=input_masks,
output_hidden_states=True,
return_dict=True,
use_cache=True,
past_key_values=cache,
)
xs = outs.hidden_states[-1]
new_cache = outs.past_key_values
return xs, new_cache
class Qwen2LM(TransformerLM):
def __init__(
self,
llm_input_size: int,
llm_output_size: int,
speech_token_size: int,
llm: torch.nn.Module,
sampling: Callable,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
mix_ratio: List[int] = [5, 15],
dpo: bool = False,
):
torch.nn.Module.__init__(self)
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
# 2. build speech token language model related modules
self.sos_eos = 0
self.task_id = 1
self.fill_token = 2
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
self.criterion_ce = LabelSmoothingLoss(
size=speech_token_size + 3,
padding_idx=IGNORE_ID,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
# 3. [Optional] build speech token related modules
self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
# 4. sampling method
self.sampling = sampling
self.mix_ratio = mix_ratio
# 5. [Optional] set dpo
self.dpo = dpo
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
if self.dpo:
reject_speech_token = batch['reject_speech_token'].to(device)
reject_speech_token_len = batch['reject_speech_token_len'].to(device)
# 1. prepare llm_target
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
[self.speech_token_size]) for i in range(text_token.size(0))]
if self.dpo:
reject_target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + reject_speech_token[i, :reject_speech_token_len[i]].tolist() +
[self.speech_token_size]) for i in range(text_token.size(0))]
target_ids.extend(reject_target_ids)
target_ids = pad_sequence(target_ids, batch_first=True, padding_value=IGNORE_ID).to(device)
# 2. speech token projection
speech_emb = self.speech_embedding(speech_token)
if self.dpo:
reject_speech_emb = self.speech_embedding(reject_speech_token)
# 3. text token projection
text_token_lst = unpad_sequence(text_token, text_token_len, batch_first=True)
text_emb = [self.llm.model.model.embed_tokens(y) for y in text_token_lst]
# 4. prepare llm_input
speech_emb = unpad_sequence(speech_emb, speech_token_len.cpu(), batch_first=True)
input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), speech_emb[i]], dim=0)
for i in range(len(text_emb))]
if self.dpo:
reject_speech_emb = unpad_sequence(reject_speech_emb, reject_speech_token_len.cpu(), batch_first=True)
reject_input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), reject_speech_emb[i]], dim=0)
for i in range(len(text_emb))]
input_emb.extend(reject_input_emb)
input_emb_lengths = torch.tensor([i.size(0) for i in input_emb], dtype=torch.int32).to(device)
input_emb = pad_sequence(input_emb, batch_first=True, padding_value=IGNORE_ID).to(device)
attention_mask = ~make_pad_mask(input_emb_lengths)
result = self.llm.model(
inputs_embeds=input_emb,
attention_mask=attention_mask,
return_dict=True
)
hidden_states = result.hidden_states
logits = self.llm_decoder(hidden_states[-1])
loss = self.criterion_ce(logits[: speech_token.shape[0]], target_ids[: speech_token.shape[0]])
acc = th_accuracy(
logits[: speech_token.shape[0]].view(-1, self.speech_token_size + 3),
target_ids[: speech_token.shape[0]],
ignore_label=IGNORE_ID,
)
if not self.dpo:
return {
"loss": loss,
"acc": acc,
}
else:
all_logps_sum, all_logps_mean = self.get_batch_logps(
logits, target_ids, attention_mask, text_token_len, average_log_prob=False, ignore_id=IGNORE_ID
)
chosen_logps = all_logps_sum[: speech_token.shape[0]]
rejected_logps = all_logps_sum[speech_token.shape[0]:]
return {
"loss": loss,
"acc": acc,
"chosen_logps": chosen_logps,
"rejected_logps": rejected_logps
}
def get_batch_logps(
self,
logits: torch.FloatTensor,
labels: torch.LongTensor,
attention_mask,
prompt_token_lens,
average_log_prob: bool = False,
ignore_id: int = -1,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
assert average_log_prob == False
assert logits.shape[:-1] == labels.shape
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_masks = attention_mask.clone().bool()
# mask prompts
for mask, text_token_len in zip(loss_masks, prompt_token_lens):
mask[:text_token_len + 1] = False
loss_masks = loss_masks[:, 1:]
labels[loss_masks == False] = 0
# dummy token; we'll ignore the losses on these tokens later
ignore = labels == ignore_id
labels = labels.masked_fill(ignore, 0) # avoid -1 index
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) # (bs, time,)
logprobs_sums = (per_token_logps * loss_masks).sum(-1)
logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1)
return logprobs_sums, logprobs_means
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
text = self.llm.model.model.embed_tokens(text)
# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 5. step by step decode
out_tokens = []
cache = None
for i in range(max_len):
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
if top_ids > self.speech_token_size:
continue
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
@torch.inference_mode()
def inference_bistream(
self,
text: Generator,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor, None, None]:
device = prompt_text.device
# 1. prepare input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb], dim=1)
# 2. iterate text
out_tokens = []
cache = None
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
text_cache = self.llm.model.model.embed_tokens(prompt_text)
next_fill_index = -1
for this_text in text:
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
# prompt_speech_token_emb not empty, try append to lm_input
while prompt_speech_token_emb.size(1) != 0:
if text_cache.size(1) >= self.mix_ratio[0]:
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
else:
logging.info('not enough text token to decode, wait for more')
break
# no prompt_speech_token_emb remain, can decode some speech token
if prompt_speech_token_emb.size(1) == 0:
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
logging.info('get fill token, need to append more text token')
if text_cache.size(1) >= self.mix_ratio[0]:
lm_input_text = text_cache[:, :self.mix_ratio[0]]
logging.info('append {} text token'.format(lm_input_text.size(1)))
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
lm_input = lm_input_text
else:
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
text_cache = text_cache[:, self.mix_ratio[0]:]
else:
logging.info('not enough text token to decode, wait for more')
continue
while True:
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
top_ids = self.speech_token_size + 2
next_fill_index += (self.mix_ratio[1] + 1)
else:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
if top_ids == self.speech_token_size + 2:
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size + 2:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
yield top_ids
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
# 3. final decode
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
logging.info('no more text token, decode until met eos')
while True:
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
# in stream mode, yield token one by one
yield top_ids
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)

View File

@@ -0,0 +1,184 @@
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import nullcontext
import os
import torch
import torch.distributed as dist
from cosyvoice.utils.train_utils_dpo import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
from cosyvoice.utils.losses_dpo import DPOLoss
class Executor:
def __init__(self, gan: bool = False, dpo: bool = False, beta: float = 0.01, label_smoothing: float = 0.0, ipo: bool = False):
self.gan = gan
self.step = 0
self.epoch = 0
self.rank = int(os.environ.get('RANK', 0))
self.device = torch.device('cuda:{}'.format(self.rank))
self.dpo = dpo
if self.dpo:
self.dpo_loss = DPOLoss(beta, label_smoothing, ipo)
else:
self.dpo_loss = None
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
''' Train one epoch
'''
lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
logging.info('using accumulate grad, new batch size is {} times'
' larger than before'.format(info_dict['accum_grad']))
# A context manager to be used in conjunction with an instance of
# torch.nn.parallel.DistributedDataParallel to be able to train
# with uneven inputs across participating processes.
model.train()
if self.dpo:
assert ref_model is not None
ref_model.eval()
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
with model_context():
for batch_idx, batch_dict in enumerate(train_data_loader):
info_dict["tag"] = "TRAIN"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
if cosyvoice_join(group_join, info_dict):
break
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
context = model.no_sync
# Used for single gpu training and DDP gradient synchronization
# processes.
else:
context = nullcontext
with context():
info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model, self.dpo_loss)
info_dict = batch_backward(model, scaler, info_dict)
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
log_per_step(writer, info_dict)
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
(batch_idx + 1) % info_dict["accum_grad"] == 0:
dist.barrier()
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False, ref_model=ref_model, dpo_loss=self.dpo_loss)
model.train()
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
self.step += 1
dist.barrier()
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True, ref_model=ref_model, dpo_loss=self.dpo_loss)
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
writer, info_dict, scaler, group_join):
''' Train one epoch
'''
lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
logging.info('using accumulate grad, new batch size is {} times'
' larger than before'.format(info_dict['accum_grad']))
# A context manager to be used in conjunction with an instance of
# torch.nn.parallel.DistributedDataParallel to be able to train
# with uneven inputs across participating processes.
model.train()
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
with model_context():
for batch_idx, batch_dict in enumerate(train_data_loader):
info_dict["tag"] = "TRAIN"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
if cosyvoice_join(group_join, info_dict):
break
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
context = model.no_sync
# Used for single gpu training and DDP gradient synchronization
# processes.
else:
context = nullcontext
with context():
batch_dict['turn'] = 'discriminator'
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
info_dict = batch_backward(model, scaler, info_dict)
info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
optimizer.zero_grad()
log_per_step(writer, info_dict)
with context():
batch_dict['turn'] = 'generator'
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
info_dict = batch_backward(model, scaler, info_dict)
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
optimizer_d.zero_grad()
log_per_step(writer, info_dict)
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
(batch_idx + 1) % info_dict["accum_grad"] == 0:
dist.barrier()
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
model.train()
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
self.step += 1
dist.barrier()
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
@torch.inference_mode()
def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, ref_model=None, dpo_loss=None):
''' Cross validation on
'''
logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
model.eval()
if self.dpo:
assert ref_model is not None
ref_model.eval()
total_num_utts, total_loss_dict = 0, {} # avoid division by 0
for batch_idx, batch_dict in enumerate(cv_data_loader):
info_dict["tag"] = "CV"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
num_utts = len(batch_dict["utts"])
total_num_utts += num_utts
if self.gan is True:
batch_dict['turn'] = 'generator'
info_dict = batch_forward(model, batch_dict, None, info_dict, ref_model, dpo_loss)
for k, v in info_dict['loss_dict'].items():
if k not in total_loss_dict:
total_loss_dict[k] = []
total_loss_dict[k].append(v.item() * num_utts)
log_per_step(None, info_dict)
for k, v in total_loss_dict.items():
total_loss_dict[k] = sum(v) / total_num_utts
info_dict['loss_dict'] = total_loss_dict
log_per_save(writer, info_dict)
model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
save_model(model, model_name, info_dict)

View File

@@ -0,0 +1,57 @@
import torch
import torch.nn.functional as F
from typing import Tuple
def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
loss = 0
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
m_DG = torch.median((dr - dg))
L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
loss += tau - F.relu(tau - L_rel)
return loss
def mel_loss(real_speech, generated_speech, mel_transforms):
loss = 0
for transform in mel_transforms:
mel_r = transform(real_speech)
mel_g = transform(generated_speech)
loss += F.l1_loss(mel_g, mel_r)
return loss
class DPOLoss(torch.nn.Module):
"""
DPO Loss
"""
def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
super().__init__()
self.beta = beta
self.label_smoothing = label_smoothing
self.ipo = ipo
def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
logits = pi_logratios - ref_logratios
if self.ipo:
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
else:
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
loss = losses.mean()
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
return loss, chosen_rewards, rejected_rewards

View File

@@ -0,0 +1,364 @@
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2023 Horizon Inc. (authors: Xingchen Song)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import torch
import json
import re
import datetime
import yaml
import deepspeed
import torch.optim as optim
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
from cosyvoice.dataset.dataset import Dataset
from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
def init_distributed(args):
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
', rank {}, world_size {}'.format(rank, world_size))
if args.train_engine == 'torch_ddp':
torch.cuda.set_device(local_rank)
dist.init_process_group(args.dist_backend)
else:
deepspeed.init_distributed(dist_backend=args.dist_backend)
return world_size, local_rank, rank
def init_dataset_and_dataloader(args, configs, gan):
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True)
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False)
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
train_data_loader = DataLoader(train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
cv_data_loader = DataLoader(cv_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
def check_modify_and_save_config(args, configs):
if args.train_engine == "torch_ddp":
configs['train_conf']["dtype"] = 'fp32'
else:
with open(args.deepspeed_config, 'r') as fin:
ds_configs = json.load(fin)
if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
configs['train_conf']["dtype"] = "fp16"
elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
configs['train_conf']["dtype"] = "bf16"
else:
configs['train_conf']["dtype"] = "fp32"
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
# if use deepspeed, override ddp config
configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
return configs
def wrap_cuda_model(args, model):
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
world_size = int(os.environ.get('WORLD_SIZE', 1))
if args.train_engine == "torch_ddp": # native pytorch ddp
assert (torch.cuda.is_available())
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
else:
if int(os.environ.get('RANK', 0)) == 0:
logging.info("Estimating model states memory needs (zero2)...")
estimate_zero2_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size)
return model
def init_optimizer_and_scheduler(args, configs, model, gan):
if gan is False:
if configs['train_conf']['optim'] == 'adam':
optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
elif configs['train_conf']['optim'] == 'adamw':
optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
else:
raise ValueError("unknown optimizer: " + configs['train_conf'])
if configs['train_conf']['scheduler'] == 'warmuplr':
scheduler_type = WarmupLR
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'constantlr':
scheduler_type = ConstantLR
scheduler = ConstantLR(optimizer)
else:
raise ValueError("unknown scheduler: " + configs['train_conf'])
# use deepspeed optimizer for speedup
if args.train_engine == "deepspeed":
def scheduler(opt):
return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
model, optimizer, _, scheduler = deepspeed.initialize(
args=args,
model=model,
optimizer=None,
lr_scheduler=scheduler,
model_parameters=model.parameters())
optimizer_d, scheduler_d = None, None
else:
# currently we wrap generator and discriminator in one model, so we cannot use deepspeed
if configs['train_conf']['optim'] == 'adam':
optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
elif configs['train_conf']['optim'] == 'adamw':
optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
else:
raise ValueError("unknown optimizer: " + configs['train_conf'])
if configs['train_conf']['scheduler'] == 'warmuplr':
scheduler_type = WarmupLR
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'constantlr':
scheduler_type = ConstantLR
scheduler = ConstantLR(optimizer)
else:
raise ValueError("unknown scheduler: " + configs['train_conf'])
if configs['train_conf']['optim_d'] == 'adam':
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
elif configs['train_conf']['optim_d'] == 'adamw':
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
else:
raise ValueError("unknown optimizer: " + configs['train_conf'])
if configs['train_conf']['scheduler_d'] == 'warmuplr':
scheduler_type = WarmupLR
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'constantlr':
scheduler_type = ConstantLR
scheduler_d = ConstantLR(optimizer_d)
else:
raise ValueError("unknown scheduler: " + configs['train_conf'])
return model, optimizer, scheduler, optimizer_d, scheduler_d
def init_summarywriter(args):
writer = None
if int(os.environ.get('RANK', 0)) == 0:
os.makedirs(args.model_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
return writer
def save_model(model, model_name, info_dict):
rank = int(os.environ.get('RANK', 0))
model_dir = info_dict["model_dir"]
save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
if info_dict["train_engine"] == "torch_ddp":
if rank == 0:
torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
else:
with torch.no_grad():
model.save_checkpoint(save_dir=model_dir,
tag=model_name,
client_state=info_dict)
if rank == 0:
info_path = re.sub('.pt$', '.yaml', save_model_path)
info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
with open(info_path, 'w') as fout:
data = yaml.dump(info_dict)
fout.write(data)
logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
def cosyvoice_join(group_join, info_dict):
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
if info_dict["batch_idx"] != 0:
# we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
try:
dist.monitored_barrier(group=group_join,
timeout=group_join.options._timeout)
return False
except RuntimeError as e:
logging.info("Detected uneven workload distribution: {}\n".format(e) +
"Break current worker to manually join all workers, " +
"world_size {}, current rank {}, current local_rank {}\n".
format(world_size, rank, local_rank))
return True
else:
return False
def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None):
device = int(os.environ.get('LOCAL_RANK', 0))
dtype = info_dict["dtype"]
if dtype == "fp16":
dtype = torch.float16
elif dtype == "bf16":
dtype = torch.bfloat16
else: # fp32
dtype = torch.float32
if info_dict['train_engine'] == 'torch_ddp':
autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
else:
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
with autocast:
info_dict['loss_dict'] = model(batch, device)
if ref_model and dpo_loss:
chosen_logps = info_dict['loss_dict']["chosen_logps"]
rejected_logps = info_dict['loss_dict']["rejected_logps"]
sft_loss = info_dict['loss_dict']['loss']
with torch.no_grad():
ref_model = ref_model.to(device)
ref_loss_dict = ref_model(batch, device)
reference_chosen_logps = ref_loss_dict["chosen_logps"]
reference_rejected_logps = ref_loss_dict["rejected_logps"]
preference_loss, chosen_reward, reject_reward = dpo_loss(
chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps
)
dpo_acc = (chosen_reward > reject_reward).float().mean()
info_dict['loss_dict']["loss"] = preference_loss + sft_loss
info_dict['loss_dict']["sft_loss"] = sft_loss
info_dict['loss_dict']["dpo_loss"] = preference_loss
info_dict['loss_dict']["dpo_acc"] = dpo_acc
info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean()
info_dict['loss_dict']["reject_reward"] = reject_reward.mean()
return info_dict
def batch_backward(model, scaler, info_dict):
if info_dict["train_engine"] == "deepspeed":
scaled_loss = model.backward(info_dict['loss_dict']['loss'])
else:
scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
if scaler is not None:
scaler.scale(scaled_loss).backward()
else:
scaled_loss.backward()
info_dict['loss_dict']['loss'] = scaled_loss
return info_dict
def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
grad_norm = 0.0
if info_dict['train_engine'] == "deepspeed":
info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
model.step()
grad_norm = model.get_global_grad_norm()
elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
# Use mixed precision training
if scaler is not None:
scaler.unscale_(optimizer)
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
# We don't check grad here since that if the gradient
# has inf/nan values, scaler.step will skip
# optimizer.step().
if torch.isfinite(grad_norm):
scaler.step(optimizer)
scaler.update()
else:
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
if torch.isfinite(grad_norm):
optimizer.step()
optimizer.zero_grad()
scheduler.step()
info_dict["lr"] = optimizer.param_groups[0]['lr']
info_dict["grad_norm"] = grad_norm
return info_dict
def log_per_step(writer, info_dict):
tag = info_dict["tag"]
epoch = info_dict.get('epoch', 0)
step = info_dict["step"]
batch_idx = info_dict["batch_idx"]
loss_dict = info_dict['loss_dict']
rank = int(os.environ.get('RANK', 0))
# only rank 0 write to tensorboard to avoid multi-process write
if writer is not None:
if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
(info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
for k in ['epoch', 'lr', 'grad_norm']:
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
for k, v in loss_dict.items():
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
# TRAIN & CV, Shell log (stdout)
if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
for name, value in loss_dict.items():
log_str += '{} {:.6f} '.format(name, value)
if tag == "TRAIN":
log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
info_dict["lr"], info_dict['grad_norm'])
log_str += ' rank {}'.format(rank)
logging.debug(log_str)
def log_per_save(writer, info_dict):
tag = info_dict["tag"]
epoch = info_dict["epoch"]
step = info_dict["step"]
loss_dict = info_dict["loss_dict"]
lr = info_dict['lr']
rank = int(os.environ.get('RANK', 0))
logging.info(
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
if writer is not None:
for k in ['epoch', 'lr']:
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
for k, v in loss_dict.items():
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)

View File

@@ -1,4 +1,4 @@
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
ARG VENV_NAME="cosyvoice"
ENV VENV=$VENV_NAME

View File

@@ -0,0 +1,226 @@
# set random seed, so that you may reproduce your result.
__set_seed1: !apply:random.seed [1986]
__set_seed2: !apply:numpy.random.seed [1986]
__set_seed3: !apply:torch.manual_seed [1986]
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
# fixed params
sample_rate: 24000 # 16000 for llm, 24000 for cfm
llm_input_size: 896
llm_output_size: 896
spk_embed_dim: 192
qwen_pretrain_path: 'CosyVoice2-0.5B/CosyVoice-BlankEN'
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:cosyvoice.llm.llm_dpo.Qwen2LM
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
speech_token_size: 6561
length_normalized_loss: True
lsm_weight: 0
dpo: True
llm: !new:cosyvoice.llm.llm.Qwen2Encoder
pretrain_path: !ref <qwen_pretrain_path>
sampling: !name:cosyvoice.utils.common.ras_sampling
top_p: 0.8
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
input_size: 512
output_size: 80
spk_embed_dim: !ref <spk_embed_dim>
output_type: 'mel'
vocab_size: 6561
input_frame_rate: 25
only_mask_loss: True
token_mel_ratio: 2
pre_lookahead_len: 3
encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
output_size: 512
attention_heads: 8
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 512
use_cnn_module: False
macaron_style: False
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
in_channels: 240
n_spks: 1
spk_emb_dim: 80
cfm_params: !new:omegaconf.DictConfig
content:
sigma_min: 1e-06
solver: 'euler'
t_scheduler: 'cosine'
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
in_channels: 320
out_channels: 80
causal: True
channels: [256]
dropout: 0.0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 12
num_heads: 8
act_fn: 'gelu'
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
sampling_rate: !ref <sample_rate>
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 5, 3]
upsample_kernel_sizes: [16, 11, 7]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
# gan related module
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
fmin: 0
fmax: null
center: False
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
generator: !ref <hift>
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
mel_spec_transform: [
!ref <mel_spec_transform1>
]
# processor functions
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
multilingual: True
num_languages: 100
language: 'en'
task: 'transcribe'
allowed_special: 'all'
tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample: !name:cosyvoice.dataset.processor.resample
resample_rate: !ref <sample_rate>
truncate: !name:cosyvoice.dataset.processor.truncate
truncate_length: 24576 # must be a multiplier of hop_size
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
fmin: 0
fmax: 8000
center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate>
hop_size: 256
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
normalize: True
shuffle: !name:cosyvoice.dataset.processor.shuffle
shuffle_size: 1000
sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g
padding: !name:cosyvoice.dataset.processor.padding
use_spk_embedding: True # change to True during sft
dpo: True
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <compute_fbank>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
data_pipeline_gan: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <truncate>,
!ref <compute_fbank>,
!ref <compute_f0>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
# llm flow train conf
train_conf:
optim: adam
optim_conf:
lr: 0.00001 # change to 1e-5 during sft
scheduler: warmuplr # change to constantlr during sft
scheduler_conf:
warmup_steps: 25000
max_epoch: 200
grad_clip: 5
accum_grad: 2
log_interval: 100
save_per_step: -1
# gan train conf
train_conf_gan:
optim: adam
optim_conf:
lr: 0.0002 # use small lr for gan training
scheduler: constantlr
optim_d: adam
optim_conf_d:
lr: 0.0002 # use small lr for gan training
scheduler_d: constantlr
max_epoch: 200
grad_clip: 5
accum_grad: 1 # in gan training, accum_grad must be 1
log_interval: 100
save_per_step: -1

View File

@@ -18,7 +18,7 @@ networkx==3.1
omegaconf==2.3.0
onnx==1.16.0
onnxruntime-gpu==1.18.0; sys_platform == 'linux'
onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'windows'
onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32'
openai-whisper==20231117
protobuf==4.25
pyarrow==18.1.0

View File

@@ -72,6 +72,14 @@ async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instr
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
return StreamingResponse(generate_data(model_output))
@app.get("/inference_instruct2")
@app.post("/inference_instruct2")
async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File()):
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k)
return StreamingResponse(generate_data(model_output))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
@@ -90,4 +98,4 @@ if __name__ == '__main__':
cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
uvicorn.run(app, host="0.0.0.0", port=args.port)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@@ -27,6 +27,9 @@ def single_job(utt):
audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile')
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
# Convert audio to mono
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
if audio.shape[1] / 16000 > 30:
logging.warning('do not support extract speech token for audio longer than 30s')
speech_token = []

125
tools/make_parquet_list_dpo.py Executable file
View File

@@ -0,0 +1,125 @@
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import json
from tqdm import tqdm
import pandas as pd
import multiprocessing
import time
import torch
def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
start_time = time.time()
data_list = []
for utt in tqdm(utt_list):
data = open(utt2wav[utt], 'rb').read()
data_list.append(data)
wav_list = [utt2wav[utt] for utt in utt_list]
text_list = [utt2text[utt] for utt in utt_list]
spk_list = [utt2spk[utt] for utt in utt_list]
uttembedding_list = [utt2embedding[utt] for utt in utt_list]
spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
speech_token_list = [utt2speech_token[utt] for utt in utt_list]
if utt2reject_speech_token:
reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
# 保存到parquet,utt2parquet_file,spk2parquet_file
df = pd.DataFrame()
df['utt'] = utt_list
df['wav'] = wav_list
df['audio_data'] = data_list
df['text'] = text_list
df['spk'] = spk_list
df['utt_embedding'] = uttembedding_list
df['spk_embedding'] = spkembedding_list
df['speech_token'] = speech_token_list
if utt2reject_speech_token:
df['reject_speech_token'] = reject_speech_token_list
df.to_parquet(parquet_file)
with open(utt2parquet_file, 'w') as f:
json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
with open(spk2parquet_file, 'w') as f:
json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
logging.info('spend time {}'.format(time.time() - start_time))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_utts_per_parquet',
type=int,
default=1000,
help='num utts per parquet')
parser.add_argument('--num_processes',
type=int,
default=1,
help='num processes for make parquets')
parser.add_argument('--src_dir',
type=str)
parser.add_argument('--des_dir',
type=str)
parser.add_argument('--dpo',
action='store_true',
default=False,
help='Use Direct Preference Optimization')
args = parser.parse_args()
utt2wav, utt2text, utt2spk = {}, {}, {}
with open('{}/wav.scp'.format(args.src_dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
with open('{}/text'.format(args.src_dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2text[l[0]] = ' '.join(l[1:])
with open('{}/utt2spk'.format(args.src_dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2spk[l[0]] = l[1]
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
if args.dpo:
utt2reject_speech_token = torch.load('{}/utt2reject_speech_token.pt'.format(args.src_dir))
else:
utt2reject_speech_token = None
utts = list(utt2wav.keys())
# Using process pool to speedup
pool = multiprocessing.Pool(processes=args.num_processes)
parquet_list, utt2parquet_list, spk2parquet_list = [], [], []
for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)):
parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i))
utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i))
spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i))
parquet_list.append(parquet_file)
utt2parquet_list.append(utt2parquet_file)
spk2parquet_list.append(spk2parquet_file)
pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file))
pool.close()
pool.join()
with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \
open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \
open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3:
for name in parquet_list:
f1.write(name + '\n')
for name in utt2parquet_list:
f2.write(name + '\n')
for name in spk2parquet_list:
f3.write(name + '\n')