Merge pull request #85 from shounakb1/train_codes

initial data script
This commit is contained in:
czk32611
2024-08-06 18:49:07 +08:00
committed by GitHub
9 changed files with 603 additions and 58 deletions

View File

@@ -48,22 +48,22 @@ def get_image_list(data_root, split):
class Dataset(object):
def __init__(self,
data_root,
split,
json_path,
use_audio_length_left=1,
use_audio_length_right=1,
whisper_model_type = "tiny"
):
self.all_videos, self.all_imgNum = get_image_list(data_root, split)
# self.all_videos, self.all_imgNum = get_image_list(data_root, split)
self.audio_feature = [use_audio_length_left,use_audio_length_right]
self.all_img_names = []
self.split = split
self.img_names_path = '...'
# self.split = split
self.img_names_path = '../data'
self.whisper_model_type = whisper_model_type
self.use_audio_length_left = use_audio_length_left
self.use_audio_length_right = use_audio_length_right
if self.whisper_model_type =="tiny":
self.whisper_path = '...'
self.whisper_path = '../data/audios'
self.whisper_feature_W = 5
self.whisper_feature_H = 384
elif self.whisper_model_type =="largeV2":
@@ -71,6 +71,8 @@ class Dataset(object):
self.whisper_feature_W = 33
self.whisper_feature_H = 1280
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*2+2+1= 50
with open(json_path, 'r') as file:
self.all_videos = json.load(file)
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
@@ -79,7 +81,6 @@ class Dataset(object):
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
with open(json_path_names, "w") as f:
json.dump(img_names,f)
print(f"save to {json_path_names}")
else:
with open(json_path_names, "r") as f:
img_names = json.load(f)
@@ -135,7 +136,6 @@ class Dataset(object):
vidname = self.all_videos[idx].split('/')[-1]
video_imgs = self.all_img_names[idx]
if len(video_imgs) == 0:
# print("video_imgs = 0:",vidname)
continue
img_name = random.choice(video_imgs)
img_idx = int(basename(img_name).split(".")[0])
@@ -193,7 +193,6 @@ class Dataset(object):
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
# 判定是否越界
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
if not os.path.exists(audio_feat_path):
is_index_out_of_range = True
break
@@ -214,8 +213,6 @@ class Dataset(object):
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
continue
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
return ref_image, image, masked_image, mask, audio_feature
@@ -231,10 +228,8 @@ if __name__ == "__main__":
val_data_loader = data_utils.DataLoader(
val_data, batch_size=4, shuffle=True,
num_workers=1)
print("val_dataset:",val_data_loader.__len__())
for i, data in enumerate(val_data_loader):
ref_image, image, masked_image, mask, audio_feature = data
print("ref_image: ", ref_image.shape)

View File

@@ -1,15 +1,17 @@
# Draft training codes
# Data preprocessing
We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized.
Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
The train yaml file should contain the training video paths and corresponding audio paths
The test yaml file should contain the validation video paths and corresponding audio paths
## Setup
Run:
```
./data_new.sh train output train_video1.mp4 train_video2.mp4
./data_new.sh test output test_video1.mp4 test_video2.mp4
```
This creates folders which contain the image frames and npy files. This also creates train.json and val.json which can be used during the training.
We trained our model on an NVIDIA A100 with `batch size=8, gradient_accumulation_steps=4` for 20w+ steps. Using multiple GPUs should accelerate the training.
## Data preprocessing
You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69).
Finally, the data should be organized as follows:
## Data organization
```
./data/
├── images
@@ -35,9 +37,16 @@ Finally, the data should be organized as follows:
## Training
Simply run after preparing the preprocessed data
```
sh train.sh
cd train_codes
sh train.sh #--train_json="../train.json" \(Generated in Data preprocessing step.)
#--val_json="../val.json" \
```
## Inference with trained checkpoit
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
```
python -m scripts.finetuned_inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
```
## TODO
- [ ] release data preprocessing codes
- [x] release data preprocessing codes
- [ ] release some novel designs in training (after technical report)

View File

@@ -27,10 +27,13 @@ from diffusers import (
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
import sys
sys.path.append("./")
from DataLoader import Dataset
from utils.utils import preprocess_img_tensor
from torch.utils import data as data_utils
from model_utils import validation,PositionalEncoding
from utils.model_utils import validation,PositionalEncoding
import time
import pandas as pd
from PIL import Image
@@ -137,6 +140,8 @@ def parse_args():
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders")
parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders")
parser.add_argument(
"--hub_model_id",
type=str,
@@ -234,13 +239,17 @@ def parse_args():
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def print_model_dtypes(model, model_name):
for name, param in model.named_parameters():
if(param.dtype!=torch.float32):
print(f"{name}: {param.dtype}")
def main():
args = parse_args()
print(args)
args.output_dir = f"output/{args.output_dir}"
args.val_out_dir = f"val/{args.val_out_dir}"
os.makedirs(args.output_dir, exist_ok=True)
@@ -332,7 +341,7 @@ def main():
optimizer_class = torch.optim.AdamW
params_to_optimize = (
itertools.chain(unet.parameters())
itertools.chain(unet.parameters()))
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
@@ -343,23 +352,21 @@ def main():
print("loading train_dataset ...")
train_dataset = Dataset(args.data_root,
'train',
args.train_json,
use_audio_length_left=args.use_audio_length_left,
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
print("train_dataset:",train_dataset.__len__())
train_data_loader = data_utils.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True,
num_workers=8)
print("loading val_dataset ...")
val_dataset = Dataset(args.data_root,
'val',
args.val_json,
use_audio_length_left=args.use_audio_length_left,
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
print("val_dataset:",val_dataset.__len__())
val_data_loader = data_utils.DataLoader(
val_dataset, batch_size=1, shuffle=False,
num_workers=8)
@@ -388,6 +395,7 @@ def main():
vae_fp32.requires_grad_(False)
weight_dtype = torch.float32
# weight_dtype = torch.float16
vae_fp32.to(accelerator.device, dtype=weight_dtype)
vae_fp32.encoder = None
if accelerator.mixed_precision == "fp16":
@@ -412,6 +420,8 @@ def main():
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print(f" Num batches each epoch = {len(train_data_loader)}")
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
@@ -433,6 +443,9 @@ def main():
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
# path="../models/pytorch_model.bin"
#TODO change path
# path=None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
@@ -458,10 +471,11 @@ def main():
# caluate the elapsed time
elapsed_time = []
start = time.time()
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
# for step, batch in enumerate(train_dataloader):
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
@@ -470,24 +484,23 @@ def main():
continue
dataloader_time = time.time() - start
start = time.time()
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
"""
print("=============epoch:{0}=step:{1}=====".format(epoch,step))
print("ref_image: ",ref_image.shape)
print("masks: ", masks.shape)
print("masked_image: ", masked_image.shape)
print("audio feature: ", audio_feature.shape)
print("image: ", image.shape)
"""
# """
# print("=============epoch:{0}=step:{1}=====".format(epoch,step))
# print("ref_image: ",ref_image.shape)
# print("masks: ", masks.shape)
# print("masked_image: ", masked_image.shape)
# print("audio feature: ", audio_feature.shape)
# print("image: ", image.shape)
# """
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
image = preprocess_img_tensor(image).to(vae.device)
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
img_process_time = time.time() - start
start = time.time()
with accelerator.accumulate(unet):
vae = vae.half()
# Convert images to latent space
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
latents = latents * vae.config.scaling_factor
@@ -592,12 +605,23 @@ def main():
f"Running validation... epoch={epoch}, global_step={global_step}"
)
print("===========start validation==========")
# Use the helper function to check the data types for each model
vae_new = vae.float()
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
print(f"weight_dtype: {weight_dtype}")
print(f"epoch type: {type(epoch)}")
print(f"global_step type: {type(global_step)}")
validation(
vae=accelerator.unwrap_model(vae),
# vae=accelerator.unwrap_model(vae),
vae=accelerator.unwrap_model(vae_new),
vae_fp32=accelerator.unwrap_model(vae_fp32),
unet=accelerator.unwrap_model(unet),
unet_config=unet_config,
weight_dtype=weight_dtype,
# weight_dtype=weight_dtype,
weight_dtype=torch.float32,
epoch=epoch,
global_step=global_step,
val_data_loader=val_data_loader,

View File

@@ -1,27 +1,29 @@
export VAE_MODEL="./sd-vae-ft-mse/"
export DATASET="..."
export UNET_CONFIG="./musetalk.json"
export VAE_MODEL="../models/sd-vae-ft-mse/"
export DATASET="../data"
export UNET_CONFIG="../models/musetalk/musetalk.json"
accelerate launch --multi_gpu train.py \
accelerate launch train.py \
--mixed_precision="fp16" \
--unet_config_file=$UNET_CONFIG \
--pretrained_model_name_or_path=$VAE_MODEL \
--data_root=$DATASET \
--train_batch_size=8 \
--gradient_accumulation_steps=4 \
--train_batch_size=256 \
--gradient_accumulation_steps=16 \
--gradient_checkpointing \
--max_train_steps=200000 \
--max_train_steps=100000 \
--learning_rate=5e-05 \
--max_grad_norm=1 \
--lr_scheduler="cosine" \
--lr_warmup_steps=0 \
--output_dir="..." \
--val_out_dir='...' \
--output_dir="output" \
--val_out_dir='val' \
--testing_speed \
--checkpointing_steps=1000 \
--validation_steps=1000 \
--checkpointing_steps=2000 \
--validation_steps=2000 \
--reconstruction \
--resume_from_checkpoint="latest" \
--use_audio_length_left=2 \
--use_audio_length_right=2 \
--whisper_model_type="tiny" \
--train_json="../train.json" \
--val_json="../val.json" \
--lr_scheduler="cosine" \

View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import time
import math
from utils import decode_latents, preprocess_img_tensor
from utils.utils import decode_latents, preprocess_img_tensor
import os
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple, Union