This commit is contained in:
lyuxiang.lx
2026-01-29 17:31:04 +00:00
parent 84e41729ea
commit c93d3dda01

View File

@@ -405,7 +405,7 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32) batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0) batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
if torch.tensor(['whisper_feat' in sample[i] for i in order]).all(): if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
whisper_feat = [torch.tensor(sample[i]['whisper_feat']) for i in order] whisper_feat = [sample[i]['whisper_feat'] for i in order]
batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32) batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0) batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
if torch.tensor(['speech_token' in sample[i] for i in order]).all(): if torch.tensor(['speech_token' in sample[i] for i in order]).all():