This commit is contained in:
lyuxiang.lx
2026-01-29 10:29:22 +00:00
parent f26cde56df
commit 84e41729ea
4 changed files with 20 additions and 13 deletions

View File

@@ -189,7 +189,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
if 'speech_token' not in batch:
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'])
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
@@ -322,8 +322,11 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
if 'speech_token' not in batch:
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)