mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
modified dataloader.py and inference.py for training and inference
This commit is contained in:
@@ -27,10 +27,13 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
|
||||
import sys
|
||||
sys.path.append("./")
|
||||
|
||||
from DataLoader import Dataset
|
||||
from utils.utils import preprocess_img_tensor
|
||||
from torch.utils import data as data_utils
|
||||
from model_utils import validation,PositionalEncoding
|
||||
from utils.model_utils import validation,PositionalEncoding
|
||||
import time
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
@@ -234,13 +237,17 @@ def parse_args():
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def print_model_dtypes(model, model_name):
|
||||
for name, param in model.named_parameters():
|
||||
if(param.dtype!=torch.float32):
|
||||
print(f"{name}: {param.dtype}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
print(args)
|
||||
args.output_dir = f"output/{args.output_dir}"
|
||||
args.val_out_dir = f"val/{args.val_out_dir}"
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
@@ -332,7 +339,7 @@ def main():
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet.parameters())
|
||||
itertools.chain(unet.parameters()))
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
@@ -348,7 +355,6 @@ def main():
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
print("train_dataset:",train_dataset.__len__())
|
||||
train_data_loader = data_utils.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True,
|
||||
num_workers=8)
|
||||
@@ -359,7 +365,6 @@ def main():
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
print("val_dataset:",val_dataset.__len__())
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_dataset, batch_size=1, shuffle=False,
|
||||
num_workers=8)
|
||||
@@ -388,6 +393,7 @@ def main():
|
||||
vae_fp32.requires_grad_(False)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
# weight_dtype = torch.float16
|
||||
vae_fp32.to(accelerator.device, dtype=weight_dtype)
|
||||
vae_fp32.encoder = None
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
@@ -412,6 +418,8 @@ def main():
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
print(f" Num batches each epoch = {len(train_data_loader)}")
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
|
||||
@@ -433,6 +441,9 @@ def main():
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
# path="../models/pytorch_model.bin"
|
||||
#TODO change path
|
||||
# path=None
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
@@ -458,10 +469,11 @@ def main():
|
||||
# caluate the elapsed time
|
||||
elapsed_time = []
|
||||
start = time.time()
|
||||
|
||||
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
# for step, batch in enumerate(train_dataloader):
|
||||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
@@ -470,24 +482,23 @@ def main():
|
||||
continue
|
||||
dataloader_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||||
"""
|
||||
print("=============epoch:{0}=step:{1}=====".format(epoch,step))
|
||||
print("ref_image: ",ref_image.shape)
|
||||
print("masks: ", masks.shape)
|
||||
print("masked_image: ", masked_image.shape)
|
||||
print("audio feature: ", audio_feature.shape)
|
||||
print("image: ", image.shape)
|
||||
"""
|
||||
# """
|
||||
# print("=============epoch:{0}=step:{1}=====".format(epoch,step))
|
||||
# print("ref_image: ",ref_image.shape)
|
||||
# print("masks: ", masks.shape)
|
||||
# print("masked_image: ", masked_image.shape)
|
||||
# print("audio feature: ", audio_feature.shape)
|
||||
# print("image: ", image.shape)
|
||||
# """
|
||||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||||
image = preprocess_img_tensor(image).to(vae.device)
|
||||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||||
|
||||
img_process_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
vae = vae.half()
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -592,12 +603,23 @@ def main():
|
||||
f"Running validation... epoch={epoch}, global_step={global_step}"
|
||||
)
|
||||
print("===========start validation==========")
|
||||
# Use the helper function to check the data types for each model
|
||||
vae_new = vae.float()
|
||||
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
|
||||
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
|
||||
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
|
||||
|
||||
print(f"weight_dtype: {weight_dtype}")
|
||||
print(f"epoch type: {type(epoch)}")
|
||||
print(f"global_step type: {type(global_step)}")
|
||||
validation(
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
# vae=accelerator.unwrap_model(vae),
|
||||
vae=accelerator.unwrap_model(vae_new),
|
||||
vae_fp32=accelerator.unwrap_model(vae_fp32),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet_config=unet_config,
|
||||
weight_dtype=weight_dtype,
|
||||
# weight_dtype=weight_dtype,
|
||||
weight_dtype=torch.float32,
|
||||
epoch=epoch,
|
||||
global_step=global_step,
|
||||
val_data_loader=val_data_loader,
|
||||
|
||||
Reference in New Issue
Block a user