Files
MuseTalk/musetalk/loss/vgg_face.py
Zhizhou Zhong 1ab53a626b feat: data preprocessing and training (#294)
* docs: update readme

* docs: update readme

* feat: training codes

* feat: data preprocess

* docs: release training
2025-04-04 22:10:03 +08:00

237 lines
8.1 KiB
Python
Executable File

'''
This part of code contains a pretrained vgg_face model.
ref link: https://github.com/prlz77/vgg-face.pytorch
'''
import torch
import torch.nn.functional as F
import torch.utils.model_zoo
import pickle
from musetalk.loss import resnet as ResNet
MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth"
VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl'
# It was 93.5940, 104.7624, 129.1863 before dividing by 255
MEAN_RGB = [
0.367035294117647,
0.41083294117647057,
0.5066129411764705
]
def load_state_dict(model, fname):
"""
Set parameters converted from Caffe models authors of VGGFace2 provide.
See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
Arguments:
model: model
fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
"""
with open(fname, 'rb') as f:
weights = pickle.load(f, encoding='latin1')
own_state = model.state_dict()
for name, param in weights.items():
if name in own_state:
try:
own_state[name].copy_(torch.from_numpy(param))
except Exception:
raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
else:
raise KeyError('unexpected key "{}" in state_dict'.format(name))
def vggface2(pretrained=True):
vggface = ResNet.resnet50(num_classes=8631, include_top=True)
load_state_dict(vggface, VGG_FACE_PATH)
return vggface
def vggface(pretrained=False, **kwargs):
"""VGGFace model.
Args:
pretrained (bool): If True, returns pre-trained model
"""
model = VggFace(**kwargs)
if pretrained:
state = torch.utils.model_zoo.load_url(MODEL_URL)
model.load_state_dict(state)
return model
class VggFace(torch.nn.Module):
def __init__(self, classes=2622):
"""VGGFace model.
Face recognition network. It takes as input a Bx3x224x224
batch of face images and gives as output a BxC score vector
(C is the number of identities).
Input images need to be scaled in the 0-1 range and then
normalized with respect to the mean RGB used during training.
Args:
classes (int): number of identities recognized by the
network
"""
super().__init__()
self.conv1 = _ConvBlock(3, 64, 64)
self.conv2 = _ConvBlock(64, 128, 128)
self.conv3 = _ConvBlock(128, 256, 256, 256)
self.conv4 = _ConvBlock(256, 512, 512, 512)
self.conv5 = _ConvBlock(512, 512, 512, 512)
self.dropout = torch.nn.Dropout(0.5)
self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096)
self.fc2 = torch.nn.Linear(4096, 4096)
self.fc3 = torch.nn.Linear(4096, classes)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = x.view(x.size(0), -1)
x = self.dropout(F.relu(self.fc1(x)))
x = self.dropout(F.relu(self.fc2(x)))
x = self.fc3(x)
return x
class _ConvBlock(torch.nn.Module):
"""A Convolutional block."""
def __init__(self, *units):
"""Create a block with len(units) - 1 convolutions.
convolution number i transforms the number of channels from
units[i - 1] to units[i] channels.
"""
super().__init__()
self.convs = torch.nn.ModuleList([
torch.nn.Conv2d(in_, out, 3, 1, 1)
for in_, out in zip(units[:-1], units[1:])
])
def forward(self, x):
# Each convolution is followed by a ReLU, then the block is
# concluded by a max pooling.
for c in self.convs:
x = F.relu(c(x))
return F.max_pool2d(x, 2, 2, 0, ceil_mode=True)
import numpy as np
from torchvision import models
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
X = (X - self.mean) / self.std
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
from torch import nn
class AntiAliasInterpolation2d(nn.Module):
"""
Band-limited downsampling, for better preservation of the input signal.
"""
def __init__(self, channels, scale):
super(AntiAliasInterpolation2d, self).__init__()
sigma = (1 / scale - 1) / 2
kernel_size = 2 * round(sigma * 4) + 1
self.ka = kernel_size // 2
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
kernel_size = [kernel_size, kernel_size]
sigma = [sigma, sigma]
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
self.scale = scale
inv_scale = 1 / scale
self.int_inv_scale = int(inv_scale)
def forward(self, input):
if self.scale == 1.0:
return input
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
out = F.conv2d(out, weight=self.weight, groups=self.groups)
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
return out
class ImagePyramide(torch.nn.Module):
"""
Create image pyramide for computing pyramide perceptual loss.
"""
def __init__(self, scales, num_channels):
super(ImagePyramide, self).__init__()
downs = {}
for scale in scales:
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
self.downs = nn.ModuleDict(downs)
def forward(self, x):
out_dict = {}
for scale, down_module in self.downs.items():
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
return out_dict