mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 17:39:20 +08:00
modified dataloader.py and inference.py for training and inference
This commit is contained in:
@@ -15,14 +15,27 @@ from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="fp16",
|
||||
)
|
||||
unet = accelerator.prepare(
|
||||
unet,
|
||||
|
||||
)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
global pe
|
||||
if not (args.unet_checkpoint == None):
|
||||
print("unet ckpt loaded")
|
||||
accelerator.load_state(args.unet_checkpoint)
|
||||
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
@@ -63,8 +76,6 @@ def main(args):
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
|
||||
#print(input_img_list)
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
@@ -79,24 +90,27 @@ def main(args):
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
crop_i=0
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
crop_i+=1
|
||||
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
print("start inference")
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
@@ -107,7 +121,6 @@ def main(args):
|
||||
dtype=unet.model.dtype) # torch, B, 5*N,384
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
@@ -122,22 +135,29 @@ def main(args):
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
except:
|
||||
# print(bbox)
|
||||
continue
|
||||
|
||||
combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
print(cmd_img2video)
|
||||
os.system(cmd_img2video)
|
||||
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
print(cmd_combine_audio)
|
||||
os.system(cmd_combine_audio)
|
||||
|
||||
os.remove("temp.mp4")
|
||||
shutil.rmtree(result_img_save_path)
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
os.system(cmd_img2video)
|
||||
|
||||
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
# print(cmd_combine_audio)
|
||||
# os.system(cmd_combine_audio)
|
||||
|
||||
# shutil.rmtree(result_img_save_path)
|
||||
print(f"result is save to {output_vid_name}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -156,6 +176,7 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
parser.add_argument("--unet_checkpoint", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
main(args)
|
||||
@@ -57,13 +57,13 @@ class Dataset(object):
|
||||
self.audio_feature = [use_audio_length_left,use_audio_length_right]
|
||||
self.all_img_names = []
|
||||
self.split = split
|
||||
self.img_names_path = '...'
|
||||
self.img_names_path = '../data'
|
||||
self.whisper_model_type = whisper_model_type
|
||||
self.use_audio_length_left = use_audio_length_left
|
||||
self.use_audio_length_right = use_audio_length_right
|
||||
|
||||
if self.whisper_model_type =="tiny":
|
||||
self.whisper_path = '...'
|
||||
self.whisper_path = '../data/audios'
|
||||
self.whisper_feature_W = 5
|
||||
self.whisper_feature_H = 384
|
||||
elif self.whisper_model_type =="largeV2":
|
||||
@@ -72,6 +72,10 @@ class Dataset(object):
|
||||
self.whisper_feature_H = 1280
|
||||
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50
|
||||
|
||||
if(self.split=="train"):
|
||||
self.all_videos=["../data/images/train"]
|
||||
if(self.split=="val"):
|
||||
self.all_videos=["../data/images/test"]
|
||||
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
|
||||
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
|
||||
if not os.path.exists(json_path_names):
|
||||
@@ -79,7 +83,6 @@ class Dataset(object):
|
||||
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
|
||||
with open(json_path_names, "w") as f:
|
||||
json.dump(img_names,f)
|
||||
print(f"save to {json_path_names}")
|
||||
else:
|
||||
with open(json_path_names, "r") as f:
|
||||
img_names = json.load(f)
|
||||
@@ -147,7 +150,6 @@ class Dataset(object):
|
||||
vidname = self.all_videos[idx].split('/')[-1]
|
||||
video_imgs = self.all_img_names[idx]
|
||||
if len(video_imgs) == 0:
|
||||
# print("video_imgs = 0:",vidname)
|
||||
continue
|
||||
img_name = random.choice(video_imgs)
|
||||
img_idx = int(basename(img_name).split(".")[0])
|
||||
@@ -205,7 +207,6 @@ class Dataset(object):
|
||||
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
|
||||
# 判定是否越界
|
||||
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
|
||||
|
||||
if not os.path.exists(audio_feat_path):
|
||||
is_index_out_of_range = True
|
||||
break
|
||||
@@ -226,8 +227,6 @@ class Dataset(object):
|
||||
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
|
||||
continue
|
||||
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
|
||||
|
||||
|
||||
return ref_image, image, masked_image, mask, audio_feature
|
||||
|
||||
|
||||
@@ -243,10 +242,8 @@ if __name__ == "__main__":
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_data, batch_size=4, shuffle=True,
|
||||
num_workers=1)
|
||||
print("val_dataset:",val_data_loader.__len__())
|
||||
|
||||
for i, data in enumerate(val_data_loader):
|
||||
ref_image, image, masked_image, mask, audio_feature = data
|
||||
print("ref_image: ", ref_image.shape)
|
||||
|
||||
|
||||
@@ -1,32 +1,35 @@
|
||||
# Draft training codes
|
||||
# Data preprocessing
|
||||
|
||||
We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized.
|
||||
Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
|
||||
The train yaml file should contain the training video paths and corresponding audio paths
|
||||
The test yaml file should contain the validation video paths and corresponding audio paths
|
||||
|
||||
## Setup
|
||||
Run:
|
||||
```
|
||||
python -m scripts.data --inference_config path_to_train.yaml --folder_name train
|
||||
python -m scripts.data --inference_config path_to_test.yaml --folder_name test
|
||||
```
|
||||
This creates folders which contain the image frames and npy files.
|
||||
|
||||
We trained our model on an NVIDIA A100 with `batch size=8, gradient_accumulation_steps=4` for 20w+ steps. Using multiple GPUs should accelerate the training.
|
||||
|
||||
## Data preprocessing
|
||||
You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69).
|
||||
|
||||
Finally, the data should be organized as follows:
|
||||
## Data organization
|
||||
```
|
||||
./data/
|
||||
├── images
|
||||
│ └──RD_Radio10_000
|
||||
│ └──train
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
│ └──RD_Radio11_000
|
||||
│ └──test
|
||||
│ └── 0.png
|
||||
│ └── 1.png
|
||||
│ └── xxx.png
|
||||
├── audios
|
||||
│ └──RD_Radio10_000
|
||||
│ └──train
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
│ └──RD_Radio11_000
|
||||
│ └──test
|
||||
│ └── 0.npy
|
||||
│ └── 1.npy
|
||||
│ └── xxx.npy
|
||||
@@ -37,7 +40,12 @@ Simply run after preparing the preprocessed data
|
||||
```
|
||||
sh train.sh
|
||||
```
|
||||
## Inference with trained checkpoit
|
||||
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
|
||||
```
|
||||
python -m scripts.inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
|
||||
```
|
||||
|
||||
## TODO
|
||||
- [ ] release data preprocessing codes
|
||||
- [x] release data preprocessing codes
|
||||
- [ ] release some novel designs in training (after technical report)
|
||||
@@ -27,10 +27,13 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
|
||||
import sys
|
||||
sys.path.append("./")
|
||||
|
||||
from DataLoader import Dataset
|
||||
from utils.utils import preprocess_img_tensor
|
||||
from torch.utils import data as data_utils
|
||||
from model_utils import validation,PositionalEncoding
|
||||
from utils.model_utils import validation,PositionalEncoding
|
||||
import time
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
@@ -234,13 +237,17 @@ def parse_args():
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def print_model_dtypes(model, model_name):
|
||||
for name, param in model.named_parameters():
|
||||
if(param.dtype!=torch.float32):
|
||||
print(f"{name}: {param.dtype}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
print(args)
|
||||
args.output_dir = f"output/{args.output_dir}"
|
||||
args.val_out_dir = f"val/{args.val_out_dir}"
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
@@ -332,7 +339,7 @@ def main():
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet.parameters())
|
||||
itertools.chain(unet.parameters()))
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
@@ -348,7 +355,6 @@ def main():
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
print("train_dataset:",train_dataset.__len__())
|
||||
train_data_loader = data_utils.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True,
|
||||
num_workers=8)
|
||||
@@ -359,7 +365,6 @@ def main():
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
print("val_dataset:",val_dataset.__len__())
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_dataset, batch_size=1, shuffle=False,
|
||||
num_workers=8)
|
||||
@@ -388,6 +393,7 @@ def main():
|
||||
vae_fp32.requires_grad_(False)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
# weight_dtype = torch.float16
|
||||
vae_fp32.to(accelerator.device, dtype=weight_dtype)
|
||||
vae_fp32.encoder = None
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
@@ -412,6 +418,8 @@ def main():
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
print(f" Num batches each epoch = {len(train_data_loader)}")
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
|
||||
@@ -433,6 +441,9 @@ def main():
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
# path="../models/pytorch_model.bin"
|
||||
#TODO change path
|
||||
# path=None
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
@@ -458,10 +469,11 @@ def main():
|
||||
# caluate the elapsed time
|
||||
elapsed_time = []
|
||||
start = time.time()
|
||||
|
||||
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
# for step, batch in enumerate(train_dataloader):
|
||||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
@@ -470,24 +482,23 @@ def main():
|
||||
continue
|
||||
dataloader_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||||
"""
|
||||
print("=============epoch:{0}=step:{1}=====".format(epoch,step))
|
||||
print("ref_image: ",ref_image.shape)
|
||||
print("masks: ", masks.shape)
|
||||
print("masked_image: ", masked_image.shape)
|
||||
print("audio feature: ", audio_feature.shape)
|
||||
print("image: ", image.shape)
|
||||
"""
|
||||
# """
|
||||
# print("=============epoch:{0}=step:{1}=====".format(epoch,step))
|
||||
# print("ref_image: ",ref_image.shape)
|
||||
# print("masks: ", masks.shape)
|
||||
# print("masked_image: ", masked_image.shape)
|
||||
# print("audio feature: ", audio_feature.shape)
|
||||
# print("image: ", image.shape)
|
||||
# """
|
||||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||||
image = preprocess_img_tensor(image).to(vae.device)
|
||||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||||
|
||||
img_process_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
vae = vae.half()
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -592,12 +603,23 @@ def main():
|
||||
f"Running validation... epoch={epoch}, global_step={global_step}"
|
||||
)
|
||||
print("===========start validation==========")
|
||||
# Use the helper function to check the data types for each model
|
||||
vae_new = vae.float()
|
||||
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
|
||||
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
|
||||
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
|
||||
|
||||
print(f"weight_dtype: {weight_dtype}")
|
||||
print(f"epoch type: {type(epoch)}")
|
||||
print(f"global_step type: {type(global_step)}")
|
||||
validation(
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
# vae=accelerator.unwrap_model(vae),
|
||||
vae=accelerator.unwrap_model(vae_new),
|
||||
vae_fp32=accelerator.unwrap_model(vae_fp32),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet_config=unet_config,
|
||||
weight_dtype=weight_dtype,
|
||||
# weight_dtype=weight_dtype,
|
||||
weight_dtype=torch.float32,
|
||||
epoch=epoch,
|
||||
global_step=global_step,
|
||||
val_data_loader=val_data_loader,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
export VAE_MODEL="./sd-vae-ft-mse/"
|
||||
export DATASET="..."
|
||||
export UNET_CONFIG="./musetalk.json"
|
||||
export VAE_MODEL="../models/sd-vae-ft-mse/"
|
||||
export DATASET="../data"
|
||||
export UNET_CONFIG="../models/musetalk/musetalk.json"
|
||||
|
||||
accelerate launch --multi_gpu train.py \
|
||||
accelerate launch train.py \
|
||||
--mixed_precision="fp16" \
|
||||
--unet_config_file=$UNET_CONFIG \
|
||||
--pretrained_model_name_or_path=$VAE_MODEL \
|
||||
@@ -10,13 +10,13 @@ accelerate launch --multi_gpu train.py \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--max_train_steps=200000 \
|
||||
--max_train_steps=50000 \
|
||||
--learning_rate=5e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="cosine" \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="..." \
|
||||
--val_out_dir='...' \
|
||||
--output_dir="output" \
|
||||
--val_out_dir='val' \
|
||||
--testing_speed \
|
||||
--checkpointing_steps=1000 \
|
||||
--validation_steps=1000 \
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import math
|
||||
from utils import decode_latents, preprocess_img_tensor
|
||||
from utils.utils import decode_latents, preprocess_img_tensor
|
||||
import os
|
||||
from PIL import Image
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
Reference in New Issue
Block a user