mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 17:39:20 +08:00
Fixed bug in train.py where pe was missing
This commit is contained in:
@@ -546,8 +546,12 @@ def main():
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
|
||||
audio_feature = audio_feature.to(dtype=weight_dtype)
|
||||
# Predict the noise residual
|
||||
image_pred = unet(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
|
||||
audio_feature = pe(audio_feature)
|
||||
|
||||
# Predict the noise residual
|
||||
image_pred = unet(latent_model_input,
|
||||
timesteps,
|
||||
encoder_hidden_states=audio_feature).sample
|
||||
|
||||
if args.reconstruction: # decode the image from the predicted latents
|
||||
image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred
|
||||
|
||||
Reference in New Issue
Block a user