mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
77
data_new.sh
Executable file
77
data_new.sh
Executable file
@@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Function to extract video and audio sections
|
||||
extract_sections() {
|
||||
input_video=$1
|
||||
base_name=$(basename "$input_video" .mp4)
|
||||
output_dir=$2
|
||||
split=$3
|
||||
duration=$(ffmpeg -i "$input_video" 2>&1 | grep Duration | awk '{print $2}' | tr -d ,)
|
||||
IFS=: read -r hours minutes seconds <<< "$duration"
|
||||
total_seconds=$((10#${hours}*3600 + 10#${minutes}*60 + 10#${seconds%.*}))
|
||||
chunk_size=180 # 3 minutes in seconds
|
||||
index=0
|
||||
|
||||
mkdir -p "$output_dir"
|
||||
|
||||
while [ $((index * chunk_size)) -lt $total_seconds ]; do
|
||||
start_time=$((index * chunk_size))
|
||||
section_video="${output_dir}/${base_name}_part${index}.mp4"
|
||||
section_audio="${output_dir}/${base_name}_part${index}.mp3"
|
||||
|
||||
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -c copy "$section_video"
|
||||
ffmpeg -i "$input_video" -ss "$start_time" -t "$chunk_size" -q:a 0 -map a "$section_audio"
|
||||
|
||||
# Create and update the config.yaml file
|
||||
echo "task_0:" > config.yaml
|
||||
echo " video_path: \"$section_video\"" >> config.yaml
|
||||
echo " audio_path: \"$section_audio\"" >> config.yaml
|
||||
|
||||
# Run the Python script with the current config.yaml
|
||||
python -m scripts.data --inference_config config.yaml --folder_name "$base_name"
|
||||
|
||||
index=$((index + 1))
|
||||
done
|
||||
|
||||
# Clean up save folder
|
||||
rm -rf $output_dir
|
||||
}
|
||||
|
||||
# Main script
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <train/test> <output_directory> <input_videos...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
split=$1
|
||||
output_dir=$2
|
||||
shift 2
|
||||
input_videos=("$@")
|
||||
|
||||
# Initialize JSON array
|
||||
json_array="["
|
||||
|
||||
for input_video in "${input_videos[@]}"; do
|
||||
base_name=$(basename "$input_video" .mp4)
|
||||
|
||||
# Extract sections and run the Python script for each section
|
||||
extract_sections "$input_video" "$output_dir" "$split"
|
||||
|
||||
# Add entry to JSON array
|
||||
json_array+="\"../data/images/$base_name\","
|
||||
done
|
||||
|
||||
# Remove trailing comma and close JSON array
|
||||
json_array="${json_array%,}]"
|
||||
|
||||
# Write JSON array to the correct file
|
||||
if [ "$split" == "train" ]; then
|
||||
echo "$json_array" > train.json
|
||||
elif [ "$split" == "test" ]; then
|
||||
echo "$json_array" > test.json
|
||||
else
|
||||
echo "Invalid split: $split. Must be 'train' or 'test'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Processing complete."
|
||||
256
scripts/data.py
Normal file
256
scripts/data.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import cv2
|
||||
import os
|
||||
# import dlib
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import glob
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
import uuid
|
||||
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
import gc
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
def get_largest_integer_filename(folder_path):
|
||||
# Check if the folder exists
|
||||
if not os.path.isdir(folder_path):
|
||||
return -1
|
||||
|
||||
# Get the list of files in the folder
|
||||
files = os.listdir(folder_path)
|
||||
|
||||
# Check if the folder is empty
|
||||
if not files:
|
||||
return -1
|
||||
|
||||
# Extract the integer part of filenames and find the largest
|
||||
largest_integer = -1
|
||||
for file in files:
|
||||
try:
|
||||
# Get the integer part of the filename
|
||||
file_int = int(os.path.splitext(file)[0])
|
||||
if file_int > largest_integer:
|
||||
largest_integer = file_int
|
||||
except ValueError:
|
||||
# Skip files that don't have an integer filename
|
||||
continue
|
||||
|
||||
return largest_integer
|
||||
|
||||
def datagen(whisper_chunks,
|
||||
crop_images,
|
||||
batch_size=8,
|
||||
delay_frame=0):
|
||||
whisper_batch, crop_batch = [], []
|
||||
for i, w in enumerate(whisper_chunks):
|
||||
idx = (i+delay_frame)%len(crop_images)
|
||||
crop_image = crop_images[idx]
|
||||
whisper_batch.append(w)
|
||||
crop_batch.append(crop_image)
|
||||
|
||||
if len(crop_batch) >= batch_size:
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
# latent_batch = torch.cat(latent_batch, dim=0)
|
||||
yield whisper_batch, crop_batch
|
||||
whisper_batch, crop_batch = [], []
|
||||
|
||||
# the last batch may smaller than batch size
|
||||
if len(crop_batch) > 0:
|
||||
whisper_batch = np.stack(whisper_batch)
|
||||
# latent_batch = torch.cat(latent_batch, dim=0)
|
||||
|
||||
yield whisper_batch, crop_batch
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
global pe
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
total_audio_index=get_largest_integer_filename(f"data/audios/{args.folder_name}")
|
||||
total_image_index=get_largest_integer_filename(f"data/images/{args.folder_name}")
|
||||
temp_audio_index=total_audio_index
|
||||
temp_image_index=total_image_index
|
||||
for task_id in inference_config:
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||
folder_name = args.folder_name
|
||||
if not os.path.exists(f"data/images/{folder_name}/"):
|
||||
os.makedirs(f"data/images/{folder_name}")
|
||||
if not os.path.exists(f"data/audios/{folder_name}/"):
|
||||
os.makedirs(f"data/audios/{folder_name}")
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||
os.makedirs(result_img_save_path,exist_ok =True)
|
||||
|
||||
if args.output_vid_name is None:
|
||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path)=="video":
|
||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||
os.makedirs(save_dir_full,exist_ok = True)
|
||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
os.system(cmd)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
elif get_file_type(video_path)=="image":
|
||||
input_img_list = [video_path, ]
|
||||
fps = args.fps
|
||||
elif os.path.isdir(video_path): # input img folder
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
for __ in range(0, len(whisper_feature) - 1, 2): # -1 to avoid index error if the list has an odd number of elements
|
||||
# Combine two consecutive chunks
|
||||
# pair_of_chunks = np.array([whisper_feature[__], whisper_feature[__+1]])
|
||||
concatenated_chunks = np.concatenate([whisper_feature[__], whisper_feature[__+1]], axis=0)
|
||||
# Save the pair to a .npy file
|
||||
np.save(f'data/audios/{folder_name}/{total_audio_index+(__//2)+1}.npy', concatenated_chunks)
|
||||
temp_audio_index=(__//2)+total_audio_index+1
|
||||
total_audio_index=temp_audio_index
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
|
||||
############################################## preprocess input image ##############################################
|
||||
gc.collect()
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("extracting landmarks...time consuming")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
crop_i=0
|
||||
crop_data=[]
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
x1=max(0,x1)
|
||||
y1=max(0,y1)
|
||||
x2=max(0,x2)
|
||||
y2=max(0,y2)
|
||||
|
||||
if ((y2-y1)<=0) or ((x2-x1)<=0):
|
||||
continue
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
crop_data.append(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
crop_i+=1
|
||||
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
crop_data = crop_data + crop_data[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,crop_data,batch_size)
|
||||
for i, (whisper_batch,crop_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
crop_index=0
|
||||
for image,audio in zip(crop_batch,whisper_batch):
|
||||
cv2.imwrite(f"data/images/{folder_name}/{str(i+crop_index+total_image_index+1)}.png",image)
|
||||
crop_index+=1
|
||||
temp_image_index=i+crop_index+total_image_index+1
|
||||
# np.save(f'data/audios/{folder_name}/{str(i+crop_index)}.npy', audio)
|
||||
total_image_index=temp_image_index
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0)
|
||||
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||
parser.add_argument("--folder_name", default=f'{uuid.uuid4()}', help="path to output")
|
||||
|
||||
parser.add_argument("--fps", type=int, default=25)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--output_vid_name", type=str, default=None)
|
||||
parser.add_argument("--use_saved_coord",
|
||||
action="store_true",
|
||||
help='use saved coordinate to save time')
|
||||
parser.add_argument("--use_float16",
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
def process_audio(audio_path):
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
np.save('audio/your_filename.npy', whisper_feature)
|
||||
|
||||
def mask_face(image):
|
||||
# Load dlib's face detector and the landmark predictor
|
||||
detector = dlib.get_frontal_face_detector()
|
||||
predictor_path = "/content/shape_predictor_68_face_landmarks.dat" # Set path to your downloaded predictor file
|
||||
predictor = dlib.shape_predictor(predictor_path)
|
||||
|
||||
# Load your input image
|
||||
# image_path = "/content/ori_frame_00000077.png" # Replace with the path to your input image
|
||||
# image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise ValueError("Image not found or unable to load.")
|
||||
|
||||
# Convert to grayscale for detection
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Detect faces in the image
|
||||
faces = detector(gray)
|
||||
|
||||
# Process each detected face
|
||||
for face in faces:
|
||||
# Predict landmarks
|
||||
landmarks = predictor(gray, face)
|
||||
|
||||
# The indices of nose landmarks are 27 to 35
|
||||
nose_tip = landmarks.part(33).y
|
||||
|
||||
# Blacken the region below the nose tip
|
||||
blacken_area = image[nose_tip:, :]
|
||||
blacken_area[:] = (0, 0, 0)
|
||||
|
||||
# Save the final image or display it
|
||||
# cv2.imwrite("output_image.jpg", image)
|
||||
return image
|
||||
182
scripts/finetuned_inference.py
Normal file
182
scripts/finetuned_inference.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import glob
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
|
||||
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.utils import load_all_model
|
||||
import shutil
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
# load model weights
|
||||
audio_processor, vae, unet, pe = load_all_model()
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="fp16",
|
||||
)
|
||||
unet = accelerator.prepare(
|
||||
unet,
|
||||
|
||||
)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
global pe
|
||||
if not (args.unet_checkpoint == None):
|
||||
print("unet ckpt loaded")
|
||||
accelerator.load_state(args.unet_checkpoint)
|
||||
|
||||
if args.use_float16 is True:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
for task_id in inference_config:
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
|
||||
crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
|
||||
os.makedirs(result_img_save_path,exist_ok =True)
|
||||
|
||||
if args.output_vid_name is None:
|
||||
output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path)=="video":
|
||||
save_dir_full = os.path.join(args.result_dir, input_basename)
|
||||
os.makedirs(save_dir_full,exist_ok = True)
|
||||
cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
os.system(cmd)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
elif get_file_type(video_path)=="image":
|
||||
input_img_list = [video_path, ]
|
||||
fps = args.fps
|
||||
elif os.path.isdir(video_path): # input img folder
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
############################################## extract audio feature ##############################################
|
||||
whisper_feature = audio_processor.audio2feat(audio_path)
|
||||
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
||||
############################################## preprocess input image ##############################################
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("extracting landmarks...time consuming")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
crop_i=0
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
cv2.imwrite(f"{result_img_save_path}/crop_frame_{str(crop_i).zfill(8)}.png",crop_frame)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
crop_i+=1
|
||||
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
############################################## inference batch by batch ##############################################
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
||||
res_frame_list = []
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
audio_feature_batch = torch.from_numpy(whisper_batch)
|
||||
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
||||
dtype=unet.model.dtype) # torch, B, 5*N,384
|
||||
audio_feature_batch = pe(audio_feature_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
############################################## pad to full image ##############################################
|
||||
print("pad talking image to original video")
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
except:
|
||||
continue
|
||||
|
||||
combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
cv2.imwrite(f"{result_img_save_path}/res_frame_{str(i).zfill(8)}.png",res_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/ori_frame_{str(i).zfill(8)}.png",ori_frame)
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
os.system(cmd_img2video)
|
||||
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
os.system(cmd_combine_audio)
|
||||
|
||||
os.remove("temp.mp4")
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/ori_frame_%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
||||
os.system(cmd_img2video)
|
||||
|
||||
# cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
|
||||
# print(cmd_combine_audio)
|
||||
# os.system(cmd_combine_audio)
|
||||
|
||||
# shutil.rmtree(result_img_save_path)
|
||||
print(f"result is save to {output_vid_name}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0)
|
||||
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||
|
||||
parser.add_argument("--fps", type=int, default=25)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--output_vid_name", type=str, default=None)
|
||||
parser.add_argument("--use_saved_coord",
|
||||
action="store_true",
|
||||
help='use saved coordinate to save time')
|
||||
parser.add_argument("--use_float16",
|
||||
action="store_true",
|
||||
help="Whether use float16 to speed up inference",
|
||||
)
|
||||
parser.add_argument("--unet_checkpoint", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -158,4 +158,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
main(args)
|
||||
@@ -48,22 +48,22 @@ def get_image_list(data_root, split):
|
||||
class Dataset(object):
|
||||
def __init__(self,
|
||||
data_root,
|
||||
split,
|
||||
json_path,
|
||||
use_audio_length_left=1,
|
||||
use_audio_length_right=1,
|
||||
whisper_model_type = "tiny"
|
||||
):
|
||||
self.all_videos, self.all_imgNum = get_image_list(data_root, split)
|
||||
# self.all_videos, self.all_imgNum = get_image_list(data_root, split)
|
||||
self.audio_feature = [use_audio_length_left,use_audio_length_right]
|
||||
self.all_img_names = []
|
||||
self.split = split
|
||||
self.img_names_path = '...'
|
||||
# self.split = split
|
||||
self.img_names_path = '../data'
|
||||
self.whisper_model_type = whisper_model_type
|
||||
self.use_audio_length_left = use_audio_length_left
|
||||
self.use_audio_length_right = use_audio_length_right
|
||||
|
||||
if self.whisper_model_type =="tiny":
|
||||
self.whisper_path = '...'
|
||||
self.whisper_path = '../data/audios'
|
||||
self.whisper_feature_W = 5
|
||||
self.whisper_feature_H = 384
|
||||
elif self.whisper_model_type =="largeV2":
|
||||
@@ -71,6 +71,8 @@ class Dataset(object):
|
||||
self.whisper_feature_W = 33
|
||||
self.whisper_feature_H = 1280
|
||||
self.whisper_feature_concateW = self.whisper_feature_W*2*(self.use_audio_length_left+self.use_audio_length_right+1) #5*2*(2+2+1)= 50
|
||||
with open(json_path, 'r') as file:
|
||||
self.all_videos = json.load(file)
|
||||
|
||||
for vidname in tqdm(self.all_videos, desc="Preparing dataset"):
|
||||
json_path_names = f"{self.img_names_path}/{vidname.split('/')[-1].split('.')[0]}.json"
|
||||
@@ -79,7 +81,6 @@ class Dataset(object):
|
||||
img_names.sort(key=lambda x:int(x.split("/")[-1].split('.')[0]))
|
||||
with open(json_path_names, "w") as f:
|
||||
json.dump(img_names,f)
|
||||
print(f"save to {json_path_names}")
|
||||
else:
|
||||
with open(json_path_names, "r") as f:
|
||||
img_names = json.load(f)
|
||||
@@ -135,7 +136,6 @@ class Dataset(object):
|
||||
vidname = self.all_videos[idx].split('/')[-1]
|
||||
video_imgs = self.all_img_names[idx]
|
||||
if len(video_imgs) == 0:
|
||||
# print("video_imgs = 0:",vidname)
|
||||
continue
|
||||
img_name = random.choice(video_imgs)
|
||||
img_idx = int(basename(img_name).split(".")[0])
|
||||
@@ -193,7 +193,6 @@ class Dataset(object):
|
||||
for feat_idx in range(window_index-self.use_audio_length_left,window_index+self.use_audio_length_right+1):
|
||||
# 判定是否越界
|
||||
audio_feat_path = os.path.join(self.whisper_path, sub_folder_name, str(feat_idx) + ".npy")
|
||||
|
||||
if not os.path.exists(audio_feat_path):
|
||||
is_index_out_of_range = True
|
||||
break
|
||||
@@ -214,8 +213,6 @@ class Dataset(object):
|
||||
print(f"shape error!! {vidname} {window_index}, audio_feature.shape: {audio_feature.shape}")
|
||||
continue
|
||||
audio_feature = torch.squeeze(torch.FloatTensor(audio_feature))
|
||||
|
||||
|
||||
return ref_image, image, masked_image, mask, audio_feature
|
||||
|
||||
|
||||
@@ -231,10 +228,8 @@ if __name__ == "__main__":
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_data, batch_size=4, shuffle=True,
|
||||
num_workers=1)
|
||||
print("val_dataset:",val_data_loader.__len__())
|
||||
|
||||
for i, data in enumerate(val_data_loader):
|
||||
ref_image, image, masked_image, mask, audio_feature = data
|
||||
print("ref_image: ", ref_image.shape)
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
# Draft training codes
|
||||
# Data preprocessing
|
||||
|
||||
We provde the draft training codes here. Unfortunately, data preprocessing code is still being reorganized.
|
||||
Create two config yaml files, one for training and other for testing (both in same format as configs/inference/test.yaml)
|
||||
The train yaml file should contain the training video paths and corresponding audio paths
|
||||
The test yaml file should contain the validation video paths and corresponding audio paths
|
||||
|
||||
## Setup
|
||||
Run:
|
||||
```
|
||||
./data_new.sh train output train_video1.mp4 train_video2.mp4
|
||||
./data_new.sh test output test_video1.mp4 test_video2.mp4
|
||||
```
|
||||
This creates folders which contain the image frames and npy files. This also creates train.json and val.json which can be used during the training.
|
||||
|
||||
We trained our model on an NVIDIA A100 with `batch size=8, gradient_accumulation_steps=4` for 20w+ steps. Using multiple GPUs should accelerate the training.
|
||||
|
||||
## Data preprocessing
|
||||
You could refer the inference codes which [crop the face images](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L79) and [extract audio features](https://github.com/TMElyralab/MuseTalk/blob/main/scripts/inference.py#L69).
|
||||
|
||||
Finally, the data should be organized as follows:
|
||||
## Data organization
|
||||
```
|
||||
./data/
|
||||
├── images
|
||||
@@ -35,9 +37,16 @@ Finally, the data should be organized as follows:
|
||||
## Training
|
||||
Simply run after preparing the preprocessed data
|
||||
```
|
||||
sh train.sh
|
||||
cd train_codes
|
||||
sh train.sh #--train_json="../train.json" \(Generated in Data preprocessing step.)
|
||||
#--val_json="../val.json" \
|
||||
```
|
||||
## Inference with trained checkpoit
|
||||
Simply run after training the model, the model checkpoints are saved at train_codes/output usually
|
||||
```
|
||||
python -m scripts.finetuned_inference --inference_config configs/inference/test.yaml --unet_checkpoint path_to_trained_checkpoint_folder
|
||||
```
|
||||
|
||||
## TODO
|
||||
- [ ] release data preprocessing codes
|
||||
- [x] release data preprocessing codes
|
||||
- [ ] release some novel designs in training (after technical report)
|
||||
@@ -27,10 +27,13 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
|
||||
import sys
|
||||
sys.path.append("./")
|
||||
|
||||
from DataLoader import Dataset
|
||||
from utils.utils import preprocess_img_tensor
|
||||
from torch.utils import data as data_utils
|
||||
from model_utils import validation,PositionalEncoding
|
||||
from utils.model_utils import validation,PositionalEncoding
|
||||
import time
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
@@ -137,6 +140,8 @@ def parse_args():
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument("--train_json", type=str, default="train.json", help="The json file containing train image folders")
|
||||
parser.add_argument("--val_json", type=str, default="test.json", help="The json file containing validation image folders")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
@@ -234,13 +239,17 @@ def parse_args():
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def print_model_dtypes(model, model_name):
|
||||
for name, param in model.named_parameters():
|
||||
if(param.dtype!=torch.float32):
|
||||
print(f"{name}: {param.dtype}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
print(args)
|
||||
args.output_dir = f"output/{args.output_dir}"
|
||||
args.val_out_dir = f"val/{args.val_out_dir}"
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
@@ -332,7 +341,7 @@ def main():
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet.parameters())
|
||||
itertools.chain(unet.parameters()))
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
@@ -343,23 +352,21 @@ def main():
|
||||
|
||||
print("loading train_dataset ...")
|
||||
train_dataset = Dataset(args.data_root,
|
||||
'train',
|
||||
args.train_json,
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
print("train_dataset:",train_dataset.__len__())
|
||||
train_data_loader = data_utils.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True,
|
||||
num_workers=8)
|
||||
print("loading val_dataset ...")
|
||||
val_dataset = Dataset(args.data_root,
|
||||
'val',
|
||||
args.val_json,
|
||||
use_audio_length_left=args.use_audio_length_left,
|
||||
use_audio_length_right=args.use_audio_length_right,
|
||||
whisper_model_type=args.whisper_model_type
|
||||
)
|
||||
print("val_dataset:",val_dataset.__len__())
|
||||
val_data_loader = data_utils.DataLoader(
|
||||
val_dataset, batch_size=1, shuffle=False,
|
||||
num_workers=8)
|
||||
@@ -388,6 +395,7 @@ def main():
|
||||
vae_fp32.requires_grad_(False)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
# weight_dtype = torch.float16
|
||||
vae_fp32.to(accelerator.device, dtype=weight_dtype)
|
||||
vae_fp32.encoder = None
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
@@ -412,6 +420,8 @@ def main():
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
print(f" Num batches each epoch = {len(train_data_loader)}")
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num batches each epoch = {len(train_data_loader)}")
|
||||
@@ -433,6 +443,9 @@ def main():
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
# path="../models/pytorch_model.bin"
|
||||
#TODO change path
|
||||
# path=None
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
@@ -458,10 +471,11 @@ def main():
|
||||
# caluate the elapsed time
|
||||
elapsed_time = []
|
||||
start = time.time()
|
||||
|
||||
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
# for step, batch in enumerate(train_dataloader):
|
||||
for step, (ref_image, image, masked_image, masks, audio_feature) in enumerate(train_data_loader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
@@ -470,24 +484,23 @@ def main():
|
||||
continue
|
||||
dataloader_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
masks = masks.unsqueeze(1).unsqueeze(1).to(vae.device)
|
||||
"""
|
||||
print("=============epoch:{0}=step:{1}=====".format(epoch,step))
|
||||
print("ref_image: ",ref_image.shape)
|
||||
print("masks: ", masks.shape)
|
||||
print("masked_image: ", masked_image.shape)
|
||||
print("audio feature: ", audio_feature.shape)
|
||||
print("image: ", image.shape)
|
||||
"""
|
||||
# """
|
||||
# print("=============epoch:{0}=step:{1}=====".format(epoch,step))
|
||||
# print("ref_image: ",ref_image.shape)
|
||||
# print("masks: ", masks.shape)
|
||||
# print("masked_image: ", masked_image.shape)
|
||||
# print("audio feature: ", audio_feature.shape)
|
||||
# print("image: ", image.shape)
|
||||
# """
|
||||
ref_image = preprocess_img_tensor(ref_image).to(vae.device)
|
||||
image = preprocess_img_tensor(image).to(vae.device)
|
||||
masked_image = preprocess_img_tensor(masked_image).to(vae.device)
|
||||
|
||||
img_process_time = time.time() - start
|
||||
start = time.time()
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
vae = vae.half()
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample() # init image
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -592,12 +605,23 @@ def main():
|
||||
f"Running validation... epoch={epoch}, global_step={global_step}"
|
||||
)
|
||||
print("===========start validation==========")
|
||||
# Use the helper function to check the data types for each model
|
||||
vae_new = vae.float()
|
||||
print_model_dtypes(accelerator.unwrap_model(vae_new), "VAE")
|
||||
print_model_dtypes(accelerator.unwrap_model(vae_fp32), "VAE_FP32")
|
||||
print_model_dtypes(accelerator.unwrap_model(unet), "UNET")
|
||||
|
||||
print(f"weight_dtype: {weight_dtype}")
|
||||
print(f"epoch type: {type(epoch)}")
|
||||
print(f"global_step type: {type(global_step)}")
|
||||
validation(
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
# vae=accelerator.unwrap_model(vae),
|
||||
vae=accelerator.unwrap_model(vae_new),
|
||||
vae_fp32=accelerator.unwrap_model(vae_fp32),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet_config=unet_config,
|
||||
weight_dtype=weight_dtype,
|
||||
# weight_dtype=weight_dtype,
|
||||
weight_dtype=torch.float32,
|
||||
epoch=epoch,
|
||||
global_step=global_step,
|
||||
val_data_loader=val_data_loader,
|
||||
|
||||
@@ -1,27 +1,29 @@
|
||||
export VAE_MODEL="./sd-vae-ft-mse/"
|
||||
export DATASET="..."
|
||||
export UNET_CONFIG="./musetalk.json"
|
||||
export VAE_MODEL="../models/sd-vae-ft-mse/"
|
||||
export DATASET="../data"
|
||||
export UNET_CONFIG="../models/musetalk/musetalk.json"
|
||||
|
||||
accelerate launch --multi_gpu train.py \
|
||||
accelerate launch train.py \
|
||||
--mixed_precision="fp16" \
|
||||
--unet_config_file=$UNET_CONFIG \
|
||||
--pretrained_model_name_or_path=$VAE_MODEL \
|
||||
--data_root=$DATASET \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--train_batch_size=256 \
|
||||
--gradient_accumulation_steps=16 \
|
||||
--gradient_checkpointing \
|
||||
--max_train_steps=200000 \
|
||||
--max_train_steps=100000 \
|
||||
--learning_rate=5e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="cosine" \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="..." \
|
||||
--val_out_dir='...' \
|
||||
--output_dir="output" \
|
||||
--val_out_dir='val' \
|
||||
--testing_speed \
|
||||
--checkpointing_steps=1000 \
|
||||
--validation_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--validation_steps=2000 \
|
||||
--reconstruction \
|
||||
--resume_from_checkpoint="latest" \
|
||||
--use_audio_length_left=2 \
|
||||
--use_audio_length_right=2 \
|
||||
--whisper_model_type="tiny" \
|
||||
--train_json="../train.json" \
|
||||
--val_json="../val.json" \
|
||||
--lr_scheduler="cosine" \
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import math
|
||||
from utils import decode_latents, preprocess_img_tensor
|
||||
from utils.utils import decode_latents, preprocess_img_tensor
|
||||
import os
|
||||
from PIL import Image
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
Reference in New Issue
Block a user