modified dataloader.py and inference.py for training and inference

This commit is contained in:
Shounak Banerjee
2024-06-03 11:09:12 +00:00
parent 7254ca6306
commit b4a592d7f3
6 changed files with 106 additions and 58 deletions

View File

@@ -15,14 +15,27 @@ from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model
import shutil
from accelerate import Accelerator
# load model weights
audio_processor, vae, unet, pe = load_all_model()
accelerator = Accelerator(
mixed_precision="fp16",
)
unet = accelerator.prepare(
unet,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
@torch.no_grad()
def main(args):
global pe
if not (args.unet_checkpoint == None):
print("unet ckpt loaded")
accelerator.load_state(args.unet_checkpoint)
if args.use_float16 is True:
pe = pe.half()
vae.vae = vae.vae.half()
@@ -63,8 +76,6 @@ def main(args):
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
#print(input_img_list)
############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
@@ -79,24 +90,27 @@ def main(args):
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
i = 0
input_latent_list = []
crop_i=0
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
crop_i+=1
# to smooth the first and the last frame
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
############################################## inference batch by batch ##############################################
print("start inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
@@ -107,7 +121,6 @@ def main(args):
dtype=unet.model.dtype) # torch, B, 5*N,384
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
@@ -122,22 +135,29 @@ def main(args):
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
# print(bbox)
continue
combine_frame = get_image(ori_frame,res_frame,bbox)
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
print(cmd_img2video)
os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
print(cmd_combine_audio)
os.system(cmd_combine_audio)
os.remove("temp.mp4")
shutil.rmtree(result_img_save_path)
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
os.system(cmd_img2video)
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
# print(cmd_combine_audio)
# os.system(cmd_combine_audio)
# shutil.rmtree(result_img_save_path)
print(f"result is save to {output_vid_name}")
if __name__ == "__main__":
@@ -156,6 +176,7 @@ if __name__ == "__main__":
action="store_true",
help="Whether use float16 to speed up inference",
)
parser.add_argument("--unet_checkpoint", type=str, default=None)
args = parser.parse_args()
main(args)
main(args)

View File

@@ -57,13 +57,13 @@ class Dataset(object):
self.audio_feature = [use_audio_length_left,use_audio_length_right]
self.all_img_names = []
self.split = split
self.img_names_path = '...'
self.img_names_path = '../data'
self.whisper_model_type = whisper_model_type
self.use_audio_length_left = use_audio_length_left
self.use_audio_length_right = use_audio_length_right
if self.whisper_model_type =="tiny":
self.whisper_path = '...'
self.whisper_path = '../data/audios'
self.whisper_feature_W = 5
self.whisper_feature_H = 384
elif self.whisper_model_type =="largeV2":
@@ -72,6 +72,10 @@ class Dataset(object):
self.whisper_feature_H = 1280
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*2+2+1= 50
if(self.split=="train"):
self.all_videos=["../data/images/train"]
if(self.split=="val"):
self.all_videos=["../data/images/test"]
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
if not os.path.exists(json_path_names):
@@ -79,7 +83,6 @@ class Dataset(object):
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
with open(json_path_names, "w") as f:
json.dump(img_names,f)
print(f"save to {json_path_names}")
else:
with open(json_path_names, "r") as f:
img_names = json.load(f)
@@ -147,7 +150,6 @@ class Dataset(object):
vidname = self.all_videos[idx].split('/')[-1]
video_imgs = self.all_img_names[idx]
if len(video_imgs) == 0:
# print("video_imgs = 0:",vidname)
continue
img_name = random.choice(video_imgs)
img_idx = int(basename(img_name).split(".")[0])
@@ -205,7 +207,6 @@ class Dataset(object):
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
# 判定是否越界
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
if not os.path.exists(audio_feat_path):
is_index_out_of_range = True
break
@@ -226,8 +227,6 @@ class Dataset(object):
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
continue
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
return ref_image, image, masked_image, mask, audio_feature
@@ -243,10 +242,8 @@ if __name__ == "__main__":
val_data_loader = data_utils.DataLoader(
val_data, batch_size=4, shuffle=True,
num_workers=1)
print("val_dataset:",val_data_loader.__len__())
for i, data in enumerate(val_data_loader):
ref_image, image, masked_image, mask, audio_feature = data
print("ref_image: ", ref_image.shape)

View File

@@ -1,32 +1,35 @@
# Draft training codes
# Data preprocessing
We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized.
Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
The train yaml file should contain the training video paths and corresponding audio paths
The test yaml file should contain the validation video paths and corresponding audio paths
## Setup
Run:
```
python -m scripts.data --inference_config path_to_train.yaml --folder_name train
python -m scripts.data --inference_config path_to_test.yaml --folder_name test
```
This creates folders which contain the image frames and npy files.
We trained our model on an NVIDIA A100 with `batch size=8, gradient_accumulation_steps=4` for 20w+ steps. Using multiple GPUs should accelerate the training.
## Data preprocessing
You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69).
Finally, the data should be organized as follows:
## Data organization
```
./data/
├── images
│ └──RD_Radio10_000
│ └──train
│ └── 0.png
│ └── 1.png
│ └── xxx.png
│ └──RD_Radio11_000
│ └──test
│ └── 0.png
│ └── 1.png
│ └── xxx.png
├── audios
│ └──RD_Radio10_000
│ └──train
│ └── 0.npy
│ └── 1.npy
│ └── xxx.npy
│ └──RD_Radio11_000
│ └──test
│ └── 0.npy
│ └── 1.npy
│ └── xxx.npy
@@ -37,7 +40,12 @@ Simply run after preparing the preprocessed data
```
sh train.sh
```
## Inference with trained checkpoit
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
```
python -m scripts.inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
```
## TODO
- [ ] release data preprocessing codes
- [x] release data preprocessing codes
- [ ] release some novel designs in training (after technical report)

View File

@@ -27,10 +27,13 @@ from diffusers import (
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
import sys
sys.path.append("./")
from DataLoader import Dataset
from utils.utils import preprocess_img_tensor
from torch.utils import data as data_utils
from model_utils import validation,PositionalEncoding
from utils.model_utils import validation,PositionalEncoding
import time
import pandas as pd
from PIL import Image
@@ -234,13 +237,17 @@ def parse_args():
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def print_model_dtypes(model, model_name):
for name, param in model.named_parameters():
if(param.dtype!=torch.float32):
print(f"{name}: {param.dtype}")
def main():
args = parse_args()
print(args)
args.output_dir = f"output/{args.output_dir}"
args.val_out_dir = f"val/{args.val_out_dir}"
os.makedirs(args.output_dir, exist_ok=True)
@@ -332,7 +339,7 @@ def main():
optimizer_class = torch.optim.AdamW
params_to_optimize = (
itertools.chain(unet.parameters())
itertools.chain(unet.parameters()))
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
@@ -348,7 +355,6 @@ def main():
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
print("train_dataset:",train_dataset.__len__())
train_data_loader = data_utils.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True,
num_workers=8)
@@ -359,7 +365,6 @@ def main():
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
print("val_dataset:",val_dataset.__len__())
val_data_loader = data_utils.DataLoader(
val_dataset, batch_size=1, shuffle=False,
num_workers=8)
@@ -388,6 +393,7 @@ def main():
vae_fp32.requires_grad_(False)
weight_dtype = torch.float32
# weight_dtype = torch.float16
vae_fp32.to(accelerator.device, dtype=weight_dtype)
vae_fp32.encoder = None
if accelerator.mixed_precision == "fp16":
@@ -412,6 +418,8 @@ def main():
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print(f" Num batches each epoch = {len(train_data_loader)}")
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
@@ -433,6 +441,9 @@ def main():
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
# path="../models/pytorch_model.bin"
#TODO change path
# path=None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
@@ -458,10 +469,11 @@ def main():
# caluate the elapsed time
elapsed_time = []
start = time.time()
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
# for step, batch in enumerate(train_dataloader):
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
@@ -470,24 +482,23 @@ def main():
continue
dataloader_time = time.time() - start
start = time.time()
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
"""
print("=============epoch:{0}=step:{1}=====".format(epoch,step))
print("ref_image: ",ref_image.shape)
print("masks: ", masks.shape)
print("masked_image: ", masked_image.shape)
print("audio feature: ", audio_feature.shape)
print("image: ", image.shape)
"""
# """
# print("=============epoch:{0}=step:{1}=====".format(epoch,step))
# print("ref_image: ",ref_image.shape)
# print("masks: ", masks.shape)
# print("masked_image: ", masked_image.shape)
# print("audio feature: ", audio_feature.shape)
# print("image: ", image.shape)
# """
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
image = preprocess_img_tensor(image).to(vae.device)
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
img_process_time = time.time() - start
start = time.time()
with accelerator.accumulate(unet):
vae = vae.half()
# Convert images to latent space
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
latents = latents * vae.config.scaling_factor
@@ -592,12 +603,23 @@ def main():
f"Running validation... epoch={epoch}, global_step={global_step}"
)
print("===========start validation==========")
# Use the helper function to check the data types for each model
vae_new = vae.float()
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
print(f"weight_dtype: {weight_dtype}")
print(f"epoch type: {type(epoch)}")
print(f"global_step type: {type(global_step)}")
validation(
vae=accelerator.unwrap_model(vae),
# vae=accelerator.unwrap_model(vae),
vae=accelerator.unwrap_model(vae_new),
vae_fp32=accelerator.unwrap_model(vae_fp32),
unet=accelerator.unwrap_model(unet),
unet_config=unet_config,
weight_dtype=weight_dtype,
# weight_dtype=weight_dtype,
weight_dtype=torch.float32,
epoch=epoch,
global_step=global_step,
val_data_loader=val_data_loader,

View File

@@ -1,8 +1,8 @@
export VAE_MODEL="./sd-vae-ft-mse/"
export DATASET="..."
export UNET_CONFIG="./musetalk.json"
export VAE_MODEL="../models/sd-vae-ft-mse/"
export DATASET="../data"
export UNET_CONFIG="../models/musetalk/musetalk.json"
accelerate launch --multi_gpu train.py \
accelerate launch train.py \
--mixed_precision="fp16" \
--unet_config_file=$UNET_CONFIG \
--pretrained_model_name_or_path=$VAE_MODEL \
@@ -10,13 +10,13 @@ accelerate launch --multi_gpu train.py \
--train_batch_size=8 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=200000 \
--max_train_steps=50000 \
--learning_rate=5e-05 \
--max_grad_norm=1 \
--lr_scheduler="cosine" \
--lr_warmup_steps=0 \
--output_dir="..." \
--val_out_dir='...' \
--output_dir="output" \
--val_out_dir='val' \
--testing_speed \
--checkpointing_steps=1000 \
--validation_steps=1000 \

View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import time
import math
from utils import decode_latents, preprocess_img_tensor
from utils.utils import decode_latents, preprocess_img_tensor
import os
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple, Union