mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
fix cosyvoice3 training
This commit is contained in:
@@ -26,7 +26,7 @@ import pyworld as pw
|
|||||||
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
||||||
|
|
||||||
|
|
||||||
def parquet_opener(data, mode='train', tts_data={}):
|
def parquet_opener(data, mode='train'):
|
||||||
""" Give url or local file, return file descriptor
|
""" Give url or local file, return file descriptor
|
||||||
Inplace operation.
|
Inplace operation.
|
||||||
|
|
||||||
@@ -44,12 +44,8 @@ def parquet_opener(data, mode='train', tts_data={}):
|
|||||||
df = df.to_pandas()
|
df = df.to_pandas()
|
||||||
for i in range(len(df)):
|
for i in range(len(df)):
|
||||||
sample.update(dict(df.loc[i]))
|
sample.update(dict(df.loc[i]))
|
||||||
if mode == 'train':
|
# NOTE do not return sample directly, must initialize a new dict
|
||||||
# NOTE do not return sample directly, must initialize a new dict
|
yield {**sample}
|
||||||
yield {**sample}
|
|
||||||
else:
|
|
||||||
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
|
||||||
yield {**sample, 'tts_index': index, 'tts_text': text}
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
||||||
|
|
||||||
|
|||||||
@@ -332,8 +332,9 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
|
|||||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||||
|
|
||||||
# text encode
|
# text encode
|
||||||
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
h = self.pre_lookahead_layer(token)
|
||||||
h = self.encoder_proj(h)
|
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
|
||||||
|
mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
|
||||||
|
|
||||||
# get conditions
|
# get conditions
|
||||||
conds = torch.zeros(feat.shape, device=token.device)
|
conds = torch.zeros(feat.shape, device=token.device)
|
||||||
@@ -344,7 +345,6 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
|
|||||||
conds[i, :index] = feat[i, :index]
|
conds[i, :index] = feat[i, :index]
|
||||||
conds = conds.transpose(1, 2)
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
|
|
||||||
loss, _ = self.decoder.compute_loss(
|
loss, _ = self.decoder.compute_loss(
|
||||||
feat.transpose(1, 2).contiguous(),
|
feat.transpose(1, 2).contiguous(),
|
||||||
mask.unsqueeze(1),
|
mask.unsqueeze(1),
|
||||||
|
|||||||
@@ -301,18 +301,23 @@ class Qwen2LM(TransformerLM):
|
|||||||
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
||||||
self.vllm_output_queue = {}
|
self.vllm_output_queue = {}
|
||||||
|
|
||||||
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len):
|
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
|
||||||
lm_target, lm_input = [], []
|
lm_target, lm_input = [], []
|
||||||
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
||||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||||
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
|
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
|
||||||
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
|
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
|
||||||
|
# NOTE add instruct_token in CosyVoice3
|
||||||
|
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
|
||||||
|
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
|
||||||
|
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
|
||||||
for i in range(len(text_token)):
|
for i in range(len(text_token)):
|
||||||
# bistream sequence
|
# bistream sequence
|
||||||
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
||||||
this_lm_target, this_lm_input = [], []
|
this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
|
||||||
this_lm_target.append(IGNORE_ID)
|
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
|
||||||
this_lm_input.append(sos_emb.squeeze(dim=0))
|
this_lm_target += [IGNORE_ID] * instruct_token_len[i]
|
||||||
|
this_lm_input.append(instruct_token_emb[i])
|
||||||
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
||||||
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
||||||
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
||||||
@@ -333,8 +338,8 @@ class Qwen2LM(TransformerLM):
|
|||||||
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
|
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
|
||||||
# unistream sequence
|
# unistream sequence
|
||||||
else:
|
else:
|
||||||
this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
|
this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
|
||||||
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
|
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
|
||||||
lm_target.append(this_lm_target)
|
lm_target.append(this_lm_target)
|
||||||
lm_input.append(this_lm_input)
|
lm_input.append(this_lm_input)
|
||||||
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
||||||
@@ -681,6 +686,7 @@ class CosyVoice3LM(Qwen2LM):
|
|||||||
|
|
||||||
# 1. encode text_token
|
# 1. encode text_token
|
||||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||||
|
instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
|
||||||
|
|
||||||
# 3. sos and task_id
|
# 3. sos and task_id
|
||||||
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
@@ -691,7 +697,7 @@ class CosyVoice3LM(Qwen2LM):
|
|||||||
|
|
||||||
# 3. prepare llm_input/target
|
# 3. prepare llm_input/target
|
||||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
||||||
speech_token, speech_token_emb, speech_token_len)
|
speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
|
||||||
lm_target = lm_target.to(device)
|
lm_target = lm_target.to(device)
|
||||||
|
|
||||||
# 4. run lm forward
|
# 4. run lm forward
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ def init_distributed(args):
|
|||||||
def init_dataset_and_dataloader(args, configs, gan, dpo):
|
def init_dataset_and_dataloader(args, configs, gan, dpo):
|
||||||
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
|
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
|
||||||
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
|
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
|
||||||
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=False, partition=False)
|
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='dev', gan=gan, dpo=dpo, shuffle=False, partition=False)
|
||||||
|
|
||||||
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
|
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
|
||||||
train_data_loader = DataLoader(train_dataset,
|
train_data_loader = DataLoader(train_dataset,
|
||||||
@@ -164,18 +164,18 @@ def init_optimizer_and_scheduler(args, configs, model, gan):
|
|||||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||||
|
|
||||||
if configs['train_conf']['optim_d'] == 'adam':
|
if configs['train_conf']['optim_d'] == 'adam':
|
||||||
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
|
||||||
elif configs['train_conf']['optim_d'] == 'adamw':
|
elif configs['train_conf']['optim_d'] == 'adamw':
|
||||||
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
|
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
|
||||||
else:
|
else:
|
||||||
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||||
|
|
||||||
if configs['train_conf']['scheduler_d'] == 'warmuplr':
|
if configs['train_conf']['scheduler_d'] == 'warmuplr':
|
||||||
scheduler_type = WarmupLR
|
scheduler_type = WarmupLR
|
||||||
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_d'])
|
||||||
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
|
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
|
||||||
scheduler_type = NoamHoldAnnealing
|
scheduler_type = NoamHoldAnnealing
|
||||||
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
|
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_d'])
|
||||||
elif configs['train_conf']['scheduler'] == 'constantlr':
|
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||||
scheduler_type = ConstantLR
|
scheduler_type = ConstantLR
|
||||||
scheduler_d = ConstantLR(optimizer_d)
|
scheduler_d = ConstantLR(optimizer_d)
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ filter: !name:cosyvoice.dataset.processor.filter
|
|||||||
resample: !name:cosyvoice.dataset.processor.resample
|
resample: !name:cosyvoice.dataset.processor.resample
|
||||||
resample_rate: !ref <sample_rate>
|
resample_rate: !ref <sample_rate>
|
||||||
truncate: !name:cosyvoice.dataset.processor.truncate
|
truncate: !name:cosyvoice.dataset.processor.truncate
|
||||||
truncate_length: 24480 # must be a multiplier of hop_size
|
truncate_length: 24960 # must be a multiplier of hop_size and token_mel_ratio
|
||||||
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||||
n_fft: 1920
|
n_fft: 1920
|
||||||
num_mels: 80
|
num_mels: 80
|
||||||
|
|||||||
Reference in New Issue
Block a user