mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
118 lines
4.6 KiB
Python
Executable File
118 lines
4.6 KiB
Python
Executable File
import torch
|
|
import time
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
from .model import BiSeNet
|
|
import torchvision.transforms as transforms
|
|
|
|
class FaceParsing():
|
|
def __init__(self, left_cheek_width=80, right_cheek_width=80):
|
|
self.net = self.model_init()
|
|
self.preprocess = self.image_preprocess()
|
|
# Ensure all size parameters are integers
|
|
cone_height = 21
|
|
tail_height = 12
|
|
total_size = cone_height + tail_height
|
|
|
|
# Create kernel with explicit integer dimensions
|
|
kernel = np.zeros((total_size, total_size), dtype=np.uint8)
|
|
center_x = total_size // 2 # Ensure center coordinates are integers
|
|
|
|
# Cone part
|
|
for row in range(cone_height):
|
|
if row < cone_height//2:
|
|
continue
|
|
width = int(2 * (row - cone_height//2) + 1)
|
|
start = int(center_x - (width // 2))
|
|
end = int(center_x + (width // 2) + 1)
|
|
kernel[row, start:end] = 1
|
|
|
|
# Vertical extension part
|
|
if cone_height > 0:
|
|
base_width = int(kernel[cone_height-1].sum())
|
|
else:
|
|
base_width = 1
|
|
|
|
for row in range(cone_height, total_size):
|
|
start = max(0, int(center_x - (base_width//2)))
|
|
end = min(total_size, int(center_x + (base_width//2) + 1))
|
|
kernel[row, start:end] = 1
|
|
self.kernel = kernel
|
|
|
|
# Modify cheek erosion kernel to be flatter ellipse
|
|
self.cheek_kernel = cv2.getStructuringElement(
|
|
cv2.MORPH_ELLIPSE, (35, 3))
|
|
|
|
# Add cheek area mask (protect chin area)
|
|
self.cheek_mask = self._create_cheek_mask(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
|
|
|
|
def _create_cheek_mask(self, left_cheek_width=80, right_cheek_width=80):
|
|
"""Create cheek area mask (1/4 area on both sides)"""
|
|
mask = np.zeros((512, 512), dtype=np.uint8)
|
|
center = 512 // 2
|
|
cv2.rectangle(mask, (0, 0), (center - left_cheek_width, 512), 255, -1) # Left cheek
|
|
cv2.rectangle(mask, (center + right_cheek_width, 0), (512, 512), 255, -1) # Right cheek
|
|
return mask
|
|
|
|
def model_init(self,
|
|
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
|
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
|
net = BiSeNet(resnet_path)
|
|
if torch.cuda.is_available():
|
|
net.cuda()
|
|
net.load_state_dict(torch.load(model_pth))
|
|
else:
|
|
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
|
|
net.eval()
|
|
return net
|
|
|
|
def image_preprocess(self):
|
|
return transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
|
])
|
|
|
|
def __call__(self, image, size=(512, 512), mode="raw"):
|
|
if isinstance(image, str):
|
|
image = Image.open(image)
|
|
|
|
width, height = image.size
|
|
with torch.no_grad():
|
|
image = image.resize(size, Image.BILINEAR)
|
|
img = self.preprocess(image)
|
|
if torch.cuda.is_available():
|
|
img = torch.unsqueeze(img, 0).cuda()
|
|
else:
|
|
img = torch.unsqueeze(img, 0)
|
|
out = self.net(img)[0]
|
|
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
|
|
|
# Add 14:neck, remove 10:nose and 7:8:9
|
|
if mode == "neck":
|
|
parsing[np.isin(parsing, [1, 11, 12, 13, 14])] = 255
|
|
parsing[np.where(parsing!=255)] = 0
|
|
elif mode == "jaw":
|
|
face_region = np.isin(parsing, [1])*255
|
|
face_region = face_region.astype(np.uint8)
|
|
original_dilated = cv2.dilate(face_region, self.kernel, iterations=1)
|
|
eroded = cv2.erode(original_dilated, self.cheek_kernel, iterations=2)
|
|
face_region = cv2.bitwise_and(eroded, self.cheek_mask)
|
|
face_region = cv2.bitwise_or(face_region, cv2.bitwise_and(original_dilated, ~self.cheek_mask))
|
|
parsing[(face_region==255) & (~np.isin(parsing, [10]))] = 255
|
|
parsing[np.isin(parsing, [11, 12, 13])] = 255
|
|
parsing[np.where(parsing!=255)] = 0
|
|
else:
|
|
parsing[np.isin(parsing, [1, 11, 12, 13])] = 255
|
|
parsing[np.where(parsing!=255)] = 0
|
|
|
|
parsing = Image.fromarray(parsing.astype(np.uint8))
|
|
return parsing
|
|
|
|
if __name__ == "__main__":
|
|
fp = FaceParsing()
|
|
segmap = fp('154_small.png')
|
|
segmap.save('res.png')
|
|
|