mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 18:09:19 +08:00
Fixed bug in train.py where pe was missing
This commit is contained in:
@@ -546,8 +546,12 @@ def main():
|
|||||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||||
|
|
||||||
audio_feature = audio_feature.to(dtype=weight_dtype)
|
audio_feature = audio_feature.to(dtype=weight_dtype)
|
||||||
# Predict the noise residual
|
audio_feature = pe(audio_feature)
|
||||||
image_pred = unet(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
|
|
||||||
|
# Predict the noise residual
|
||||||
|
image_pred = unet(latent_model_input,
|
||||||
|
timesteps,
|
||||||
|
encoder_hidden_states=audio_feature).sample
|
||||||
|
|
||||||
if args.reconstruction: # decode the image from the predicted latents
|
if args.reconstruction: # decode the image from the predicted latents
|
||||||
image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred
|
image_pred_img = (1 / vae_fp32.config.scaling_factor) * image_pred
|
||||||
|
|||||||
Reference in New Issue
Block a user