remove unnecessary code

This commit is contained in:
lyuxiang.lx
2024-09-05 14:26:12 +08:00
parent 11eacb810e
commit e141634da1
2 changed files with 1 additions and 3 deletions

View File

@@ -113,7 +113,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
# concat text and prompt_text
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).to(embedding.dtype).unsqueeze(-1).to(embedding)
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode