From 98f0e6f2b1edadd81de6fcccb56f6fd0c554c9fa Mon Sep 17 00:00:00 2001 From: czk32611 Date: Thu, 8 Aug 2024 14:56:25 +0800 Subject: [PATCH] Fixed bug in train.py where pe was missing --- train_codes/train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_codes/train.py b/train_codes/train.py index fae0cb2..a7a96b3 100755 --- a/train_codes/train.py +++ b/train_codes/train.py @@ -546,8 +546,12 @@ def main(): latent_model_input = torch.cat([masked_latents, ref_latents], dim=1) audio_feature = audio_feature.to(dtype=weight_dtype) - # Predict the noise residual - image_pred = unet(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample + audio_feature = pe(audio_feature) + + # Predict the noise residual + image_pred = unet(latent_model_input, + timesteps, + encoder_hidden_states=audio_feature).sample if args.reconstruction: # decode the image from the predicted latents image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred