mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 09:29:20 +08:00
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
327 lines
12 KiB
Python
Executable File
327 lines
12 KiB
Python
Executable File
import os
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from typing import Union, List
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
import shutil
|
|
import os.path as osp
|
|
|
|
ffmpeg_path = os.getenv('FFMPEG_PATH')
|
|
if ffmpeg_path is None:
|
|
print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
|
|
elif ffmpeg_path not in os.getenv('PATH'):
|
|
print("add ffmpeg to path")
|
|
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
|
|
|
|
|
|
from musetalk.models.vae import VAE
|
|
from musetalk.models.unet import UNet,PositionalEncoding
|
|
|
|
|
|
def load_all_model(
|
|
unet_model_path="./models/musetalk/pytorch_model.bin",
|
|
vae_type="sd-vae-ft-mse",
|
|
unet_config="./models/musetalk/musetalk.json",
|
|
device=None,
|
|
):
|
|
vae = VAE(
|
|
model_path = f"./models/{vae_type}/",
|
|
)
|
|
print(f"load unet model from {unet_model_path}")
|
|
unet = UNet(
|
|
unet_config=unet_config,
|
|
model_path=unet_model_path,
|
|
device=device
|
|
)
|
|
pe = PositionalEncoding(d_model=384)
|
|
return vae, unet, pe
|
|
|
|
def get_file_type(video_path):
|
|
_, ext = os.path.splitext(video_path)
|
|
|
|
if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
|
|
return 'image'
|
|
elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
|
|
return 'video'
|
|
else:
|
|
return 'unsupported'
|
|
|
|
def get_video_fps(video_path):
|
|
video = cv2.VideoCapture(video_path)
|
|
fps = video.get(cv2.CAP_PROP_FPS)
|
|
video.release()
|
|
return fps
|
|
|
|
def datagen(
|
|
whisper_chunks,
|
|
vae_encode_latents,
|
|
batch_size=8,
|
|
delay_frame=0,
|
|
device="cuda:0",
|
|
):
|
|
whisper_batch, latent_batch = [], []
|
|
for i, w in enumerate(whisper_chunks):
|
|
idx = (i+delay_frame)%len(vae_encode_latents)
|
|
latent = vae_encode_latents[idx]
|
|
whisper_batch.append(w)
|
|
latent_batch.append(latent)
|
|
|
|
if len(latent_batch) >= batch_size:
|
|
whisper_batch = torch.stack(whisper_batch)
|
|
latent_batch = torch.cat(latent_batch, dim=0)
|
|
yield whisper_batch, latent_batch
|
|
whisper_batch, latent_batch = [], []
|
|
|
|
# the last batch may smaller than batch size
|
|
if len(latent_batch) > 0:
|
|
whisper_batch = torch.stack(whisper_batch)
|
|
latent_batch = torch.cat(latent_batch, dim=0)
|
|
|
|
yield whisper_batch.to(device), latent_batch.to(device)
|
|
|
|
def cast_training_params(
|
|
model: Union[torch.nn.Module, List[torch.nn.Module]],
|
|
dtype=torch.float32,
|
|
):
|
|
if not isinstance(model, list):
|
|
model = [model]
|
|
for m in model:
|
|
for param in m.parameters():
|
|
# only upcast trainable parameters into fp32
|
|
if param.requires_grad:
|
|
param.data = param.to(dtype)
|
|
|
|
def rand_log_normal(
|
|
shape,
|
|
loc=0.,
|
|
scale=1.,
|
|
device='cpu',
|
|
dtype=torch.float32,
|
|
generator=None
|
|
):
|
|
"""Draws samples from an lognormal distribution."""
|
|
rnd_normal = torch.randn(
|
|
shape, device=device, dtype=dtype, generator=generator) # N(0, I)
|
|
sigma = (rnd_normal * scale + loc).exp()
|
|
return sigma
|
|
|
|
def get_mouth_region(frames, image_pred, pixel_values_face_mask):
|
|
# Initialize lists to store the results for each image in the batch
|
|
mouth_real_list = []
|
|
mouth_generated_list = []
|
|
|
|
# Process each image in the batch
|
|
for b in range(frames.shape[0]):
|
|
# Find the non-zero area in the face mask
|
|
non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
|
|
# If there are no non-zero indices, skip this image
|
|
if non_zero_indices.numel() == 0:
|
|
continue
|
|
|
|
min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
|
|
non_zero_indices[:, 1])
|
|
min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
|
|
non_zero_indices[:, 2])
|
|
|
|
# Crop the frames and image_pred according to the non-zero area
|
|
frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
|
|
image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
|
|
# Resize the cropped images to 256*256
|
|
frames_resized = F.interpolate(frames_cropped.unsqueeze(
|
|
0), size=(256, 256), mode='bilinear', align_corners=False)
|
|
image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
|
|
0), size=(256, 256), mode='bilinear', align_corners=False)
|
|
|
|
# Append the resized images to the result lists
|
|
mouth_real_list.append(frames_resized)
|
|
mouth_generated_list.append(image_pred_resized)
|
|
|
|
# Convert the lists to tensors if they are not empty
|
|
mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
|
|
mouth_generated = torch.cat(
|
|
mouth_generated_list, dim=0) if mouth_generated_list else None
|
|
|
|
return mouth_real, mouth_generated
|
|
|
|
def get_image_pred(pixel_values,
|
|
ref_pixel_values,
|
|
audio_prompts,
|
|
vae,
|
|
net,
|
|
weight_dtype):
|
|
with torch.no_grad():
|
|
bsz, num_frames, c, h, w = pixel_values.shape
|
|
|
|
masked_pixel_values = pixel_values.clone()
|
|
masked_pixel_values[:, :, :, h//2:, :] = -1
|
|
|
|
masked_frames = rearrange(
|
|
masked_pixel_values, 'b f c h w -> (b f) c h w')
|
|
masked_latents = vae.encode(masked_frames).latent_dist.mode()
|
|
masked_latents = masked_latents * vae.config.scaling_factor
|
|
masked_latents = masked_latents.float()
|
|
|
|
ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
|
|
ref_latents = vae.encode(ref_frames).latent_dist.mode()
|
|
ref_latents = ref_latents * vae.config.scaling_factor
|
|
ref_latents = ref_latents.float()
|
|
|
|
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
|
|
input_latents = input_latents.to(weight_dtype)
|
|
timesteps = torch.tensor([0], device=input_latents.device)
|
|
latents_pred = net(
|
|
input_latents,
|
|
timesteps,
|
|
audio_prompts,
|
|
)
|
|
latents_pred = (1 / vae.config.scaling_factor) * latents_pred
|
|
image_pred = vae.decode(latents_pred).sample
|
|
image_pred = image_pred.float()
|
|
|
|
return image_pred
|
|
|
|
def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
|
|
with torch.no_grad():
|
|
audio_feature_length_per_frame = 2 * \
|
|
(cfg.data.audio_padding_length_left +
|
|
cfg.data.audio_padding_length_right + 1)
|
|
audio_feats = batch['audio_feature'].to(weight_dtype)
|
|
audio_feats = wav2vec.encoder(
|
|
audio_feats, output_hidden_states=True).hidden_states
|
|
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384]
|
|
|
|
start_ts = batch['audio_offset']
|
|
step_ts = batch['audio_step']
|
|
audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
|
|
audio_feats,
|
|
torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
|
|
audio_prompts = []
|
|
for bb in range(bsz):
|
|
audio_feats_list = []
|
|
for f in range(num_frames):
|
|
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
|
|
audio_clip = audio_feats[bb:bb+1,
|
|
cur_t: cur_t+audio_feature_length_per_frame]
|
|
|
|
audio_feats_list.append(audio_clip)
|
|
audio_feats_list = torch.stack(audio_feats_list, 1)
|
|
audio_prompts.append(audio_feats_list)
|
|
audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384
|
|
return audio_prompts
|
|
|
|
def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
|
|
save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")
|
|
|
|
if total_limit is not None:
|
|
checkpoints = os.listdir(save_dir)
|
|
checkpoints = [d for d in checkpoints if d.endswith(".pth")]
|
|
checkpoints = [d for d in checkpoints if name in d]
|
|
checkpoints = sorted(
|
|
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
|
|
)
|
|
|
|
if len(checkpoints) >= total_limit:
|
|
num_to_remove = len(checkpoints) - total_limit + 1
|
|
removing_checkpoints = checkpoints[0:num_to_remove]
|
|
logger.info(
|
|
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
|
)
|
|
logger.info(
|
|
f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
|
|
|
for removing_checkpoint in removing_checkpoints:
|
|
removing_checkpoint = os.path.join(
|
|
save_dir, removing_checkpoint)
|
|
os.remove(removing_checkpoint)
|
|
|
|
state_dict = model.state_dict()
|
|
torch.save(state_dict, save_path)
|
|
|
|
def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
|
|
unwarp_net = accelerator.unwrap_model(net)
|
|
save_checkpoint(
|
|
unwarp_net.unet,
|
|
save_dir,
|
|
global_step,
|
|
name="unet",
|
|
total_limit=cfg.total_limit,
|
|
logger=logger
|
|
)
|
|
|
|
def delete_additional_ckpt(base_path, num_keep):
|
|
dirs = []
|
|
for d in os.listdir(base_path):
|
|
if d.startswith("checkpoint-"):
|
|
dirs.append(d)
|
|
num_tot = len(dirs)
|
|
if num_tot <= num_keep:
|
|
return
|
|
# ensure ckpt is sorted and delete the ealier!
|
|
del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
|
|
for d in del_dirs:
|
|
path_to_dir = osp.join(base_path, d)
|
|
if osp.exists(path_to_dir):
|
|
shutil.rmtree(path_to_dir)
|
|
|
|
def seed_everything(seed):
|
|
import random
|
|
|
|
import numpy as np
|
|
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
np.random.seed(seed % (2**32))
|
|
random.seed(seed)
|
|
|
|
def process_and_save_images(
|
|
batch,
|
|
image_pred,
|
|
image_pred_infer,
|
|
save_dir,
|
|
global_step,
|
|
accelerator,
|
|
num_images_to_keep=10,
|
|
syncnet_score=1
|
|
):
|
|
# Rearrange the tensors
|
|
print("image_pred.shape: ", image_pred.shape)
|
|
pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
|
|
pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
|
|
|
|
# Create masked pixel values
|
|
masked_pixel_values = batch["pixel_values_vid"].clone()
|
|
_, _, _, h, _ = batch["pixel_values_vid"].shape
|
|
masked_pixel_values[:, :, :, h//2:, :] = -1
|
|
masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
|
|
|
|
# Keep only the specified number of images
|
|
pixel_values = pixel_values[:num_images_to_keep, :, :, :]
|
|
masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
|
|
pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
|
|
image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
|
|
image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
|
|
|
|
# Concatenate images
|
|
concat = torch.cat([
|
|
masked_pixel_values * 0.5 + 0.5,
|
|
pixel_values_ref_img * 0.5 + 0.5,
|
|
image_pred * 0.5 + 0.5,
|
|
pixel_values * 0.5 + 0.5,
|
|
image_pred_infer * 0.5 + 0.5,
|
|
], dim=2)
|
|
print("concat.shape: ", concat.shape)
|
|
|
|
# Create the save directory if it doesn't exist
|
|
os.makedirs(f'{save_dir}/samples/', exist_ok=True)
|
|
|
|
# Try to save the concatenated image
|
|
try:
|
|
# Concatenate images horizontally and convert to numpy array
|
|
final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
|
|
# Save the image
|
|
cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
|
|
print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
|
|
except Exception as e:
|
|
print(f"Failed to save image: {e}") |