mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 18:09:19 +08:00
77
data_new.sh
Executable file
77
data_new.sh
Executable file
@@ -0,0 +1,77 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Function to extract video and audio sections
|
||||||
|
extract_sections() {
|
||||||
|
input_video=$1
|
||||||
|
base_name=$(basename "$input_video" .mp4)
|
||||||
|
output_dir=$2
|
||||||
|
split=$3
|
||||||
|
duration=$(ffmpeg -i "$input_video" 2>&1 | grep Duration | awk '{print $2}' | tr -d ,)
|
||||||
|
IFS=: read -r hours minutes seconds <<< "$duration"
|
||||||
|
total_seconds=$((10#${hours}*3600 + 10#${minutes}*60 + 10#${seconds%.*}))
|
||||||
|
chunk_size=180 # 3 minutes in seconds
|
||||||
|
index=0
|
||||||
|
|
||||||
|
mkdir -p "$output_dir"
|
||||||
|
|
||||||
|
while [ $((index * chunk_size)) -lt $total_seconds ]; do
|
||||||
|
start_time=$((index * chunk_size))
|
||||||
|
section_video="${output_dir}/${base_name}_part${index}.mp4"
|
||||||
|
section_audio="${output_dir}/${base_name}_part${index}.mp3"
|
||||||
|
|
||||||
|
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -c copy "$section_video"
|
||||||
|
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -q:a 0 -map a "$section_audio"
|
||||||
|
|
||||||
|
# Create and update the config.yaml file
|
||||||
|
echo "task_0:" > config.yaml
|
||||||
|
echo " video_path: \"$section_video\"" >> config.yaml
|
||||||
|
echo " audio_path: \"$section_audio\"" >> config.yaml
|
||||||
|
|
||||||
|
# Run the Python script with the current config.yaml
|
||||||
|
python -m scripts.data --inference_config config.yaml --folder_name "$base_name"
|
||||||
|
|
||||||
|
index=$((index + 1))
|
||||||
|
done
|
||||||
|
|
||||||
|
# Clean up save folder
|
||||||
|
rm -rf $output_dir
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main script
|
||||||
|
if [ $# -lt 3 ]; then
|
||||||
|
echo "Usage: $0 <train/test> <output_directory> <input_videos...>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
split=$1
|
||||||
|
output_dir=$2
|
||||||
|
shift 2
|
||||||
|
input_videos=("$@")
|
||||||
|
|
||||||
|
# Initialize JSON array
|
||||||
|
json_array="["
|
||||||
|
|
||||||
|
for input_video in "${input_videos[@]}"; do
|
||||||
|
base_name=$(basename "$input_video" .mp4)
|
||||||
|
|
||||||
|
# Extract sections and run the Python script for each section
|
||||||
|
extract_sections "$input_video" "$output_dir" "$split"
|
||||||
|
|
||||||
|
# Add entry to JSON array
|
||||||
|
json_array+="\"../data/images/$base_name\","
|
||||||
|
done
|
||||||
|
|
||||||
|
# Remove trailing comma and close JSON array
|
||||||
|
json_array="${json_array%,}]"
|
||||||
|
|
||||||
|
# Write JSON array to the correct file
|
||||||
|
if [ "$split" == "train" ]; then
|
||||||
|
echo "$json_array" > train.json
|
||||||
|
elif [ "$split" == "test" ]; then
|
||||||
|
echo "$json_array" > test.json
|
||||||
|
else
|
||||||
|
echo "Invalid split: $split. Must be 'train' or 'test'."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Processing complete."
|
||||||
256
scripts/data.py
Normal file
256
scripts/data.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
# import dlib
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import glob
|
||||||
|
import pickle
|
||||||
|
from tqdm import tqdm
|
||||||
|
import copy
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from musetalk.utils.utils import get_file_type,get_video_fps
|
||||||
|
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||||
|
from musetalk.utils.blending import get_image
|
||||||
|
from musetalk.utils.utils import load_all_model
|
||||||
|
import shutil
|
||||||
|
import gc
|
||||||
|
|
||||||
|
# load model weights
|
||||||
|
audio_processor, vae, unet, pe = load_all_model()
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
timesteps = torch.tensor([0], device=device)
|
||||||
|
|
||||||
|
def get_largest_integer_filename(folder_path):
|
||||||
|
# Check if the folder exists
|
||||||
|
if not os.path.isdir(folder_path):
|
||||||
|
return -1
|
||||||
|
|
||||||
|
# Get the list of files in the folder
|
||||||
|
files = os.listdir(folder_path)
|
||||||
|
|
||||||
|
# Check if the folder is empty
|
||||||
|
if not files:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
# Extract the integer part of filenames and find the largest
|
||||||
|
largest_integer = -1
|
||||||
|
for file in files:
|
||||||
|
try:
|
||||||
|
# Get the integer part of the filename
|
||||||
|
file_int = int(os.path.splitext(file)[0])
|
||||||
|
if file_int > largest_integer:
|
||||||
|
largest_integer = file_int
|
||||||
|
except ValueError:
|
||||||
|
# Skip files that don't have an integer filename
|
||||||
|
continue
|
||||||
|
|
||||||
|
return largest_integer
|
||||||
|
|
||||||
|
def datagen(whisper_chunks,
|
||||||
|
crop_images,
|
||||||
|
batch_size=8,
|
||||||
|
delay_frame=0):
|
||||||
|
whisper_batch, crop_batch = [], []
|
||||||
|
for i, w in enumerate(whisper_chunks):
|
||||||
|
idx = (i+delay_frame)%len(crop_images)
|
||||||
|
crop_image = crop_images[idx]
|
||||||
|
whisper_batch.append(w)
|
||||||
|
crop_batch.append(crop_image)
|
||||||
|
|
||||||
|
if len(crop_batch) >= batch_size:
|
||||||
|
whisper_batch = np.stack(whisper_batch)
|
||||||
|
# latent_batch = torch.cat(latent_batch, dim=0)
|
||||||
|
yield whisper_batch, crop_batch
|
||||||
|
whisper_batch, crop_batch = [], []
|
||||||
|
|
||||||
|
# the last batch may smaller than batch size
|
||||||
|
if len(crop_batch) > 0:
|
||||||
|
whisper_batch = np.stack(whisper_batch)
|
||||||
|
# latent_batch = torch.cat(latent_batch, dim=0)
|
||||||
|
|
||||||
|
yield whisper_batch, crop_batch
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main(args):
|
||||||
|
global pe
|
||||||
|
if args.use_float16 is True:
|
||||||
|
pe = pe.half()
|
||||||
|
vae.vae = vae.vae.half()
|
||||||
|
unet.model = unet.model.half()
|
||||||
|
|
||||||
|
inference_config = OmegaConf.load(args.inference_config)
|
||||||
|
total_audio_index=get_largest_integer_filename(f"data/audios/{args.folder_name}")
|
||||||
|
total_image_index=get_largest_integer_filename(f"data/images/{args.folder_name}")
|
||||||
|
temp_audio_index=total_audio_index
|
||||||
|
temp_image_index=total_image_index
|
||||||
|
for task_id in inference_config:
|
||||||
|
video_path = inference_config[task_id]["video_path"]
|
||||||
|
audio_path = inference_config[task_id]["audio_path"]
|
||||||
|
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||||
|
folder_name = args.folder_name
|
||||||
|
if not os.path.exists(f"data/images/{folder_name}/"):
|
||||||
|
os.makedirs(f"data/images/{folder_name}")
|
||||||
|
if not os.path.exists(f"data/audios/{folder_name}/"):
|
||||||
|
os.makedirs(f"data/audios/{folder_name}")
|
||||||
|
input_basename = os.path.basename(video_path).split('.')[0]
|
||||||
|
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||||
|
output_basename = f"{input_basename}_{audio_basename}"
|
||||||
|
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||||
|
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||||
|
os.makedirs(result_img_save_path,exist_ok =True)
|
||||||
|
|
||||||
|
if args.output_vid_name is None:
|
||||||
|
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||||
|
else:
|
||||||
|
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||||
|
############################################## extract frames from source video ##############################################
|
||||||
|
if get_file_type(video_path)=="video":
|
||||||
|
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||||
|
os.makedirs(save_dir_full,exist_ok = True)
|
||||||
|
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||||
|
os.system(cmd)
|
||||||
|
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||||
|
fps = get_video_fps(video_path)
|
||||||
|
elif get_file_type(video_path)=="image":
|
||||||
|
input_img_list = [video_path, ]
|
||||||
|
fps = args.fps
|
||||||
|
elif os.path.isdir(video_path): # input img folder
|
||||||
|
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
|
fps = args.fps
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||||
|
############################################## extract audio feature ##############################################
|
||||||
|
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||||
|
for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements
|
||||||
|
# Combine two consecutive chunks
|
||||||
|
# pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]])
|
||||||
|
concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0)
|
||||||
|
# Save the pair to a .npy file
|
||||||
|
np.save(f'data/audios/{folder_name}/{total_audio_index+(__//2)+1}.npy', concatenated_chunks)
|
||||||
|
temp_audio_index=(__//2)+total_audio_index+1
|
||||||
|
total_audio_index=temp_audio_index
|
||||||
|
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||||
|
|
||||||
|
############################################## preprocess input image ##############################################
|
||||||
|
gc.collect()
|
||||||
|
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||||
|
print("using extracted coordinates")
|
||||||
|
with open(crop_coord_save_path,'rb') as f:
|
||||||
|
coord_list = pickle.load(f)
|
||||||
|
frame_list = read_imgs(input_img_list)
|
||||||
|
else:
|
||||||
|
print("extracting landmarks...time consuming")
|
||||||
|
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||||
|
with open(crop_coord_save_path, 'wb') as f:
|
||||||
|
pickle.dump(coord_list, f)
|
||||||
|
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
input_latent_list = []
|
||||||
|
crop_i=0
|
||||||
|
crop_data=[]
|
||||||
|
for bbox, frame in zip(coord_list, frame_list):
|
||||||
|
if bbox == coord_placeholder:
|
||||||
|
continue
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
|
||||||
|
x1=max(0,x1)
|
||||||
|
y1=max(0,y1)
|
||||||
|
x2=max(0,x2)
|
||||||
|
y2=max(0,y2)
|
||||||
|
|
||||||
|
if ((y2-y1)<=0) or ((x2-x1)<=0):
|
||||||
|
continue
|
||||||
|
crop_frame = frame[y1:y2, x1:x2]
|
||||||
|
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||||
|
latents = vae.get_latents_for_unet(crop_frame)
|
||||||
|
crop_data.append(crop_frame)
|
||||||
|
input_latent_list.append(latents)
|
||||||
|
crop_i+=1
|
||||||
|
|
||||||
|
# to smooth the first and the last frame
|
||||||
|
frame_list_cycle = frame_list + frame_list[::-1]
|
||||||
|
coord_list_cycle = coord_list + coord_list[::-1]
|
||||||
|
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||||
|
crop_data = crop_data + crop_data[::-1]
|
||||||
|
############################################## inference batch by batch ##############################################
|
||||||
|
video_num = len(whisper_chunks)
|
||||||
|
batch_size = args.batch_size
|
||||||
|
gen = datagen(whisper_chunks,crop_data,batch_size)
|
||||||
|
for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||||
|
crop_index=0
|
||||||
|
for image,audio in zip(crop_batch,whisper_batch):
|
||||||
|
cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index+total_image_index+1)}.png",image)
|
||||||
|
crop_index+=1
|
||||||
|
temp_image_index=i+crop_index+total_image_index+1
|
||||||
|
# np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio)
|
||||||
|
total_image_index=temp_image_index
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
||||||
|
parser.add_argument("--bbox_shift", type=int, default=0)
|
||||||
|
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||||
|
parser.add_argument("--folder_name", default=f'{uuid.uuid4()}', help="path to output")
|
||||||
|
|
||||||
|
parser.add_argument("--fps", type=int, default=25)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=8)
|
||||||
|
parser.add_argument("--output_vid_name", type=str, default=None)
|
||||||
|
parser.add_argument("--use_saved_coord",
|
||||||
|
action="store_true",
|
||||||
|
help='use saved coordinate to save time')
|
||||||
|
parser.add_argument("--use_float16",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether use float16 to speed up inference",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
|
|
||||||
|
|
||||||
|
def process_audio(audio_path):
|
||||||
|
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||||
|
np.save('audio/your_filename.npy', whisper_feature)
|
||||||
|
|
||||||
|
def mask_face(image):
|
||||||
|
# Load dlib's face detector and the landmark predictor
|
||||||
|
detector = dlib.get_frontal_face_detector()
|
||||||
|
predictor_path = "/content/shape_predictor_68_face_landmarks.dat" # Set path to your downloaded predictor file
|
||||||
|
predictor = dlib.shape_predictor(predictor_path)
|
||||||
|
|
||||||
|
# Load your input image
|
||||||
|
# image_path = "/content/ori_frame_00000077.png" # Replace with the path to your input image
|
||||||
|
# image = cv2.imread(image_path)
|
||||||
|
if image is None:
|
||||||
|
raise ValueError("Image not found or unable to load.")
|
||||||
|
|
||||||
|
# Convert to grayscale for detection
|
||||||
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# Detect faces in the image
|
||||||
|
faces = detector(gray)
|
||||||
|
|
||||||
|
# Process each detected face
|
||||||
|
for face in faces:
|
||||||
|
# Predict landmarks
|
||||||
|
landmarks = predictor(gray, face)
|
||||||
|
|
||||||
|
# The indices of nose landmarks are 27 to 35
|
||||||
|
nose_tip = landmarks.part(33).y
|
||||||
|
|
||||||
|
# Blacken the region below the nose tip
|
||||||
|
blacken_area = image[nose_tip:, :]
|
||||||
|
blacken_area[:] = (0, 0, 0)
|
||||||
|
|
||||||
|
# Save the final image or display it
|
||||||
|
# cv2.imwrite("output_image.jpg", image)
|
||||||
|
return image
|
||||||
182
scripts/finetuned_inference.py
Normal file
182
scripts/finetuned_inference.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import glob
|
||||||
|
import pickle
|
||||||
|
from tqdm import tqdm
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||||
|
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||||
|
from musetalk.utils.blending import get_image
|
||||||
|
from musetalk.utils.utils import load_all_model
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
# load model weights
|
||||||
|
audio_processor, vae, unet, pe = load_all_model()
|
||||||
|
accelerator = Accelerator(
|
||||||
|
mixed_precision="fp16",
|
||||||
|
)
|
||||||
|
unet = accelerator.prepare(
|
||||||
|
unet,
|
||||||
|
|
||||||
|
)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
timesteps = torch.tensor([0], device=device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main(args):
|
||||||
|
global pe
|
||||||
|
if not (args.unet_checkpoint == None):
|
||||||
|
print("unet ckpt loaded")
|
||||||
|
accelerator.load_state(args.unet_checkpoint)
|
||||||
|
|
||||||
|
if args.use_float16 is True:
|
||||||
|
pe = pe.half()
|
||||||
|
vae.vae = vae.vae.half()
|
||||||
|
unet.model = unet.model.half()
|
||||||
|
|
||||||
|
inference_config = OmegaConf.load(args.inference_config)
|
||||||
|
print(inference_config)
|
||||||
|
for task_id in inference_config:
|
||||||
|
video_path = inference_config[task_id]["video_path"]
|
||||||
|
audio_path = inference_config[task_id]["audio_path"]
|
||||||
|
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||||
|
|
||||||
|
input_basename = os.path.basename(video_path).split('.')[0]
|
||||||
|
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||||
|
output_basename = f"{input_basename}_{audio_basename}"
|
||||||
|
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||||
|
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||||
|
os.makedirs(result_img_save_path,exist_ok =True)
|
||||||
|
|
||||||
|
if args.output_vid_name is None:
|
||||||
|
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||||
|
else:
|
||||||
|
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||||
|
############################################## extract frames from source video ##############################################
|
||||||
|
if get_file_type(video_path)=="video":
|
||||||
|
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||||
|
os.makedirs(save_dir_full,exist_ok = True)
|
||||||
|
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||||
|
os.system(cmd)
|
||||||
|
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||||
|
fps = get_video_fps(video_path)
|
||||||
|
elif get_file_type(video_path)=="image":
|
||||||
|
input_img_list = [video_path, ]
|
||||||
|
fps = args.fps
|
||||||
|
elif os.path.isdir(video_path): # input img folder
|
||||||
|
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
|
fps = args.fps
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||||
|
############################################## extract audio feature ##############################################
|
||||||
|
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||||
|
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||||
|
############################################## preprocess input image ##############################################
|
||||||
|
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||||
|
print("using extracted coordinates")
|
||||||
|
with open(crop_coord_save_path,'rb') as f:
|
||||||
|
coord_list = pickle.load(f)
|
||||||
|
frame_list = read_imgs(input_img_list)
|
||||||
|
else:
|
||||||
|
print("extracting landmarks...time consuming")
|
||||||
|
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||||
|
with open(crop_coord_save_path, 'wb') as f:
|
||||||
|
pickle.dump(coord_list, f)
|
||||||
|
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
input_latent_list = []
|
||||||
|
crop_i=0
|
||||||
|
for bbox, frame in zip(coord_list, frame_list):
|
||||||
|
if bbox == coord_placeholder:
|
||||||
|
continue
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
crop_frame = frame[y1:y2, x1:x2]
|
||||||
|
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||||
|
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
|
||||||
|
latents = vae.get_latents_for_unet(crop_frame)
|
||||||
|
input_latent_list.append(latents)
|
||||||
|
crop_i+=1
|
||||||
|
|
||||||
|
# to smooth the first and the last frame
|
||||||
|
frame_list_cycle = frame_list + frame_list[::-1]
|
||||||
|
coord_list_cycle = coord_list + coord_list[::-1]
|
||||||
|
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||||
|
############################################## inference batch by batch ##############################################
|
||||||
|
video_num = len(whisper_chunks)
|
||||||
|
batch_size = args.batch_size
|
||||||
|
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||||
|
res_frame_list = []
|
||||||
|
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||||
|
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||||
|
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||||
|
dtype=unet.model.dtype) # torch, B, 5*N,384
|
||||||
|
audio_feature_batch = pe(audio_feature_batch)
|
||||||
|
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||||
|
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||||
|
recon = vae.decode_latents(pred_latents)
|
||||||
|
for res_frame in recon:
|
||||||
|
res_frame_list.append(res_frame)
|
||||||
|
|
||||||
|
############################################## pad to full image ##############################################
|
||||||
|
print("pad talking image to original video")
|
||||||
|
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
||||||
|
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||||
|
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
try:
|
||||||
|
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||||
|
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
|
||||||
|
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
|
||||||
|
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||||
|
|
||||||
|
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||||
|
os.system(cmd_img2video)
|
||||||
|
|
||||||
|
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||||
|
os.system(cmd_combine_audio)
|
||||||
|
|
||||||
|
os.remove("temp.mp4")
|
||||||
|
|
||||||
|
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||||
|
os.system(cmd_img2video)
|
||||||
|
|
||||||
|
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||||
|
# print(cmd_combine_audio)
|
||||||
|
# os.system(cmd_combine_audio)
|
||||||
|
|
||||||
|
# shutil.rmtree(result_img_save_path)
|
||||||
|
print(f"result is save to {output_vid_name}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
||||||
|
parser.add_argument("--bbox_shift", type=int, default=0)
|
||||||
|
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||||
|
|
||||||
|
parser.add_argument("--fps", type=int, default=25)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=8)
|
||||||
|
parser.add_argument("--output_vid_name", type=str, default=None)
|
||||||
|
parser.add_argument("--use_saved_coord",
|
||||||
|
action="store_true",
|
||||||
|
help='use saved coordinate to save time')
|
||||||
|
parser.add_argument("--use_float16",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether use float16 to speed up inference",
|
||||||
|
)
|
||||||
|
parser.add_argument("--unet_checkpoint", type=str, default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@@ -48,22 +48,22 @@ def get_image_list(data_root, split):
|
|||||||
class Dataset(object):
|
class Dataset(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data_root,
|
data_root,
|
||||||
split,
|
json_path,
|
||||||
use_audio_length_left=1,
|
use_audio_length_left=1,
|
||||||
use_audio_length_right=1,
|
use_audio_length_right=1,
|
||||||
whisper_model_type = "tiny"
|
whisper_model_type = "tiny"
|
||||||
):
|
):
|
||||||
self.all_videos, self.all_imgNum = get_image_list(data_root, split)
|
# self.all_videos, self.all_imgNum = get_image_list(data_root, split)
|
||||||
self.audio_feature = [use_audio_length_left,use_audio_length_right]
|
self.audio_feature = [use_audio_length_left,use_audio_length_right]
|
||||||
self.all_img_names = []
|
self.all_img_names = []
|
||||||
self.split = split
|
# self.split = split
|
||||||
self.img_names_path = '...'
|
self.img_names_path = '../data'
|
||||||
self.whisper_model_type = whisper_model_type
|
self.whisper_model_type = whisper_model_type
|
||||||
self.use_audio_length_left = use_audio_length_left
|
self.use_audio_length_left = use_audio_length_left
|
||||||
self.use_audio_length_right = use_audio_length_right
|
self.use_audio_length_right = use_audio_length_right
|
||||||
|
|
||||||
if self.whisper_model_type =="tiny":
|
if self.whisper_model_type =="tiny":
|
||||||
self.whisper_path = '...'
|
self.whisper_path = '../data/audios'
|
||||||
self.whisper_feature_W = 5
|
self.whisper_feature_W = 5
|
||||||
self.whisper_feature_H = 384
|
self.whisper_feature_H = 384
|
||||||
elif self.whisper_model_type =="largeV2":
|
elif self.whisper_model_type =="largeV2":
|
||||||
@@ -71,6 +71,8 @@ class Dataset(object):
|
|||||||
self.whisper_feature_W = 33
|
self.whisper_feature_W = 33
|
||||||
self.whisper_feature_H = 1280
|
self.whisper_feature_H = 1280
|
||||||
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50
|
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50
|
||||||
|
with open(json_path, 'r') as file:
|
||||||
|
self.all_videos = json.load(file)
|
||||||
|
|
||||||
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
|
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
|
||||||
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
|
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
|
||||||
@@ -79,7 +81,6 @@ class Dataset(object):
|
|||||||
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
|
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
|
||||||
with open(json_path_names, "w") as f:
|
with open(json_path_names, "w") as f:
|
||||||
json.dump(img_names,f)
|
json.dump(img_names,f)
|
||||||
print(f"save to {json_path_names}")
|
|
||||||
else:
|
else:
|
||||||
with open(json_path_names, "r") as f:
|
with open(json_path_names, "r") as f:
|
||||||
img_names = json.load(f)
|
img_names = json.load(f)
|
||||||
@@ -135,7 +136,6 @@ class Dataset(object):
|
|||||||
vidname = self.all_videos[idx].split('/')[-1]
|
vidname = self.all_videos[idx].split('/')[-1]
|
||||||
video_imgs = self.all_img_names[idx]
|
video_imgs = self.all_img_names[idx]
|
||||||
if len(video_imgs) == 0:
|
if len(video_imgs) == 0:
|
||||||
# print("video_imgs = 0:",vidname)
|
|
||||||
continue
|
continue
|
||||||
img_name = random.choice(video_imgs)
|
img_name = random.choice(video_imgs)
|
||||||
img_idx = int(basename(img_name).split(".")[0])
|
img_idx = int(basename(img_name).split(".")[0])
|
||||||
@@ -193,7 +193,6 @@ class Dataset(object):
|
|||||||
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
|
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
|
||||||
# 判定是否越界
|
# 判定是否越界
|
||||||
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
|
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
|
||||||
|
|
||||||
if not os.path.exists(audio_feat_path):
|
if not os.path.exists(audio_feat_path):
|
||||||
is_index_out_of_range = True
|
is_index_out_of_range = True
|
||||||
break
|
break
|
||||||
@@ -214,8 +213,6 @@ class Dataset(object):
|
|||||||
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
|
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
|
||||||
continue
|
continue
|
||||||
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
|
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
|
||||||
|
|
||||||
|
|
||||||
return ref_image, image, masked_image, mask, audio_feature
|
return ref_image, image, masked_image, mask, audio_feature
|
||||||
|
|
||||||
|
|
||||||
@@ -231,10 +228,8 @@ if __name__ == "__main__":
|
|||||||
val_data_loader = data_utils.DataLoader(
|
val_data_loader = data_utils.DataLoader(
|
||||||
val_data, batch_size=4, shuffle=True,
|
val_data, batch_size=4, shuffle=True,
|
||||||
num_workers=1)
|
num_workers=1)
|
||||||
print("val_dataset:",val_data_loader.__len__())
|
|
||||||
|
|
||||||
for i, data in enumerate(val_data_loader):
|
for i, data in enumerate(val_data_loader):
|
||||||
ref_image, image, masked_image, mask, audio_feature = data
|
ref_image, image, masked_image, mask, audio_feature = data
|
||||||
print("ref_image: ", ref_image.shape)
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
# Draft training codes
|
# Data preprocessing
|
||||||
|
|
||||||
We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized.
|
Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
|
||||||
|
The train yaml file should contain the training video paths and corresponding audio paths
|
||||||
|
The test yaml file should contain the validation video paths and corresponding audio paths
|
||||||
|
|
||||||
## Setup
|
Run:
|
||||||
|
```
|
||||||
|
./data_new.sh train output train_video1.mp4 train_video2.mp4
|
||||||
|
./data_new.sh test output test_video1.mp4 test_video2.mp4
|
||||||
|
```
|
||||||
|
This creates folders which contain the image frames and npy files. This also creates train.json and val.json which can be used during the training.
|
||||||
|
|
||||||
We trained our model on an NVIDIA A100 with `batch size=8, gradient_accumulation_steps=4` for 20w+ steps. Using multiple GPUs should accelerate the training.
|
## Data organization
|
||||||
|
|
||||||
## Data preprocessing
|
|
||||||
You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69).
|
|
||||||
|
|
||||||
Finally, the data should be organized as follows:
|
|
||||||
```
|
```
|
||||||
./data/
|
./data/
|
||||||
├── images
|
├── images
|
||||||
@@ -35,9 +37,16 @@ Finally, the data should be organized as follows:
|
|||||||
## Training
|
## Training
|
||||||
Simply run after preparing the preprocessed data
|
Simply run after preparing the preprocessed data
|
||||||
```
|
```
|
||||||
sh train.sh
|
cd train_codes
|
||||||
|
sh train.sh #--train_json="../train.json" \(Generated in Data preprocessing step.)
|
||||||
|
#--val_json="../val.json" \
|
||||||
|
```
|
||||||
|
## Inference with trained checkpoit
|
||||||
|
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
|
||||||
|
```
|
||||||
|
python -m scripts.finetuned_inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
|
||||||
```
|
```
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
- [ ] release data preprocessing codes
|
- [x] release data preprocessing codes
|
||||||
- [ ] release some novel designs in training (after technical report)
|
- [ ] release some novel designs in training (after technical report)
|
||||||
@@ -27,10 +27,13 @@ from diffusers import (
|
|||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.append("./")
|
||||||
|
|
||||||
from DataLoader import Dataset
|
from DataLoader import Dataset
|
||||||
from utils.utils import preprocess_img_tensor
|
from utils.utils import preprocess_img_tensor
|
||||||
from torch.utils import data as data_utils
|
from torch.utils import data as data_utils
|
||||||
from model_utils import validation,PositionalEncoding
|
from utils.model_utils import validation,PositionalEncoding
|
||||||
import time
|
import time
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -137,6 +140,8 @@ def parse_args():
|
|||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||||
|
parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders")
|
||||||
|
parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hub_model_id",
|
"--hub_model_id",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -235,12 +240,16 @@ def parse_args():
|
|||||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||||
args.local_rank = env_local_rank
|
args.local_rank = env_local_rank
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
def print_model_dtypes(model, model_name):
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if(param.dtype!=torch.float32):
|
||||||
|
print(f"{name}: {param.dtype}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
print(args)
|
|
||||||
args.output_dir = f"output/{args.output_dir}"
|
args.output_dir = f"output/{args.output_dir}"
|
||||||
args.val_out_dir = f"val/{args.val_out_dir}"
|
args.val_out_dir = f"val/{args.val_out_dir}"
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
@@ -332,7 +341,7 @@ def main():
|
|||||||
optimizer_class = torch.optim.AdamW
|
optimizer_class = torch.optim.AdamW
|
||||||
|
|
||||||
params_to_optimize = (
|
params_to_optimize = (
|
||||||
itertools.chain(unet.parameters())
|
itertools.chain(unet.parameters()))
|
||||||
optimizer = optimizer_class(
|
optimizer = optimizer_class(
|
||||||
params_to_optimize,
|
params_to_optimize,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
@@ -343,23 +352,21 @@ def main():
|
|||||||
|
|
||||||
print("loading train_dataset ...")
|
print("loading train_dataset ...")
|
||||||
train_dataset = Dataset(args.data_root,
|
train_dataset = Dataset(args.data_root,
|
||||||
'train',
|
args.train_json,
|
||||||
use_audio_length_left=args.use_audio_length_left,
|
use_audio_length_left=args.use_audio_length_left,
|
||||||
use_audio_length_right=args.use_audio_length_right,
|
use_audio_length_right=args.use_audio_length_right,
|
||||||
whisper_model_type=args.whisper_model_type
|
whisper_model_type=args.whisper_model_type
|
||||||
)
|
)
|
||||||
print("train_dataset:",train_dataset.__len__())
|
|
||||||
train_data_loader = data_utils.DataLoader(
|
train_data_loader = data_utils.DataLoader(
|
||||||
train_dataset, batch_size=args.train_batch_size, shuffle=True,
|
train_dataset, batch_size=args.train_batch_size, shuffle=True,
|
||||||
num_workers=8)
|
num_workers=8)
|
||||||
print("loading val_dataset ...")
|
print("loading val_dataset ...")
|
||||||
val_dataset = Dataset(args.data_root,
|
val_dataset = Dataset(args.data_root,
|
||||||
'val',
|
args.val_json,
|
||||||
use_audio_length_left=args.use_audio_length_left,
|
use_audio_length_left=args.use_audio_length_left,
|
||||||
use_audio_length_right=args.use_audio_length_right,
|
use_audio_length_right=args.use_audio_length_right,
|
||||||
whisper_model_type=args.whisper_model_type
|
whisper_model_type=args.whisper_model_type
|
||||||
)
|
)
|
||||||
print("val_dataset:",val_dataset.__len__())
|
|
||||||
val_data_loader = data_utils.DataLoader(
|
val_data_loader = data_utils.DataLoader(
|
||||||
val_dataset, batch_size=1, shuffle=False,
|
val_dataset, batch_size=1, shuffle=False,
|
||||||
num_workers=8)
|
num_workers=8)
|
||||||
@@ -388,6 +395,7 @@ def main():
|
|||||||
vae_fp32.requires_grad_(False)
|
vae_fp32.requires_grad_(False)
|
||||||
|
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
|
# weight_dtype = torch.float16
|
||||||
vae_fp32.to(accelerator.device, dtype=weight_dtype)
|
vae_fp32.to(accelerator.device, dtype=weight_dtype)
|
||||||
vae_fp32.encoder = None
|
vae_fp32.encoder = None
|
||||||
if accelerator.mixed_precision == "fp16":
|
if accelerator.mixed_precision == "fp16":
|
||||||
@@ -412,6 +420,8 @@ def main():
|
|||||||
# Train!
|
# Train!
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
print(f" Num batches each epoch = {len(train_data_loader)}")
|
||||||
|
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(f" Num examples = {len(train_dataset)}")
|
logger.info(f" Num examples = {len(train_dataset)}")
|
||||||
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
|
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
|
||||||
@@ -433,6 +443,9 @@ def main():
|
|||||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||||
path = dirs[-1] if len(dirs) > 0 else None
|
path = dirs[-1] if len(dirs) > 0 else None
|
||||||
|
|
||||||
|
# path="../models/pytorch_model.bin"
|
||||||
|
#TODO change path
|
||||||
|
# path=None
|
||||||
if path is None:
|
if path is None:
|
||||||
accelerator.print(
|
accelerator.print(
|
||||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||||
@@ -459,9 +472,10 @@ def main():
|
|||||||
elapsed_time = []
|
elapsed_time = []
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for epoch in range(first_epoch, args.num_train_epochs):
|
for epoch in range(first_epoch, args.num_train_epochs):
|
||||||
unet.train()
|
unet.train()
|
||||||
# for step, batch in enumerate(train_dataloader):
|
|
||||||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
|
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
|
||||||
# Skip steps until we reach the resumed step
|
# Skip steps until we reach the resumed step
|
||||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||||
@@ -470,24 +484,23 @@ def main():
|
|||||||
continue
|
continue
|
||||||
dataloader_time = time.time() - start
|
dataloader_time = time.time() - start
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||||||
"""
|
# """
|
||||||
print("=============epoch:{0}=step:{1}=====".format(epoch,step))
|
# print("=============epoch:{0}=step:{1}=====".format(epoch,step))
|
||||||
print("ref_image: ",ref_image.shape)
|
# print("ref_image: ",ref_image.shape)
|
||||||
print("masks: ", masks.shape)
|
# print("masks: ", masks.shape)
|
||||||
print("masked_image: ", masked_image.shape)
|
# print("masked_image: ", masked_image.shape)
|
||||||
print("audio feature: ", audio_feature.shape)
|
# print("audio feature: ", audio_feature.shape)
|
||||||
print("image: ", image.shape)
|
# print("image: ", image.shape)
|
||||||
"""
|
# """
|
||||||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||||||
image = preprocess_img_tensor(image).to(vae.device)
|
image = preprocess_img_tensor(image).to(vae.device)
|
||||||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||||||
|
|
||||||
img_process_time = time.time() - start
|
img_process_time = time.time() - start
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
with accelerator.accumulate(unet):
|
with accelerator.accumulate(unet):
|
||||||
|
vae = vae.half()
|
||||||
# Convert images to latent space
|
# Convert images to latent space
|
||||||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||||||
latents = latents * vae.config.scaling_factor
|
latents = latents * vae.config.scaling_factor
|
||||||
@@ -592,12 +605,23 @@ def main():
|
|||||||
f"Running validation... epoch={epoch}, global_step={global_step}"
|
f"Running validation... epoch={epoch}, global_step={global_step}"
|
||||||
)
|
)
|
||||||
print("===========start validation==========")
|
print("===========start validation==========")
|
||||||
|
# Use the helper function to check the data types for each model
|
||||||
|
vae_new = vae.float()
|
||||||
|
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
|
||||||
|
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
|
||||||
|
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
|
||||||
|
|
||||||
|
print(f"weight_dtype: {weight_dtype}")
|
||||||
|
print(f"epoch type: {type(epoch)}")
|
||||||
|
print(f"global_step type: {type(global_step)}")
|
||||||
validation(
|
validation(
|
||||||
vae=accelerator.unwrap_model(vae),
|
# vae=accelerator.unwrap_model(vae),
|
||||||
|
vae=accelerator.unwrap_model(vae_new),
|
||||||
vae_fp32=accelerator.unwrap_model(vae_fp32),
|
vae_fp32=accelerator.unwrap_model(vae_fp32),
|
||||||
unet=accelerator.unwrap_model(unet),
|
unet=accelerator.unwrap_model(unet),
|
||||||
unet_config=unet_config,
|
unet_config=unet_config,
|
||||||
weight_dtype=weight_dtype,
|
# weight_dtype=weight_dtype,
|
||||||
|
weight_dtype=torch.float32,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
val_data_loader=val_data_loader,
|
val_data_loader=val_data_loader,
|
||||||
|
|||||||
@@ -1,27 +1,29 @@
|
|||||||
export VAE_MODEL="./sd-vae-ft-mse/"
|
export VAE_MODEL="../models/sd-vae-ft-mse/"
|
||||||
export DATASET="..."
|
export DATASET="../data"
|
||||||
export UNET_CONFIG="./musetalk.json"
|
export UNET_CONFIG="../models/musetalk/musetalk.json"
|
||||||
|
|
||||||
accelerate launch --multi_gpu train.py \
|
accelerate launch train.py \
|
||||||
--mixed_precision="fp16" \
|
--mixed_precision="fp16" \
|
||||||
--unet_config_file=$UNET_CONFIG \
|
--unet_config_file=$UNET_CONFIG \
|
||||||
--pretrained_model_name_or_path=$VAE_MODEL \
|
--pretrained_model_name_or_path=$VAE_MODEL \
|
||||||
--data_root=$DATASET \
|
--data_root=$DATASET \
|
||||||
--train_batch_size=8 \
|
--train_batch_size=256 \
|
||||||
--gradient_accumulation_steps=4 \
|
--gradient_accumulation_steps=16 \
|
||||||
--gradient_checkpointing \
|
--gradient_checkpointing \
|
||||||
--max_train_steps=200000 \
|
--max_train_steps=100000 \
|
||||||
--learning_rate=5e-05 \
|
--learning_rate=5e-05 \
|
||||||
--max_grad_norm=1 \
|
--max_grad_norm=1 \
|
||||||
--lr_scheduler="cosine" \
|
|
||||||
--lr_warmup_steps=0 \
|
--lr_warmup_steps=0 \
|
||||||
--output_dir="..." \
|
--output_dir="output" \
|
||||||
--val_out_dir='...' \
|
--val_out_dir='val' \
|
||||||
--testing_speed \
|
--testing_speed \
|
||||||
--checkpointing_steps=1000 \
|
--checkpointing_steps=2000 \
|
||||||
--validation_steps=1000 \
|
--validation_steps=2000 \
|
||||||
--reconstruction \
|
--reconstruction \
|
||||||
--resume_from_checkpoint="latest" \
|
--resume_from_checkpoint="latest" \
|
||||||
--use_audio_length_left=2 \
|
--use_audio_length_left=2 \
|
||||||
--use_audio_length_right=2 \
|
--use_audio_length_right=2 \
|
||||||
--whisper_model_type="tiny" \
|
--whisper_model_type="tiny" \
|
||||||
|
--train_json="../train.json" \
|
||||||
|
--val_json="../val.json" \
|
||||||
|
--lr_scheduler="cosine" \
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
from utils import decode_latents, preprocess_img_tensor
|
from utils.utils import decode_latents, preprocess_img_tensor
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|||||||
Reference in New Issue
Block a user