Update draft training codes

This commit is contained in:
czk32611
2024-04-28 11:34:49 +08:00
parent 6e32247cb1
commit d73daf1808
9 changed files with 1547 additions and 0 deletions

View File

@@ -0,0 +1,129 @@
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import time
import math
from utils import decode_latents, preprocess_img_tensor
import os
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple, Union
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
)
from torch import Tensor, nn
import logging
import json
RESIZED_IMG = 256
class PositionalEncoding(nn.Module):
"""
Transformer 中的位置编码positional encoding
"""
def __init__(self, d_model=384, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
b, seq_len, d_model = x.size()
pe = self.pe[:, :seq_len, :]
#print(b, seq_len, d_model)
x = x + pe.to(x.device)
return x
def validation(vae: torch.nn.Module,
vae_fp32: torch.nn.Module,
unet:torch.nn.Module,
unet_config,
weight_dtype: torch.dtype,
epoch: int,
global_step: int,
val_data_loader,
output_dir,
whisper_model_type,
UNet2DConditionModel=UNet2DConditionModel
):
# Get the validation pipeline
unet_copy = UNet2DConditionModel(**unet_config)
unet_copy.load_state_dict(unet.state_dict())
unet_copy.to(vae.device).to(dtype=weight_dtype)
unet_copy.eval()
if whisper_model_type == "tiny":
pe = PositionalEncoding(d_model=384)
elif whisper_model_type == "largeV2":
pe = PositionalEncoding(d_model=1280)
elif whisper_model_type == "tiny-conv":
pe = PositionalEncoding(d_model=384)
print(f" whisper_model_type: {whisper_model_type} Validation does not need PE")
else:
print(f"not support whisper_model_type {whisper_model_type}")
pe.to(vae.device, dtype=weight_dtype)
start = time.time()
with torch.no_grad():
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(val_data_loader):
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
image = preprocess_img_tensor(image).to(vae.device)
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
# Convert images to latent space
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
latents = latents * vae.config.scaling_factor
# Convert masked images to latent space
masked_latents = vae.encode(
masked_image.reshape(image.shape).to(dtype=weight_dtype) # masked image
).latent_dist.sample()
masked_latents = masked_latents * vae.config.scaling_factor
# Convert ref images to latent space
ref_latents = vae.encode(
ref_image.reshape(image.shape).to(dtype=weight_dtype) # ref image
).latent_dist.sample()
ref_latents = ref_latents * vae.config.scaling_factor
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(mask.shape[-1] // 8, mask.shape[-1] // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, mask.shape[-1], mask.shape[-1])
bsz = latents.shape[0]
timesteps = torch.tensor([0], device=latents.device)
if unet_config['in_channels'] == 9:
latent_model_input = torch.cat([mask, masked_latents, ref_latents], dim=1)
else:
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
image_pred = unet_copy(latent_model_input, timesteps, encoder_hidden_states = audio_feature).sample
image = Image.new('RGB', (RESIZED_IMG*4, RESIZED_IMG))
image.paste(decode_latents(vae_fp32,masked_latents), (0, 0))
image.paste(decode_latents(vae_fp32, ref_latents), (RESIZED_IMG, 0))
image.paste(decode_latents(vae_fp32, latents), (RESIZED_IMG*2, 0))
image.paste(decode_latents(vae_fp32, image_pred), (RESIZED_IMG*3, 0))
val_img_dir = f"images/{output_dir}/{global_step}"
if not os.path.exists(val_img_dir):
os.makedirs(val_img_dir)
image.save('{0}/val_epoch_{1}_{2}_image.png'.format(val_img_dir, global_step,step))
print("valtion in step:{0}, time:{1}".format(step,time.time()-start))
print("valtion_done in epoch:{0}, time:{1}".format(epoch,time.time()-start))