modified dataloader.py and inference.py for training and inference

This commit is contained in:
Shounak Banerjee
2024-06-03 11:09:12 +00:00
parent 7254ca6306
commit b4a592d7f3
6 changed files with 106 additions and 58 deletions

View File

@@ -27,10 +27,13 @@ from diffusers import (
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
import sys
sys.path.append("./")
from DataLoader import Dataset
from utils.utils import preprocess_img_tensor
from torch.utils import data as data_utils
from model_utils import validation,PositionalEncoding
from utils.model_utils import validation,PositionalEncoding
import time
import pandas as pd
from PIL import Image
@@ -234,13 +237,17 @@ def parse_args():
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def print_model_dtypes(model, model_name):
for name, param in model.named_parameters():
if(param.dtype!=torch.float32):
print(f"{name}: {param.dtype}")
def main():
args = parse_args()
print(args)
args.output_dir = f"output/{args.output_dir}"
args.val_out_dir = f"val/{args.val_out_dir}"
os.makedirs(args.output_dir, exist_ok=True)
@@ -332,7 +339,7 @@ def main():
optimizer_class = torch.optim.AdamW
params_to_optimize = (
itertools.chain(unet.parameters())
itertools.chain(unet.parameters()))
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
@@ -348,7 +355,6 @@ def main():
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
print("train_dataset:",train_dataset.__len__())
train_data_loader = data_utils.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True,
num_workers=8)
@@ -359,7 +365,6 @@ def main():
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
print("val_dataset:",val_dataset.__len__())
val_data_loader = data_utils.DataLoader(
val_dataset, batch_size=1, shuffle=False,
num_workers=8)
@@ -388,6 +393,7 @@ def main():
vae_fp32.requires_grad_(False)
weight_dtype = torch.float32
# weight_dtype = torch.float16
vae_fp32.to(accelerator.device, dtype=weight_dtype)
vae_fp32.encoder = None
if accelerator.mixed_precision == "fp16":
@@ -412,6 +418,8 @@ def main():
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print(f" Num batches each epoch = {len(train_data_loader)}")
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
@@ -433,6 +441,9 @@ def main():
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
# path="../models/pytorch_model.bin"
#TODO change path
# path=None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
@@ -458,10 +469,11 @@ def main():
# caluate the elapsed time
elapsed_time = []
start = time.time()
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
# for step, batch in enumerate(train_dataloader):
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
@@ -470,24 +482,23 @@ def main():
continue
dataloader_time = time.time() - start
start = time.time()
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
"""
print("=============epoch:{0}=step:{1}=====".format(epoch,step))
print("ref_image: ",ref_image.shape)
print("masks: ", masks.shape)
print("masked_image: ", masked_image.shape)
print("audio feature: ", audio_feature.shape)
print("image: ", image.shape)
"""
# """
# print("=============epoch:{0}=step:{1}=====".format(epoch,step))
# print("ref_image: ",ref_image.shape)
# print("masks: ", masks.shape)
# print("masked_image: ", masked_image.shape)
# print("audio feature: ", audio_feature.shape)
# print("image: ", image.shape)
# """
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
image = preprocess_img_tensor(image).to(vae.device)
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
img_process_time = time.time() - start
start = time.time()
with accelerator.accumulate(unet):
vae = vae.half()
# Convert images to latent space
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
latents = latents * vae.config.scaling_factor
@@ -592,12 +603,23 @@ def main():
f"Running validation... epoch={epoch}, global_step={global_step}"
)
print("===========start validation==========")
# Use the helper function to check the data types for each model
vae_new = vae.float()
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
print(f"weight_dtype: {weight_dtype}")
print(f"epoch type: {type(epoch)}")
print(f"global_step type: {type(global_step)}")
validation(
vae=accelerator.unwrap_model(vae),
# vae=accelerator.unwrap_model(vae),
vae=accelerator.unwrap_model(vae_new),
vae_fp32=accelerator.unwrap_model(vae_fp32),
unet=accelerator.unwrap_model(unet),
unet_config=unet_config,
weight_dtype=weight_dtype,
# weight_dtype=weight_dtype,
weight_dtype=torch.float32,
epoch=epoch,
global_step=global_step,
val_data_loader=val_data_loader,