Update draft training codes

This commit is contained in:
czk32611
2024-04-28 11:34:49 +08:00
parent 6e32247cb1
commit d73daf1808
9 changed files with 1547 additions and 0 deletions

View File

@@ -0,0 +1,129 @@
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import time
import math
from utils import decode_latents, preprocess_img_tensor
import os
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple, Union
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
)
from torch import Tensor, nn
import logging
import json
RESIZED_IMG = 256
class PositionalEncoding(nn.Module):
"""
Transformer 中的位置编码positional encoding
"""
def __init__(self, d_model=384, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
b, seq_len, d_model = x.size()
pe = self.pe[:, :seq_len, :]
#print(b, seq_len, d_model)
x = x + pe.to(x.device)
return x
def validation(vae: torch.nn.Module,
vae_fp32: torch.nn.Module,
unet:torch.nn.Module,
unet_config,
weight_dtype: torch.dtype,
epoch: int,
global_step: int,
val_data_loader,
output_dir,
whisper_model_type,
UNet2DConditionModel=UNet2DConditionModel
):
# Get the validation pipeline
unet_copy = UNet2DConditionModel(**unet_config)
unet_copy.load_state_dict(unet.state_dict())
unet_copy.to(vae.device).to(dtype=weight_dtype)
unet_copy.eval()
if whisper_model_type == "tiny":
pe = PositionalEncoding(d_model=384)
elif whisper_model_type == "largeV2":
pe = PositionalEncoding(d_model=1280)
elif whisper_model_type == "tiny-conv":
pe = PositionalEncoding(d_model=384)
print(f" whisper_model_type: {whisper_model_type} Validation does not need PE")
else:
print(f"not support whisper_model_type {whisper_model_type}")
pe.to(vae.device, dtype=weight_dtype)
start = time.time()
with torch.no_grad():
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(val_data_loader):
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
image = preprocess_img_tensor(image).to(vae.device)
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
# Convert images to latent space
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
latents = latents * vae.config.scaling_factor
# Convert masked images to latent space
masked_latents = vae.encode(
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
).latent_dist.sample()
masked_latents = masked_latents * vae.config.scaling_factor
# Convert ref images to latent space
ref_latents = vae.encode(
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
).latent_dist.sample()
ref_latents = ref_latents * vae.config.scaling_factor
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
bsz = latents.shape[0]
timesteps = torch.tensor([0], device=latents.device)
if unet_config['in_channels'] == 9:
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
else:
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
image_pred = unet_copy(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
image = Image.new('RGB', (RESIZED_IMG*4, RESIZED_IMG))
image.paste(decode_latents(vae_fp32,masked_latents), (0, 0))
image.paste(decode_latents(vae_fp32, ref_latents), (RESIZED_IMG, 0))
image.paste(decode_latents(vae_fp32, latents), (RESIZED_IMG*2, 0))
image.paste(decode_latents(vae_fp32, image_pred), (RESIZED_IMG*3, 0))
val_img_dir = f"images/{output_dir}/{global_step}"
if not os.path.exists(val_img_dir):
os.makedirs(val_img_dir)
image.save('{0}/val_epoch_{1}_{2}_image.png'.format(val_img_dir, global_step,step))
print("valtion in step:{0}, time:{1}".format(step,time.time()-start))
print("valtion_done in epoch:{0}, time:{1}".format(epoch,time.time()-start))

View File

@@ -0,0 +1,74 @@
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import rearrange
import torch
import torchvision.transforms as transforms
from diffusers import AutoencoderKL
import matplotlib.pyplot as plt
import PIL
import os
import cv2
from glob import glob
def preprocess_img_tensor(image_tensor):
# 假设输入是一个形状为 (N, C, H, W) 的 PyTorch 张量
N, C, H, W = image_tensor.shape
# 计算新的宽度和高度,使其为 32 的整数倍
new_w = W - W % 32
new_h = H - H % 32
# 使用 torchvision.transforms 库中的方法进行缩放和重采样
transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# 对每个图像应用变换,并将结果存储在一个新的张量中
preprocessed_images = torch.empty((N, C, new_h, new_w), dtype=torch.float32)
for i in range(N):
# 使用 F.interpolate 替换 transforms.Resize
resized_image = F.interpolate(image_tensor[i].unsqueeze(0), size=(new_h, new_w), mode='bilinear', align_corners=False)
preprocessed_images[i] = transform(resized_image.squeeze(0))
return preprocessed_images
def prepare_mask_and_masked_image(image_tensor, mask_tensor):
# 假设输入 image_tensor 的形状为 [C, H, W],输入 mask_tensor 的形状为 [H, W]
# # 对图像张量进行归一化
image_tensor_ori = (image_tensor.to(dtype=torch.float32) / 127.5) - 1.0
# # 对遮罩张量进行归一化和二值化
# mask_tensor = (mask_tensor.to(dtype=torch.float32) / 255.0).unsqueeze(0)
mask_tensor[mask_tensor < 0.5] = 0
mask_tensor[mask_tensor >= 0.5] = 1
# 创建遮罩后的图像
masked_image_tensor = image_tensor * (mask_tensor > 0.5)
return mask_tensor, masked_image_tensor
def encode_latents(vae, image):
# init_image = preprocess_image(image)
init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
init_latents = 0.18215 * init_latent_dist.sample()
return init_latents
def decode_latents(vae, latents, ref_images=None):
latents = (1/ 0.18215) * latents
image = vae.decode(latents.to(vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
if ref_images is not None:
ref_images = ref_images.detach().cpu().permute(0, 2, 3, 1).float().numpy()
ref_images = (ref_images * 255).round().astype("uint8")
h = image.shape[1]
image[:, :h//2] = ref_images[:, :h//2]
image = [Image.fromarray(im) for im in image]
return image[0]