Fixed bug in train.py where pe was missing

This commit is contained in:
czk32611
2024-08-08 14:56:25 +08:00
committed by GitHub
parent 1de8261491
commit 98f0e6f2b1

View File

@@ -546,8 +546,12 @@ def main():
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
audio_feature = audio_feature.to(dtype=weight_dtype)
# Predict the noise residual
image_pred = unet(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
audio_feature = pe(audio_feature)
# Predict the noise residual
image_pred = unet(latent_model_input,
timesteps,
encoder_hidden_states=audio_feature).sample
if args.reconstruction: # decode the image from the predicted latents
image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred