Merge pull request #85 from shounakb1/train_codes

initial data script
This commit is contained in:
czk32611
2024-08-06 18:49:07 +08:00
committed by GitHub
9 changed files with 603 additions and 58 deletions

77
data_new.sh Executable file
View File

@@ -0,0 +1,77 @@
#!/bin/bash
# Function to extract video and audio sections
extract_sections() {
input_video=$1
base_name=$(basename "$input_video" .mp4)
output_dir=$2
split=$3
duration=$(ffmpeg -i "$input_video" 2>&1 | grep Duration | awk '{print $2}' | tr -d ,)
IFS=: read -r hours minutes seconds <<< "$duration"
total_seconds=$((10#${hours}*3600 + 10#${minutes}*60 + 10#${seconds%.*}))
chunk_size=180 # 3 minutes in seconds
index=0
mkdir -p "$output_dir"
while [ $((index * chunk_size)) -lt $total_seconds ]; do
start_time=$((index * chunk_size))
section_video="${output_dir}/${base_name}_part${index}.mp4"
section_audio="${output_dir}/${base_name}_part${index}.mp3"
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -c copy "$section_video"
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -q:a 0 -map a "$section_audio"
# Create and update the config.yaml file
echo "task_0:" > config.yaml
echo " video_path: \"$section_video\"" >> config.yaml
echo " audio_path: \"$section_audio\"" >> config.yaml
# Run the Python script with the current config.yaml
python -m scripts.data --inference_config config.yaml --folder_name "$base_name"
index=$((index + 1))
done
# Clean up save folder
rm -rf $output_dir
}
# Main script
if [ $# -lt 3 ]; then
echo "Usage: $0 <train/test> <output_directory> <input_videos...>"
exit 1
fi
split=$1
output_dir=$2
shift 2
input_videos=("$@")
# Initialize JSON array
json_array="["
for input_video in "${input_videos[@]}"; do
base_name=$(basename "$input_video" .mp4)
# Extract sections and run the Python script for each section
extract_sections "$input_video" "$output_dir" "$split"
# Add entry to JSON array
json_array+="\"../data/images/$base_name\","
done
# Remove trailing comma and close JSON array
json_array="${json_array%,}]"
# Write JSON array to the correct file
if [ "$split" == "train" ]; then
echo "$json_array" > train.json
elif [ "$split" == "test" ]; then
echo "$json_array" > test.json
else
echo "Invalid split: $split. Must be 'train' or 'test'."
exit 1
fi
echo "Processing complete."

256
scripts/data.py Normal file
View File

@@ -0,0 +1,256 @@
import cv2
import os
# import dlib
import argparse
import os
from omegaconf import OmegaConf
import numpy as np
import cv2
import torch
import glob
import pickle
from tqdm import tqdm
import copy
import uuid
from musetalk.utils.utils import get_file_type,get_video_fps
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model
import shutil
import gc
# load model weights
audio_processor, vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
def get_largest_integer_filename(folder_path):
# Check if the folder exists
if not os.path.isdir(folder_path):
return -1
# Get the list of files in the folder
files = os.listdir(folder_path)
# Check if the folder is empty
if not files:
return -1
# Extract the integer part of filenames and find the largest
largest_integer = -1
for file in files:
try:
# Get the integer part of the filename
file_int = int(os.path.splitext(file)[0])
if file_int > largest_integer:
largest_integer = file_int
except ValueError:
# Skip files that don't have an integer filename
continue
return largest_integer
def datagen(whisper_chunks,
crop_images,
batch_size=8,
delay_frame=0):
whisper_batch, crop_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(crop_images)
crop_image = crop_images[idx]
whisper_batch.append(w)
crop_batch.append(crop_image)
if len(crop_batch) >= batch_size:
whisper_batch = np.stack(whisper_batch)
# latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, crop_batch
whisper_batch, crop_batch = [], []
# the last batch may smaller than batch size
if len(crop_batch) > 0:
whisper_batch = np.stack(whisper_batch)
# latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, crop_batch
@torch.no_grad()
def main(args):
global pe
if args.use_float16 is True:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
inference_config = OmegaConf.load(args.inference_config)
total_audio_index=get_largest_integer_filename(f"data/audios/{args.folder_name}")
total_image_index=get_largest_integer_filename(f"data/images/{args.folder_name}")
temp_audio_index=total_audio_index
temp_image_index=total_image_index
for task_id in inference_config:
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
folder_name = args.folder_name
if not os.path.exists(f"data/images/{folder_name}/"):
os.makedirs(f"data/images/{folder_name}")
if not os.path.exists(f"data/audios/{folder_name}/"):
os.makedirs(f"data/audios/{folder_name}")
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path,exist_ok =True)
if args.output_vid_name is None:
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
else:
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
############################################## extract frames from source video ##############################################
if get_file_type(video_path)=="video":
save_dir_full = os.path.join(args.result_dir, input_basename)
os.makedirs(save_dir_full,exist_ok = True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path)=="image":
input_img_list = [video_path, ]
fps = args.fps
elif os.path.isdir(video_path): # input img folder
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path)
for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements
# Combine two consecutive chunks
# pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]])
concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0)
# Save the pair to a .npy file
np.save(f'data/audios/{folder_name}/{total_audio_index+(__//2)+1}.npy', concatenated_chunks)
temp_audio_index=(__//2)+total_audio_index+1
total_audio_index=temp_audio_index
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
############################################## preprocess input image ##############################################
gc.collect()
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates")
with open(crop_coord_save_path,'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("extracting landmarks...time consuming")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
i = 0
input_latent_list = []
crop_i=0
crop_data=[]
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
x1=max(0,x1)
y1=max(0,y1)
x2=max(0,x2)
y2=max(0,y2)
if ((y2-y1)<=0) or ((x2-x1)<=0):
continue
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
crop_data.append(crop_frame)
input_latent_list.append(latents)
crop_i+=1
# to smooth the first and the last frame
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
crop_data = crop_data + crop_data[::-1]
############################################## inference batch by batch ##############################################
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,crop_data,batch_size)
for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
crop_index=0
for image,audio in zip(crop_batch,whisper_batch):
cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index+total_image_index+1)}.png",image)
crop_index+=1
temp_image_index=i+crop_index+total_image_index+1
# np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio)
total_image_index=temp_image_index
gc.collect()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
parser.add_argument("--bbox_shift", type=int, default=0)
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--folder_name", default=f'{uuid.uuid4()}', help="path to output")
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--output_vid_name", type=str, default=None)
parser.add_argument("--use_saved_coord",
action="store_true",
help='use saved coordinate to save time')
parser.add_argument("--use_float16",
action="store_true",
help="Whether use float16 to speed up inference",
)
args = parser.parse_args()
main(args)
def process_audio(audio_path):
whisper_feature = audio_processor.audio2feat(audio_path)
np.save('audio/your_filename.npy', whisper_feature)
def mask_face(image):
# Load dlib's face detector and the landmark predictor
detector = dlib.get_frontal_face_detector()
predictor_path = "/content/shape_predictor_68_face_landmarks.dat" # Set path to your downloaded predictor file
predictor = dlib.shape_predictor(predictor_path)
# Load your input image
# image_path = "/content/ori_frame_00000077.png" # Replace with the path to your input image
# image = cv2.imread(image_path)
if image is None:
raise ValueError("Image not found or unable to load.")
# Convert to grayscale for detection
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Detect faces in the image
faces = detector(gray)
# Process each detected face
for face in faces:
# Predict landmarks
landmarks = predictor(gray, face)
# The indices of nose landmarks are 27 to 35
nose_tip = landmarks.part(33).y
# Blacken the region below the nose tip
blacken_area = image[nose_tip:, :]
blacken_area[:] = (0, 0, 0)
# Save the final image or display it
# cv2.imwrite("output_image.jpg", image)
return image

View File

@@ -0,0 +1,182 @@
import argparse
import os
from omegaconf import OmegaConf
import numpy as np
import cv2
import torch
import glob
import pickle
from tqdm import tqdm
import copy
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model
import shutil
from accelerate import Accelerator
# load model weights
audio_processor, vae, unet, pe = load_all_model()
accelerator = Accelerator(
mixed_precision="fp16",
)
unet = accelerator.prepare(
unet,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
@torch.no_grad()
def main(args):
global pe
if not (args.unet_checkpoint == None):
print("unet ckpt loaded")
accelerator.load_state(args.unet_checkpoint)
if args.use_float16 is True:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
inference_config = OmegaConf.load(args.inference_config)
print(inference_config)
for task_id in inference_config:
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
os.makedirs(result_img_save_path,exist_ok =True)
if args.output_vid_name is None:
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
else:
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
############################################## extract frames from source video ##############################################
if get_file_type(video_path)=="video":
save_dir_full = os.path.join(args.result_dir, input_basename)
os.makedirs(save_dir_full,exist_ok = True)
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path)=="image":
input_img_list = [video_path, ]
fps = args.fps
elif os.path.isdir(video_path): # input img folder
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
############################################## extract audio feature ##############################################
whisper_feature = audio_processor.audio2feat(audio_path)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
############################################## preprocess input image ##############################################
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("using extracted coordinates")
with open(crop_coord_save_path,'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("extracting landmarks...time consuming")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
i = 0
input_latent_list = []
crop_i=0
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
crop_i+=1
# to smooth the first and the last frame
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
############################################## inference batch by batch ##############################################
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
res_frame_list = []
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype) # torch, B, 5*N,384
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
############################################## pad to full image ##############################################
print("pad talking image to original video")
for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
except:
continue
combine_frame = get_image(ori_frame,res_frame,bbox)
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
os.system(cmd_img2video)
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
os.system(cmd_combine_audio)
os.remove("temp.mp4")
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
os.system(cmd_img2video)
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
# print(cmd_combine_audio)
# os.system(cmd_combine_audio)
# shutil.rmtree(result_img_save_path)
print(f"result is save to {output_vid_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
parser.add_argument("--bbox_shift", type=int, default=0)
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--output_vid_name", type=str, default=None)
parser.add_argument("--use_saved_coord",
action="store_true",
help='use saved coordinate to save time')
parser.add_argument("--use_float16",
action="store_true",
help="Whether use float16 to speed up inference",
)
parser.add_argument("--unet_checkpoint", type=str, default=None)
args = parser.parse_args()
main(args)

View File

@@ -158,4 +158,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
main(args)
main(args)

View File

@@ -48,22 +48,22 @@ def get_image_list(data_root, split):
class Dataset(object):
def __init__(self,
data_root,
split,
json_path,
use_audio_length_left=1,
use_audio_length_right=1,
whisper_model_type = "tiny"
):
self.all_videos, self.all_imgNum = get_image_list(data_root, split)
# self.all_videos, self.all_imgNum = get_image_list(data_root, split)
self.audio_feature = [use_audio_length_left,use_audio_length_right]
self.all_img_names = []
self.split = split
self.img_names_path = '...'
# self.split = split
self.img_names_path = '../data'
self.whisper_model_type = whisper_model_type
self.use_audio_length_left = use_audio_length_left
self.use_audio_length_right = use_audio_length_right
if self.whisper_model_type =="tiny":
self.whisper_path = '...'
self.whisper_path = '../data/audios'
self.whisper_feature_W = 5
self.whisper_feature_H = 384
elif self.whisper_model_type =="largeV2":
@@ -71,6 +71,8 @@ class Dataset(object):
self.whisper_feature_W = 33
self.whisper_feature_H = 1280
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*2+2+1= 50
with open(json_path, 'r') as file:
self.all_videos = json.load(file)
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
@@ -79,7 +81,6 @@ class Dataset(object):
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
with open(json_path_names, "w") as f:
json.dump(img_names,f)
print(f"save to {json_path_names}")
else:
with open(json_path_names, "r") as f:
img_names = json.load(f)
@@ -135,7 +136,6 @@ class Dataset(object):
vidname = self.all_videos[idx].split('/')[-1]
video_imgs = self.all_img_names[idx]
if len(video_imgs) == 0:
# print("video_imgs = 0:",vidname)
continue
img_name = random.choice(video_imgs)
img_idx = int(basename(img_name).split(".")[0])
@@ -193,7 +193,6 @@ class Dataset(object):
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
# 判定是否越界
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
if not os.path.exists(audio_feat_path):
is_index_out_of_range = True
break
@@ -214,8 +213,6 @@ class Dataset(object):
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
continue
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
return ref_image, image, masked_image, mask, audio_feature
@@ -231,10 +228,8 @@ if __name__ == "__main__":
val_data_loader = data_utils.DataLoader(
val_data, batch_size=4, shuffle=True,
num_workers=1)
print("val_dataset:",val_data_loader.__len__())
for i, data in enumerate(val_data_loader):
ref_image, image, masked_image, mask, audio_feature = data
print("ref_image: ", ref_image.shape)

View File

@@ -1,15 +1,17 @@
# Draft training codes
# Data preprocessing
We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized.
Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
The train yaml file should contain the training video paths and corresponding audio paths
The test yaml file should contain the validation video paths and corresponding audio paths
## Setup
Run:
```
./data_new.sh train output train_video1.mp4 train_video2.mp4
./data_new.sh test output test_video1.mp4 test_video2.mp4
```
This creates folders which contain the image frames and npy files. This also creates train.json and val.json which can be used during the training.
We trained our model on an NVIDIA A100 with `batch size=8, gradient_accumulation_steps=4` for 20w+ steps. Using multiple GPUs should accelerate the training.
## Data preprocessing
You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69).
Finally, the data should be organized as follows:
## Data organization
```
./data/
├── images
@@ -35,9 +37,16 @@ Finally, the data should be organized as follows:
## Training
Simply run after preparing the preprocessed data
```
sh train.sh
cd train_codes
sh train.sh #--train_json="../train.json" \(Generated in Data preprocessing step.)
#--val_json="../val.json" \
```
## Inference with trained checkpoit
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
```
python -m scripts.finetuned_inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
```
## TODO
- [ ] release data preprocessing codes
- [x] release data preprocessing codes
- [ ] release some novel designs in training (after technical report)

View File

@@ -27,10 +27,13 @@ from diffusers import (
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
import sys
sys.path.append("./")
from DataLoader import Dataset
from utils.utils import preprocess_img_tensor
from torch.utils import data as data_utils
from model_utils import validation,PositionalEncoding
from utils.model_utils import validation,PositionalEncoding
import time
import pandas as pd
from PIL import Image
@@ -137,6 +140,8 @@ def parse_args():
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders")
parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders")
parser.add_argument(
"--hub_model_id",
type=str,
@@ -234,13 +239,17 @@ def parse_args():
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def print_model_dtypes(model, model_name):
for name, param in model.named_parameters():
if(param.dtype!=torch.float32):
print(f"{name}: {param.dtype}")
def main():
args = parse_args()
print(args)
args.output_dir = f"output/{args.output_dir}"
args.val_out_dir = f"val/{args.val_out_dir}"
os.makedirs(args.output_dir, exist_ok=True)
@@ -332,7 +341,7 @@ def main():
optimizer_class = torch.optim.AdamW
params_to_optimize = (
itertools.chain(unet.parameters())
itertools.chain(unet.parameters()))
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
@@ -343,23 +352,21 @@ def main():
print("loading train_dataset ...")
train_dataset = Dataset(args.data_root,
'train',
args.train_json,
use_audio_length_left=args.use_audio_length_left,
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
print("train_dataset:",train_dataset.__len__())
train_data_loader = data_utils.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True,
num_workers=8)
print("loading val_dataset ...")
val_dataset = Dataset(args.data_root,
'val',
args.val_json,
use_audio_length_left=args.use_audio_length_left,
use_audio_length_right=args.use_audio_length_right,
whisper_model_type=args.whisper_model_type
)
print("val_dataset:",val_dataset.__len__())
val_data_loader = data_utils.DataLoader(
val_dataset, batch_size=1, shuffle=False,
num_workers=8)
@@ -388,6 +395,7 @@ def main():
vae_fp32.requires_grad_(False)
weight_dtype = torch.float32
# weight_dtype = torch.float16
vae_fp32.to(accelerator.device, dtype=weight_dtype)
vae_fp32.encoder = None
if accelerator.mixed_precision == "fp16":
@@ -412,6 +420,8 @@ def main():
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print(f" Num batches each epoch = {len(train_data_loader)}")
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
@@ -433,6 +443,9 @@ def main():
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
# path="../models/pytorch_model.bin"
#TODO change path
# path=None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
@@ -458,10 +471,11 @@ def main():
# caluate the elapsed time
elapsed_time = []
start = time.time()
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
# for step, batch in enumerate(train_dataloader):
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
@@ -470,24 +484,23 @@ def main():
continue
dataloader_time = time.time() - start
start = time.time()
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
"""
print("=============epoch:{0}=step:{1}=====".format(epoch,step))
print("ref_image: ",ref_image.shape)
print("masks: ", masks.shape)
print("masked_image: ", masked_image.shape)
print("audio feature: ", audio_feature.shape)
print("image: ", image.shape)
"""
# """
# print("=============epoch:{0}=step:{1}=====".format(epoch,step))
# print("ref_image: ",ref_image.shape)
# print("masks: ", masks.shape)
# print("masked_image: ", masked_image.shape)
# print("audio feature: ", audio_feature.shape)
# print("image: ", image.shape)
# """
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
image = preprocess_img_tensor(image).to(vae.device)
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
img_process_time = time.time() - start
start = time.time()
with accelerator.accumulate(unet):
vae = vae.half()
# Convert images to latent space
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
latents = latents * vae.config.scaling_factor
@@ -592,12 +605,23 @@ def main():
f"Running validation... epoch={epoch}, global_step={global_step}"
)
print("===========start validation==========")
# Use the helper function to check the data types for each model
vae_new = vae.float()
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
print(f"weight_dtype: {weight_dtype}")
print(f"epoch type: {type(epoch)}")
print(f"global_step type: {type(global_step)}")
validation(
vae=accelerator.unwrap_model(vae),
# vae=accelerator.unwrap_model(vae),
vae=accelerator.unwrap_model(vae_new),
vae_fp32=accelerator.unwrap_model(vae_fp32),
unet=accelerator.unwrap_model(unet),
unet_config=unet_config,
weight_dtype=weight_dtype,
# weight_dtype=weight_dtype,
weight_dtype=torch.float32,
epoch=epoch,
global_step=global_step,
val_data_loader=val_data_loader,

View File

@@ -1,27 +1,29 @@
export VAE_MODEL="./sd-vae-ft-mse/"
export DATASET="..."
export UNET_CONFIG="./musetalk.json"
export VAE_MODEL="../models/sd-vae-ft-mse/"
export DATASET="../data"
export UNET_CONFIG="../models/musetalk/musetalk.json"
accelerate launch --multi_gpu train.py \
accelerate launch train.py \
--mixed_precision="fp16" \
--unet_config_file=$UNET_CONFIG \
--pretrained_model_name_or_path=$VAE_MODEL \
--data_root=$DATASET \
--train_batch_size=8 \
--gradient_accumulation_steps=4 \
--train_batch_size=256 \
--gradient_accumulation_steps=16 \
--gradient_checkpointing \
--max_train_steps=200000 \
--max_train_steps=100000 \
--learning_rate=5e-05 \
--max_grad_norm=1 \
--lr_scheduler="cosine" \
--lr_warmup_steps=0 \
--output_dir="..." \
--val_out_dir='...' \
--output_dir="output" \
--val_out_dir='val' \
--testing_speed \
--checkpointing_steps=1000 \
--validation_steps=1000 \
--checkpointing_steps=2000 \
--validation_steps=2000 \
--reconstruction \
--resume_from_checkpoint="latest" \
--use_audio_length_left=2 \
--use_audio_length_right=2 \
--whisper_model_type="tiny" \
--train_json="../train.json" \
--val_json="../val.json" \
--lr_scheduler="cosine" \

View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import time
import math
from utils import decode_latents, preprocess_img_tensor
from utils.utils import decode_latents, preprocess_img_tensor
import os
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple, Union