This commit is contained in:
aidenyzhang
2025-03-28 16:03:02 +08:00
parent 058f7ddc7f
commit db204311a5
46 changed files with 729 additions and 204 deletions

View File

@@ -8,9 +8,53 @@ from .model import BiSeNet
import torchvision.transforms as transforms
class FaceParsing():
def __init__(self):
def __init__(self, left_cheek_width=80, right_cheek_width=80):
self.net = self.model_init()
self.preprocess = self.image_preprocess()
# Ensure all size parameters are integers
cone_height = 21
tail_height = 12
total_size = cone_height + tail_height
# Create kernel with explicit integer dimensions
kernel = np.zeros((total_size, total_size), dtype=np.uint8)
center_x = total_size // 2 # Ensure center coordinates are integers
# Cone part
for row in range(cone_height):
if row < cone_height//2:
continue
width = int(2 * (row - cone_height//2) + 1)
start = int(center_x - (width // 2))
end = int(center_x + (width // 2) + 1)
kernel[row, start:end] = 1
# Vertical extension part
if cone_height > 0:
base_width = int(kernel[cone_height-1].sum())
else:
base_width = 1
for row in range(cone_height, total_size):
start = max(0, int(center_x - (base_width//2)))
end = min(total_size, int(center_x + (base_width//2) + 1))
kernel[row, start:end] = 1
self.kernel = kernel
# Modify cheek erosion kernel to be flatter ellipse
self.cheek_kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (35, 3))
# Add cheek area mask (protect chin area)
self.cheek_mask = self._create_cheek_mask(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
def _create_cheek_mask(self, left_cheek_width=80, right_cheek_width=80):
"""Create cheek area mask (1/4 area on both sides)"""
mask = np.zeros((512, 512), dtype=np.uint8)
center = 512 // 2
cv2.rectangle(mask, (0, 0), (center - left_cheek_width, 512), 255, -1) # Left cheek
cv2.rectangle(mask, (center + right_cheek_width, 0), (512, 512), 255, -1) # Right cheek
return mask
def model_init(self,
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
@@ -30,7 +74,7 @@ class FaceParsing():
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def __call__(self, image, size=(512, 512)):
def __call__(self, image, size=(512, 512), mode="jaw"):
if isinstance(image, str):
image = Image.open(image)
@@ -44,8 +88,25 @@ class FaceParsing():
img = torch.unsqueeze(img, 0)
out = self.net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
parsing[np.where(parsing>13)] = 0
parsing[np.where(parsing>=1)] = 255
# Add 14:neck, remove 10:nose and 7:8:9
if mode == "neck":
parsing[np.isin(parsing, [1, 11, 12, 13, 14])] = 255
parsing[np.where(parsing!=255)] = 0
elif mode == "jaw":
face_region = np.isin(parsing, [1])*255
face_region = face_region.astype(np.uint8)
original_dilated = cv2.dilate(face_region, self.kernel, iterations=1)
eroded = cv2.erode(original_dilated, self.cheek_kernel, iterations=2)
face_region = cv2.bitwise_and(eroded, self.cheek_mask)
face_region = cv2.bitwise_or(face_region, cv2.bitwise_and(original_dilated, ~self.cheek_mask))
parsing[(face_region==255) & (~np.isin(parsing, [10]))] = 255
parsing[np.isin(parsing, [11, 12, 13])] = 255
parsing[np.where(parsing!=255)] = 0
else:
parsing[np.isin(parsing, [1, 11, 12, 13])] = 255
parsing[np.where(parsing!=255)] = 0
parsing = Image.fromarray(parsing.astype(np.uint8))
return parsing