Files
MuseTalk/musetalk/utils/utils.py
Zhizhou Zhong 1ab53a626b feat: data preprocessing and training (#294)
* docs: update readme

* docs: update readme

* feat: training codes

* feat: data preprocess

* docs: release training
2025-04-04 22:10:03 +08:00

327 lines
12 KiB
Python
Executable File

import os
import cv2
import numpy as np
import torch
from typing import Union, List
import torch.nn.functional as F
from einops import rearrange
import shutil
import os.path as osp
ffmpeg_path = os.getenv('FFMPEG_PATH')
if ffmpeg_path is None:
print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
elif ffmpeg_path not in os.getenv('PATH'):
print("add ffmpeg to path")
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
from musetalk.models.vae import VAE
from musetalk.models.unet import UNet,PositionalEncoding
def load_all_model(
unet_model_path="./models/musetalk/pytorch_model.bin",
vae_type="sd-vae-ft-mse",
unet_config="./models/musetalk/musetalk.json",
device=None,
):
vae = VAE(
model_path = f"./models/{vae_type}/",
)
print(f"load unet model from {unet_model_path}")
unet = UNet(
unet_config=unet_config,
model_path=unet_model_path,
device=device
)
pe = PositionalEncoding(d_model=384)
return vae, unet, pe
def get_file_type(video_path):
_, ext = os.path.splitext(video_path)
if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
return 'image'
elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
return 'video'
else:
return 'unsupported'
def get_video_fps(video_path):
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)
video.release()
return fps
def datagen(
whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0,
device="cuda:0",
):
whisper_batch, latent_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(vae_encode_latents)
latent = vae_encode_latents[idx]
whisper_batch.append(w)
latent_batch.append(latent)
if len(latent_batch) >= batch_size:
whisper_batch = torch.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
whisper_batch, latent_batch = [], []
# the last batch may smaller than batch size
if len(latent_batch) > 0:
whisper_batch = torch.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch.to(device), latent_batch.to(device)
def cast_training_params(
model: Union[torch.nn.Module, List[torch.nn.Module]],
dtype=torch.float32,
):
if not isinstance(model, list):
model = [model]
for m in model:
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)
def rand_log_normal(
shape,
loc=0.,
scale=1.,
device='cpu',
dtype=torch.float32,
generator=None
):
"""Draws samples from an lognormal distribution."""
rnd_normal = torch.randn(
shape, device=device, dtype=dtype, generator=generator) # N(0, I)
sigma = (rnd_normal * scale + loc).exp()
return sigma
def get_mouth_region(frames, image_pred, pixel_values_face_mask):
# Initialize lists to store the results for each image in the batch
mouth_real_list = []
mouth_generated_list = []
# Process each image in the batch
for b in range(frames.shape[0]):
# Find the non-zero area in the face mask
non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
# If there are no non-zero indices, skip this image
if non_zero_indices.numel() == 0:
continue
min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
non_zero_indices[:, 1])
min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
non_zero_indices[:, 2])
# Crop the frames and image_pred according to the non-zero area
frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
# Resize the cropped images to 256*256
frames_resized = F.interpolate(frames_cropped.unsqueeze(
0), size=(256, 256), mode='bilinear', align_corners=False)
image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
0), size=(256, 256), mode='bilinear', align_corners=False)
# Append the resized images to the result lists
mouth_real_list.append(frames_resized)
mouth_generated_list.append(image_pred_resized)
# Convert the lists to tensors if they are not empty
mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
mouth_generated = torch.cat(
mouth_generated_list, dim=0) if mouth_generated_list else None
return mouth_real, mouth_generated
def get_image_pred(pixel_values,
ref_pixel_values,
audio_prompts,
vae,
net,
weight_dtype):
with torch.no_grad():
bsz, num_frames, c, h, w = pixel_values.shape
masked_pixel_values = pixel_values.clone()
masked_pixel_values[:, :, :, h//2:, :] = -1
masked_frames = rearrange(
masked_pixel_values, 'b f c h w -> (b f) c h w')
masked_latents = vae.encode(masked_frames).latent_dist.mode()
masked_latents = masked_latents * vae.config.scaling_factor
masked_latents = masked_latents.float()
ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
ref_latents = vae.encode(ref_frames).latent_dist.mode()
ref_latents = ref_latents * vae.config.scaling_factor
ref_latents = ref_latents.float()
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
input_latents = input_latents.to(weight_dtype)
timesteps = torch.tensor([0], device=input_latents.device)
latents_pred = net(
input_latents,
timesteps,
audio_prompts,
)
latents_pred = (1 / vae.config.scaling_factor) * latents_pred
image_pred = vae.decode(latents_pred).sample
image_pred = image_pred.float()
return image_pred
def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
with torch.no_grad():
audio_feature_length_per_frame = 2 * \
(cfg.data.audio_padding_length_left +
cfg.data.audio_padding_length_right + 1)
audio_feats = batch['audio_feature'].to(weight_dtype)
audio_feats = wav2vec.encoder(
audio_feats, output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384]
start_ts = batch['audio_offset']
step_ts = batch['audio_step']
audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
audio_feats,
torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
audio_prompts = []
for bb in range(bsz):
audio_feats_list = []
for f in range(num_frames):
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
audio_clip = audio_feats[bb:bb+1,
cur_t: cur_t+audio_feature_length_per_frame]
audio_feats_list.append(audio_clip)
audio_feats_list = torch.stack(audio_feats_list, 1)
audio_prompts.append(audio_feats_list)
audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384
return audio_prompts
def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")
if total_limit is not None:
checkpoints = os.listdir(save_dir)
checkpoints = [d for d in checkpoints if d.endswith(".pth")]
checkpoints = [d for d in checkpoints if name in d]
checkpoints = sorted(
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
)
if len(checkpoints) >= total_limit:
num_to_remove = len(checkpoints) - total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(
f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(
save_dir, removing_checkpoint)
os.remove(removing_checkpoint)
state_dict = model.state_dict()
torch.save(state_dict, save_path)
def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
unwarp_net = accelerator.unwrap_model(net)
save_checkpoint(
unwarp_net.unet,
save_dir,
global_step,
name="unet",
total_limit=cfg.total_limit,
logger=logger
)
def delete_additional_ckpt(base_path, num_keep):
dirs = []
for d in os.listdir(base_path):
if d.startswith("checkpoint-"):
dirs.append(d)
num_tot = len(dirs)
if num_tot <= num_keep:
return
# ensure ckpt is sorted and delete the ealier!
del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
for d in del_dirs:
path_to_dir = osp.join(base_path, d)
if osp.exists(path_to_dir):
shutil.rmtree(path_to_dir)
def seed_everything(seed):
import random
import numpy as np
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed % (2**32))
random.seed(seed)
def process_and_save_images(
batch,
image_pred,
image_pred_infer,
save_dir,
global_step,
accelerator,
num_images_to_keep=10,
syncnet_score=1
):
# Rearrange the tensors
print("image_pred.shape: ", image_pred.shape)
pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
# Create masked pixel values
masked_pixel_values = batch["pixel_values_vid"].clone()
_, _, _, h, _ = batch["pixel_values_vid"].shape
masked_pixel_values[:, :, :, h//2:, :] = -1
masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
# Keep only the specified number of images
pixel_values = pixel_values[:num_images_to_keep, :, :, :]
masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
# Concatenate images
concat = torch.cat([
masked_pixel_values * 0.5 + 0.5,
pixel_values_ref_img * 0.5 + 0.5,
image_pred * 0.5 + 0.5,
pixel_values * 0.5 + 0.5,
image_pred_infer * 0.5 + 0.5,
], dim=2)
print("concat.shape: ", concat.shape)
# Create the save directory if it doesn't exist
os.makedirs(f'{save_dir}/samples/', exist_ok=True)
# Try to save the concatenated image
try:
# Concatenate images horizontally and convert to numpy array
final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
# Save the image
cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
except Exception as e:
print(f"Failed to save image: {e}")