diff --git a/scripts/inference.py b/scripts/inference.py index fe2a234..dddce31 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -15,14 +15,27 @@ from musetalk.utils.blending import get_image from musetalk.utils.utils import load_all_model import shutil +from accelerate import Accelerator + # load model weights audio_processor, vae, unet, pe = load_all_model() +accelerator = Accelerator( + mixed_precision="fp16", + ) +unet = accelerator.prepare( + unet, + + ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") timesteps = torch.tensor([0], device=device) @torch.no_grad() def main(args): global pe + if not (args.unet_checkpoint == None): + print("unet ckpt loaded") + accelerator.load_state(args.unet_checkpoint) + if args.use_float16 is True: pe = pe.half() vae.vae = vae.vae.half() @@ -63,8 +76,6 @@ def main(args): fps = args.fps else: raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") - - #print(input_img_list) ############################################## extract audio feature ############################################## whisper_feature = audio_processor.audio2feat(audio_path) whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps) @@ -79,24 +90,27 @@ def main(args): coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) with open(crop_coord_save_path, 'wb') as f: pickle.dump(coord_list, f) + i = 0 input_latent_list = [] + crop_i=0 for bbox, frame in zip(coord_list, frame_list): if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox crop_frame = frame[y1:y2, x1:x2] crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4) + cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame) latents = vae.get_latents_for_unet(crop_frame) input_latent_list.append(latents) + crop_i+=1 # to smooth the first and the last frame frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] ############################################## inference batch by batch ############################################## - print("start inference") video_num = len(whisper_chunks) batch_size = args.batch_size gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size) @@ -107,7 +121,6 @@ def main(args): dtype=unet.model.dtype) # torch, B, 5*N,384 audio_feature_batch = pe(audio_feature_batch) latent_batch = latent_batch.to(dtype=unet.model.dtype) - pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample recon = vae.decode_latents(pred_latents) for res_frame in recon: @@ -122,22 +135,29 @@ def main(args): try: res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1)) except: -# print(bbox) continue combine_frame = get_image(ori_frame,res_frame,bbox) + cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame) + cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame) cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame) cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" - print(cmd_img2video) os.system(cmd_img2video) cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" - print(cmd_combine_audio) os.system(cmd_combine_audio) os.remove("temp.mp4") - shutil.rmtree(result_img_save_path) + + cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4" + os.system(cmd_img2video) + + # cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}" + # print(cmd_combine_audio) + # os.system(cmd_combine_audio) + + # shutil.rmtree(result_img_save_path) print(f"result is save to {output_vid_name}") if __name__ == "__main__": @@ -156,6 +176,7 @@ if __name__ == "__main__": action="store_true", help="Whether use float16 to speed up inference", ) + parser.add_argument("--unet_checkpoint", type=str, default=None) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/train_codes/DataLoader.py b/train_codes/DataLoader.py index 8dacbde..f0652e3 100644 --- a/train_codes/DataLoader.py +++ b/train_codes/DataLoader.py @@ -57,13 +57,13 @@ class Dataset(object): self.audio_feature = [use_audio_length_left,use_audio_length_right] self.all_img_names = [] self.split = split - self.img_names_path = '...' + self.img_names_path = '../data' self.whisper_model_type = whisper_model_type self.use_audio_length_left = use_audio_length_left self.use_audio_length_right = use_audio_length_right if self.whisper_model_type =="tiny": - self.whisper_path = '...' + self.whisper_path = '../data/audios' self.whisper_feature_W = 5 self.whisper_feature_H = 384 elif self.whisper_model_type =="largeV2": @@ -72,6 +72,10 @@ class Dataset(object): self.whisper_feature_H = 1280 self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50 + if(self.split=="train"): + self.all_videos=["../data/images/train"] + if(self.split=="val"): + self.all_videos=["../data/images/test"] for vidname in tqdm(self.all_videos, desc="Preparing dataset"): json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json" if not os.path.exists(json_path_names): @@ -79,7 +83,6 @@ class Dataset(object): img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0])) with open(json_path_names, "w") as f: json.dump(img_names,f) - print(f"save to {json_path_names}") else: with open(json_path_names, "r") as f: img_names = json.load(f) @@ -147,7 +150,6 @@ class Dataset(object): vidname = self.all_videos[idx].split('/')[-1] video_imgs = self.all_img_names[idx] if len(video_imgs) == 0: -# print("video_imgs = 0:",vidname) continue img_name = random.choice(video_imgs) img_idx = int(basename(img_name).split(".")[0]) @@ -205,7 +207,6 @@ class Dataset(object): for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1): # 判定是否越界 audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy") - if not os.path.exists(audio_feat_path): is_index_out_of_range = True break @@ -226,8 +227,6 @@ class Dataset(object): print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}") continue audio_feature = torch.squeeze(torch.FloatTensor(audio_feature)) - - return ref_image, image, masked_image, mask, audio_feature @@ -243,10 +242,8 @@ if __name__ == "__main__": val_data_loader = data_utils.DataLoader( val_data, batch_size=4, shuffle=True, num_workers=1) - print("val_dataset:",val_data_loader.__len__()) for i, data in enumerate(val_data_loader): ref_image, image, masked_image, mask, audio_feature = data - print("ref_image: ", ref_image.shape) \ No newline at end of file diff --git a/train_codes/README.md b/train_codes/README.md index db9848c..ba0f370 100644 --- a/train_codes/README.md +++ b/train_codes/README.md @@ -1,32 +1,35 @@ -# Draft training codes +# Data preprocessing -We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized. +Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml) +The train yaml file should contain the training video paths and corresponding audio paths +The test yaml file should contain the validation video paths and corresponding audio paths -## Setup +Run: +``` +python -m scripts.data --inference_config path_to_train.yaml --folder_name train +python -m scripts.data --inference_config path_to_test.yaml --folder_name test +``` +This creates folders which contain the image frames and npy files. -We trained our model on an NVIDIA A100 with `batch size=8, gradient_accumulation_steps=4` for 20w+ steps. Using multiple GPUs should accelerate the training. -## Data preprocessing - You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69). - -Finally, the data should be organized as follows: +## Data organization ``` ./data/ ├── images -│ └──RD_Radio10_000 +│ └──train │ └── 0.png │ └── 1.png │ └── xxx.png -│ └──RD_Radio11_000 +│ └──test │ └── 0.png │ └── 1.png │ └── xxx.png ├── audios -│ └──RD_Radio10_000 +│ └──train │ └── 0.npy │ └── 1.npy │ └── xxx.npy -│ └──RD_Radio11_000 +│ └──test │ └── 0.npy │ └── 1.npy │ └── xxx.npy @@ -37,7 +40,12 @@ Simply run after preparing the preprocessed data ``` sh train.sh ``` +## Inference with trained checkpoit +Simply run after training the model, the model checkpoints are saved at train_codes/output usually +``` +python -m scripts.inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder +``` ## TODO -- [ ] release data preprocessing codes +- [x] release data preprocessing codes - [ ] release some novel designs in training (after technical report) \ No newline at end of file diff --git a/train_codes/train.py b/train_codes/train.py index cb50c1c..37a9447 100755 --- a/train_codes/train.py +++ b/train_codes/train.py @@ -27,10 +27,13 @@ from diffusers import ( from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version +import sys +sys.path.append("./") + from DataLoader import Dataset from utils.utils import preprocess_img_tensor from torch.utils import data as data_utils -from model_utils import validation,PositionalEncoding +from utils.model_utils import validation,PositionalEncoding import time import pandas as pd from PIL import Image @@ -234,13 +237,17 @@ def parse_args(): env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank + + return args - +def print_model_dtypes(model, model_name): + for name, param in model.named_parameters(): + if(param.dtype!=torch.float32): + print(f"{name}: {param.dtype}") def main(): args = parse_args() - print(args) args.output_dir = f"output/{args.output_dir}" args.val_out_dir = f"val/{args.val_out_dir}" os.makedirs(args.output_dir, exist_ok=True) @@ -332,7 +339,7 @@ def main(): optimizer_class = torch.optim.AdamW params_to_optimize = ( - itertools.chain(unet.parameters()) + itertools.chain(unet.parameters())) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -348,7 +355,6 @@ def main(): use_audio_length_right=args.use_audio_length_right, whisper_model_type=args.whisper_model_type ) - print("train_dataset:",train_dataset.__len__()) train_data_loader = data_utils.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=8) @@ -359,7 +365,6 @@ def main(): use_audio_length_right=args.use_audio_length_right, whisper_model_type=args.whisper_model_type ) - print("val_dataset:",val_dataset.__len__()) val_data_loader = data_utils.DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=8) @@ -388,6 +393,7 @@ def main(): vae_fp32.requires_grad_(False) weight_dtype = torch.float32 + # weight_dtype = torch.float16 vae_fp32.to(accelerator.device, dtype=weight_dtype) vae_fp32.encoder = None if accelerator.mixed_precision == "fp16": @@ -412,6 +418,8 @@ def main(): # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print(f" Num batches each epoch = {len(train_data_loader)}") + logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_data_loader)}") @@ -433,6 +441,9 @@ def main(): dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None + # path="../models/pytorch_model.bin" + #TODO change path + # path=None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." @@ -458,10 +469,11 @@ def main(): # caluate the elapsed time elapsed_time = [] start = time.time() + + for epoch in range(first_epoch, args.num_train_epochs): unet.train() -# for step, batch in enumerate(train_dataloader): for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader): # Skip steps until we reach the resumed step if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: @@ -470,24 +482,23 @@ def main(): continue dataloader_time = time.time() - start start = time.time() - masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device) - """ - print("=============epoch:{0}=step:{1}=====".format(epoch,step)) - print("ref_image: ",ref_image.shape) - print("masks: ", masks.shape) - print("masked_image: ", masked_image.shape) - print("audio feature: ", audio_feature.shape) - print("image: ", image.shape) - """ + # """ + # print("=============epoch:{0}=step:{1}=====".format(epoch,step)) + # print("ref_image: ",ref_image.shape) + # print("masks: ", masks.shape) + # print("masked_image: ", masked_image.shape) + # print("audio feature: ", audio_feature.shape) + # print("image: ", image.shape) + # """ ref_image = preprocess_img_tensor(ref_image).to(vae.device) image = preprocess_img_tensor(image).to(vae.device) masked_image = preprocess_img_tensor(masked_image).to(vae.device) img_process_time = time.time() - start start = time.time() - with accelerator.accumulate(unet): + vae = vae.half() # Convert images to latent space latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image latents = latents * vae.config.scaling_factor @@ -592,12 +603,23 @@ def main(): f"Running validation... epoch={epoch}, global_step={global_step}" ) print("===========start validation==========") + # Use the helper function to check the data types for each model + vae_new = vae.float() + print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE") + print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32") + print_model_dtypes(accelerator.unwrap_model(unet), "UNET") + + print(f"weight_dtype: {weight_dtype}") + print(f"epoch type: {type(epoch)}") + print(f"global_step type: {type(global_step)}") validation( - vae=accelerator.unwrap_model(vae), + # vae=accelerator.unwrap_model(vae), + vae=accelerator.unwrap_model(vae_new), vae_fp32=accelerator.unwrap_model(vae_fp32), unet=accelerator.unwrap_model(unet), unet_config=unet_config, - weight_dtype=weight_dtype, + # weight_dtype=weight_dtype, + weight_dtype=torch.float32, epoch=epoch, global_step=global_step, val_data_loader=val_data_loader, diff --git a/train_codes/train.sh b/train_codes/train.sh index 908a676..8c60b48 100644 --- a/train_codes/train.sh +++ b/train_codes/train.sh @@ -1,8 +1,8 @@ -export VAE_MODEL="./sd-vae-ft-mse/" -export DATASET="..." -export UNET_CONFIG="./musetalk.json" +export VAE_MODEL="../models/sd-vae-ft-mse/" +export DATASET="../data" +export UNET_CONFIG="../models/musetalk/musetalk.json" -accelerate launch --multi_gpu train.py \ +accelerate launch train.py \ --mixed_precision="fp16" \ --unet_config_file=$UNET_CONFIG \ --pretrained_model_name_or_path=$VAE_MODEL \ @@ -10,13 +10,13 @@ accelerate launch --multi_gpu train.py \ --train_batch_size=8 \ --gradient_accumulation_steps=4 \ --gradient_checkpointing \ ---max_train_steps=200000 \ +--max_train_steps=50000 \ --learning_rate=5e-05 \ --max_grad_norm=1 \ --lr_scheduler="cosine" \ --lr_warmup_steps=0 \ ---output_dir="..." \ ---val_out_dir='...' \ +--output_dir="output" \ +--val_out_dir='val' \ --testing_speed \ --checkpointing_steps=1000 \ --validation_steps=1000 \ diff --git a/train_codes/utils/model_utils.py b/train_codes/utils/model_utils.py index e0fd2e2..0534521 100644 --- a/train_codes/utils/model_utils.py +++ b/train_codes/utils/model_utils.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import time import math -from utils import decode_latents, preprocess_img_tensor +from utils.utils import decode_latents, preprocess_img_tensor import os from PIL import Image from typing import Any, Dict, List, Optional, Tuple, Union