add llm train

This commit is contained in:
lyuxiang.lx
2025-02-07 17:17:12 +08:00
parent 2a3e033ee1
commit 79b7dff8d2
3 changed files with 68 additions and 3 deletions

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
from typing import Dict, Optional, Callable, List, Generator from typing import Dict, Optional, Callable, List, Generator
import torch import torch
from torch import nn from torch import nn
@@ -21,6 +22,7 @@ from cosyvoice.utils.common import IGNORE_ID
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy from cosyvoice.utils.common import th_accuracy
from cosyvoice.utils.file_utils import logging from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.mask import make_pad_mask
class TransformerLM(torch.nn.Module): class TransformerLM(torch.nn.Module):
@@ -226,6 +228,17 @@ class Qwen2Encoder(torch.nn.Module):
super().__init__() super().__init__()
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T)
outs = self.model(
inputs_embeds=xs,
attention_mask=masks,
output_hidden_states=True,
return_dict=True,
)
return outs.hidden_states[-1], masks.unsqueeze(1)
def forward_one_step(self, xs, masks, cache=None): def forward_one_step(self, xs, masks, cache=None):
input_masks = masks[:, -1, :] input_masks = masks[:, -1, :]
outs = self.model( outs = self.model(
@@ -280,6 +293,58 @@ class Qwen2LM(TransformerLM):
self.sampling = sampling self.sampling = sampling
self.mix_ratio = mix_ratio self.mix_ratio = mix_ratio
def pad_unpad_sequence(self, sos_eos_emb, text_token, text_token_len, task_id_emb, speech_token, speech_token_len, bistream):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
return lm_input, lm_input_len
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# 1. prepare llm_target
bistream = True if random.random() < 0.5 else False
lm_target = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
[self.speech_token_size]) for i in range(text_token.size(0))]
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
# 1. encode text_token
text_token = self.llm.model.model.embed_tokens(text_token)
# 3. eos and task_id
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 4. encode speech_token
speech_token = self.speech_embedding(speech_token)
# 5. unpad and pad
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, text_token, text_token_len, task_id_emb, speech_token, speech_token_len, bistream)
# 6. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target)
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}
@torch.inference_mode() @torch.inference_mode()
def inference( def inference(
self, self,

View File

@@ -169,7 +169,7 @@ sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic' batch_type: 'dynamic'
max_frames_in_batch: 2500 max_frames_in_batch: 2000
padding: !name:cosyvoice.dataset.processor.padding padding: !name:cosyvoice.dataset.processor.padding
use_spk_embedding: False # change to True during sft use_spk_embedding: False # change to True during sft

View File

@@ -7,7 +7,7 @@ stop_stage=3
data_url=www.openslr.org/resources/60 data_url=www.openslr.org/resources/60
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
pretrained_model_dir=/mnt/lyuxiang.lx/data/tts/models/IIC/CosyVoice2-0.5B/ pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Data Download" echo "Data Download"
@@ -86,7 +86,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
# NOTE will update llm/hift training later # NOTE will update llm/hift training later
for model in flow; do for model in llm flow; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \ torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \ cosyvoice/bin/train.py \