mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 09:59:18 +08:00
130 lines
4.9 KiB
Python
130 lines
4.9 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import time
|
||
import math
|
||
from utils import decode_latents, preprocess_img_tensor
|
||
import os
|
||
from PIL import Image
|
||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||
from diffusers import (
|
||
AutoencoderKL,
|
||
UNet2DConditionModel,
|
||
)
|
||
from torch import Tensor, nn
|
||
import logging
|
||
import json
|
||
|
||
RESIZED_IMG = 256
|
||
|
||
class PositionalEncoding(nn.Module):
|
||
"""
|
||
Transformer 中的位置编码(positional encoding)
|
||
"""
|
||
def __init__(self, d_model=384, max_len=5000):
|
||
super(PositionalEncoding, self).__init__()
|
||
pe = torch.zeros(max_len, d_model)
|
||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||
pe[:, 0::2] = torch.sin(position * div_term)
|
||
pe[:, 1::2] = torch.cos(position * div_term)
|
||
pe = pe.unsqueeze(0)
|
||
self.register_buffer('pe', pe)
|
||
|
||
def forward(self, x):
|
||
b, seq_len, d_model = x.size()
|
||
pe = self.pe[:, :seq_len, :]
|
||
#print(b, seq_len, d_model)
|
||
x = x + pe.to(x.device)
|
||
return x
|
||
|
||
def validation(vae: torch.nn.Module,
|
||
vae_fp32: torch.nn.Module,
|
||
unet:torch.nn.Module,
|
||
unet_config,
|
||
weight_dtype: torch.dtype,
|
||
epoch: int,
|
||
global_step: int,
|
||
val_data_loader,
|
||
output_dir,
|
||
whisper_model_type,
|
||
UNet2DConditionModel=UNet2DConditionModel
|
||
):
|
||
|
||
# Get the validation pipeline
|
||
unet_copy = UNet2DConditionModel(**unet_config)
|
||
|
||
unet_copy.load_state_dict(unet.state_dict())
|
||
unet_copy.to(vae.device).to(dtype=weight_dtype)
|
||
unet_copy.eval()
|
||
|
||
if whisper_model_type == "tiny":
|
||
pe = PositionalEncoding(d_model=384)
|
||
elif whisper_model_type == "largeV2":
|
||
pe = PositionalEncoding(d_model=1280)
|
||
elif whisper_model_type == "tiny-conv":
|
||
pe = PositionalEncoding(d_model=384)
|
||
print(f" whisper_model_type: {whisper_model_type} Validation does not need PE")
|
||
else:
|
||
print(f"not support whisper_model_type {whisper_model_type}")
|
||
pe.to(vae.device, dtype=weight_dtype)
|
||
|
||
start = time.time()
|
||
with torch.no_grad():
|
||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(val_data_loader):
|
||
|
||
|
||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||
image = preprocess_img_tensor(image).to(vae.device)
|
||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||
|
||
# Convert images to latent space
|
||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||
latents = latents * vae.config.scaling_factor
|
||
# Convert masked images to latent space
|
||
masked_latents = vae.encode(
|
||
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
|
||
).latent_dist.sample()
|
||
masked_latents = masked_latents * vae.config.scaling_factor
|
||
# Convert ref images to latent space
|
||
ref_latents = vae.encode(
|
||
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
|
||
).latent_dist.sample()
|
||
ref_latents = ref_latents * vae.config.scaling_factor
|
||
|
||
mask = torch.stack(
|
||
[
|
||
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
|
||
for mask in masks
|
||
]
|
||
)
|
||
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
|
||
bsz = latents.shape[0]
|
||
timesteps = torch.tensor([0], device=latents.device)
|
||
|
||
if unet_config['in_channels'] == 9:
|
||
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
|
||
else:
|
||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||
|
||
image_pred = unet_copy(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
|
||
|
||
image = Image.new('RGB', (RESIZED_IMG*4, RESIZED_IMG))
|
||
image.paste(decode_latents(vae_fp32,masked_latents), (0, 0))
|
||
image.paste(decode_latents(vae_fp32, ref_latents), (RESIZED_IMG, 0))
|
||
image.paste(decode_latents(vae_fp32, latents), (RESIZED_IMG*2, 0))
|
||
image.paste(decode_latents(vae_fp32, image_pred), (RESIZED_IMG*3, 0))
|
||
|
||
val_img_dir = f"images/{output_dir}/{global_step}"
|
||
if not os.path.exists(val_img_dir):
|
||
os.makedirs(val_img_dir)
|
||
image.save('{0}/val_epoch_{1}_{2}_image.png'.format(val_img_dir, global_step,step))
|
||
|
||
print("valtion in step:{0}, time:{1}".format(step,time.time()-start))
|
||
|
||
print("valtion_done in epoch:{0}, time:{1}".format(epoch,time.time()-start))
|
||
|