mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 09:59:18 +08:00
Update draft training codes
This commit is contained in:
129
train_codes/utils/model_utils.py
Normal file
129
train_codes/utils/model_utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import math
|
||||
from utils import decode_latents, preprocess_img_tensor
|
||||
import os
|
||||
from PIL import Image
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
import logging
|
||||
import json
|
||||
|
||||
RESIZED_IMG = 256
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""
|
||||
Transformer 中的位置编码(positional encoding)
|
||||
"""
|
||||
def __init__(self, d_model=384, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
b, seq_len, d_model = x.size()
|
||||
pe = self.pe[:, :seq_len, :]
|
||||
#print(b, seq_len, d_model)
|
||||
x = x + pe.to(x.device)
|
||||
return x
|
||||
|
||||
def validation(vae: torch.nn.Module,
|
||||
vae_fp32: torch.nn.Module,
|
||||
unet:torch.nn.Module,
|
||||
unet_config,
|
||||
weight_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
val_data_loader,
|
||||
output_dir,
|
||||
whisper_model_type,
|
||||
UNet2DConditionModel=UNet2DConditionModel
|
||||
):
|
||||
|
||||
# Get the validation pipeline
|
||||
unet_copy = UNet2DConditionModel(**unet_config)
|
||||
|
||||
unet_copy.load_state_dict(unet.state_dict())
|
||||
unet_copy.to(vae.device).to(dtype=weight_dtype)
|
||||
unet_copy.eval()
|
||||
|
||||
if whisper_model_type == "tiny":
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
elif whisper_model_type == "largeV2":
|
||||
pe = PositionalEncoding(d_model=1280)
|
||||
elif whisper_model_type == "tiny-conv":
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
print(f" whisper_model_type: {whisper_model_type} Validation does not need PE")
|
||||
else:
|
||||
print(f"not support whisper_model_type {whisper_model_type}")
|
||||
pe.to(vae.device, dtype=weight_dtype)
|
||||
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(val_data_loader):
|
||||
|
||||
|
||||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||||
image = preprocess_img_tensor(image).to(vae.device)
|
||||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||||
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||||
latents = latents * vae.config.scaling_factor
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(
|
||||
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
# Convert ref images to latent space
|
||||
ref_latents = vae.encode(
|
||||
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
|
||||
).latent_dist.sample()
|
||||
ref_latents = ref_latents * vae.config.scaling_factor
|
||||
|
||||
mask = torch.stack(
|
||||
[
|
||||
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
|
||||
bsz = latents.shape[0]
|
||||
timesteps = torch.tensor([0], device=latents.device)
|
||||
|
||||
if unet_config['in_channels'] == 9:
|
||||
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
|
||||
else:
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
|
||||
image_pred = unet_copy(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
|
||||
|
||||
image = Image.new('RGB', (RESIZED_IMG*4, RESIZED_IMG))
|
||||
image.paste(decode_latents(vae_fp32,masked_latents), (0, 0))
|
||||
image.paste(decode_latents(vae_fp32, ref_latents), (RESIZED_IMG, 0))
|
||||
image.paste(decode_latents(vae_fp32, latents), (RESIZED_IMG*2, 0))
|
||||
image.paste(decode_latents(vae_fp32, image_pred), (RESIZED_IMG*3, 0))
|
||||
|
||||
val_img_dir = f"images/{output_dir}/{global_step}"
|
||||
if not os.path.exists(val_img_dir):
|
||||
os.makedirs(val_img_dir)
|
||||
image.save('{0}/val_epoch_{1}_{2}_image.png'.format(val_img_dir, global_step,step))
|
||||
|
||||
print("valtion in step:{0}, time:{1}".format(step,time.time()-start))
|
||||
|
||||
print("valtion_done in epoch:{0}, time:{1}".format(epoch,time.time()-start))
|
||||
|
||||
74
train_codes/utils/utils.py
Normal file
74
train_codes/utils/utils.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
from einops import rearrange
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
import os
|
||||
import cv2
|
||||
from glob import glob
|
||||
|
||||
|
||||
def preprocess_img_tensor(image_tensor):
|
||||
# 假设输入是一个形状为 (N, C, H, W) 的 PyTorch 张量
|
||||
N, C, H, W = image_tensor.shape
|
||||
# 计算新的宽度和高度,使其为 32 的整数倍
|
||||
new_w = W - W % 32
|
||||
new_h = H - H % 32
|
||||
# 使用 torchvision.transforms 库中的方法进行缩放和重采样
|
||||
transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
# 对每个图像应用变换,并将结果存储在一个新的张量中
|
||||
preprocessed_images = torch.empty((N, C, new_h, new_w), dtype=torch.float32)
|
||||
for i in range(N):
|
||||
# 使用 F.interpolate 替换 transforms.Resize
|
||||
resized_image = F.interpolate(image_tensor[i].unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False)
|
||||
preprocessed_images[i] = transform(resized_image.squeeze(0))
|
||||
|
||||
return preprocessed_images
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image_tensor, mask_tensor):
|
||||
# 假设输入 image_tensor 的形状为 [C, H, W],输入 mask_tensor 的形状为 [H, W]
|
||||
# # 对图像张量进行归一化
|
||||
image_tensor_ori = (image_tensor.to(dtype=torch.float32) / 127.5) - 1.0
|
||||
# # 对遮罩张量进行归一化和二值化
|
||||
# mask_tensor = (mask_tensor.to(dtype=torch.float32) / 255.0).unsqueeze(0)
|
||||
mask_tensor[mask_tensor < 0.5] = 0
|
||||
mask_tensor[mask_tensor >= 0.5] = 1
|
||||
# 创建遮罩后的图像
|
||||
masked_image_tensor = image_tensor * (mask_tensor > 0.5)
|
||||
|
||||
return mask_tensor, masked_image_tensor
|
||||
|
||||
|
||||
def encode_latents(vae, image):
|
||||
# init_image = preprocess_image(image)
|
||||
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
|
||||
init_latents = 0.18215 * init_latent_dist.sample()
|
||||
return init_latents
|
||||
|
||||
def decode_latents(vae, latents, ref_images=None):
|
||||
latents = (1/ 0.18215) * latents
|
||||
image = vae.decode(latents.to(vae.dtype)).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
image = (image * 255).round().astype("uint8")
|
||||
if ref_images is not None:
|
||||
ref_images = ref_images.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
ref_images = (ref_images * 255).round().astype("uint8")
|
||||
h = image.shape[1]
|
||||
image[:, :h//2] = ref_images[:, :h//2]
|
||||
image = [Image.fromarray(im) for im in image]
|
||||
|
||||
return image[0]
|
||||
|
||||
Reference in New Issue
Block a user