mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
initial_commit
This commit is contained in:
148
musetalk/models/vae.py
Executable file
148
musetalk/models/vae.py
Executable file
@@ -0,0 +1,148 @@
|
||||
from diffusers import AutoencoderKL
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import torch.nn.functional as F
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
class VAE():
|
||||
"""
|
||||
VAE (Variational Autoencoder) class for image processing.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
|
||||
"""
|
||||
Initialize the VAE instance.
|
||||
|
||||
:param model_path: Path to the trained model.
|
||||
:param resized_img: The size to which images are resized.
|
||||
:param use_float16: Whether to use float16 precision.
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.vae = AutoencoderKL.from_pretrained(self.model_path)
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.vae.to(self.device)
|
||||
|
||||
if use_float16:
|
||||
self.vae = self.vae.half()
|
||||
self._use_float16 = True
|
||||
else:
|
||||
self._use_float16 = False
|
||||
|
||||
self.scaling_factor = self.vae.config.scaling_factor
|
||||
self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
self._resized_img = resized_img
|
||||
self._mask_tensor = self.get_mask_tensor()
|
||||
|
||||
def get_mask_tensor(self):
|
||||
"""
|
||||
Creates a mask tensor for image processing.
|
||||
:return: A mask tensor.
|
||||
"""
|
||||
mask_tensor = torch.zeros((self._resized_img,self._resized_img))
|
||||
mask_tensor[:self._resized_img//2,:] = 1
|
||||
mask_tensor[mask_tensor< 0.5] = 0
|
||||
mask_tensor[mask_tensor>= 0.5] = 1
|
||||
return mask_tensor
|
||||
|
||||
def preprocess_img(self,img_name,half_mask=False):
|
||||
"""
|
||||
Preprocess an image for the VAE.
|
||||
|
||||
:param img_name: The image file path or a list of image file paths.
|
||||
:param half_mask: Whether to apply a half mask to the image.
|
||||
:return: A preprocessed image tensor.
|
||||
"""
|
||||
window = []
|
||||
if isinstance(img_name, str):
|
||||
window_fnames = [img_name]
|
||||
for fname in window_fnames:
|
||||
img = cv2.imread(fname)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (self._resized_img, self._resized_img),
|
||||
interpolation=cv2.INTER_LANCZOS4)
|
||||
window.append(img)
|
||||
else:
|
||||
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
|
||||
window.append(img)
|
||||
|
||||
x = np.asarray(window) / 255.
|
||||
x = np.transpose(x, (3, 0, 1, 2))
|
||||
x = torch.squeeze(torch.FloatTensor(x))
|
||||
if half_mask:
|
||||
x = x * (self._mask_tensor>0.5)
|
||||
x = self.transform(x)
|
||||
|
||||
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
|
||||
x = x.to(self.vae.device)
|
||||
|
||||
return x
|
||||
|
||||
def encode_latents(self,image):
|
||||
"""
|
||||
Encode an image into latent variables.
|
||||
|
||||
:param image: The image tensor to encode.
|
||||
:return: The encoded latent variables.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
|
||||
init_latents = self.scaling_factor * init_latent_dist.sample()
|
||||
return init_latents
|
||||
|
||||
def decode_latents(self, latents):
|
||||
"""
|
||||
Decode latent variables back into an image.
|
||||
:param latents: The latent variables to decode.
|
||||
:return: A NumPy array representing the decoded image.
|
||||
"""
|
||||
latents = (1/ self.scaling_factor) * latents
|
||||
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
image = (image * 255).round().astype("uint8")
|
||||
image = image[...,::-1] # RGB to BGR
|
||||
return image
|
||||
|
||||
def get_latents_for_unet(self,img):
|
||||
"""
|
||||
Prepare latent variables for a U-Net model.
|
||||
:param img: The image to process.
|
||||
:return: A concatenated tensor of latents for U-Net input.
|
||||
"""
|
||||
|
||||
ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
|
||||
masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||
ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
|
||||
ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
return latent_model_input
|
||||
|
||||
if __name__ == "__main__":
|
||||
vae_mode_path = "./models/sd-vae-ft-mse/"
|
||||
vae = VAE(model_path = vae_mode_path,use_float16=False)
|
||||
img_path = "./results/sun001_crop/00000.png"
|
||||
|
||||
crop_imgs_path = "./results/sun001_crop/"
|
||||
latents_out_path = "./results/latents/"
|
||||
if not os.path.exists(latents_out_path):
|
||||
os.mkdir(latents_out_path)
|
||||
|
||||
files = os.listdir(crop_imgs_path)
|
||||
files.sort()
|
||||
files = [file for file in files if file.split(".")[-1] == "png"]
|
||||
|
||||
for file in files:
|
||||
index = file.split(".")[0]
|
||||
img_path = crop_imgs_path + file
|
||||
latents = vae.get_latents_for_unet(img_path)
|
||||
print(img_path,"latents",latents.size())
|
||||
#torch.save(latents,os.path.join(latents_out_path,index+".pt"))
|
||||
#reload_tensor = torch.load('tensor.pt')
|
||||
#print(reload_tensor.size())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user