diff --git a/train_codes/train.py b/train_codes/train.py index fae0cb2..a7a96b3 100755 --- a/train_codes/train.py +++ b/train_codes/train.py @@ -546,8 +546,12 @@ def main(): latent_model_input = torch.cat([masked_latents, ref_latents], dim=1) audio_feature = audio_feature.to(dtype=weight_dtype) - # Predict the noise residual - image_pred = unet(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample + audio_feature = pe(audio_feature) + + # Predict the noise residual + image_pred = unet(latent_model_input, + timesteps, + encoder_hidden_states=audio_feature).sample if args.reconstruction: # decode the image from the predicted latents image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred