mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 18:09:19 +08:00
feat: data preprocessing and training (#294)
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
This commit is contained in:
237
musetalk/loss/vgg_face.py
Executable file
237
musetalk/loss/vgg_face.py
Executable file
@@ -0,0 +1,237 @@
|
||||
'''
|
||||
This part of code contains a pretrained vgg_face model.
|
||||
ref link: https://github.com/prlz77/vgg-face.pytorch
|
||||
'''
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo
|
||||
import pickle
|
||||
from musetalk.loss import resnet as ResNet
|
||||
|
||||
|
||||
MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth"
|
||||
VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl'
|
||||
|
||||
# It was 93.5940, 104.7624, 129.1863 before dividing by 255
|
||||
MEAN_RGB = [
|
||||
0.367035294117647,
|
||||
0.41083294117647057,
|
||||
0.5066129411764705
|
||||
]
|
||||
def load_state_dict(model, fname):
|
||||
"""
|
||||
Set parameters converted from Caffe models authors of VGGFace2 provide.
|
||||
See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
|
||||
|
||||
Arguments:
|
||||
model: model
|
||||
fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
|
||||
"""
|
||||
with open(fname, 'rb') as f:
|
||||
weights = pickle.load(f, encoding='latin1')
|
||||
|
||||
own_state = model.state_dict()
|
||||
for name, param in weights.items():
|
||||
if name in own_state:
|
||||
try:
|
||||
own_state[name].copy_(torch.from_numpy(param))
|
||||
except Exception:
|
||||
raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
|
||||
'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
|
||||
else:
|
||||
raise KeyError('unexpected key "{}" in state_dict'.format(name))
|
||||
|
||||
|
||||
def vggface2(pretrained=True):
|
||||
vggface = ResNet.resnet50(num_classes=8631, include_top=True)
|
||||
load_state_dict(vggface, VGG_FACE_PATH)
|
||||
return vggface
|
||||
|
||||
def vggface(pretrained=False, **kwargs):
|
||||
"""VGGFace model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns pre-trained model
|
||||
"""
|
||||
model = VggFace(**kwargs)
|
||||
if pretrained:
|
||||
state = torch.utils.model_zoo.load_url(MODEL_URL)
|
||||
model.load_state_dict(state)
|
||||
return model
|
||||
|
||||
|
||||
class VggFace(torch.nn.Module):
|
||||
def __init__(self, classes=2622):
|
||||
"""VGGFace model.
|
||||
|
||||
Face recognition network. It takes as input a Bx3x224x224
|
||||
batch of face images and gives as output a BxC score vector
|
||||
(C is the number of identities).
|
||||
Input images need to be scaled in the 0-1 range and then
|
||||
normalized with respect to the mean RGB used during training.
|
||||
|
||||
Args:
|
||||
classes (int): number of identities recognized by the
|
||||
network
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.conv1 = _ConvBlock(3, 64, 64)
|
||||
self.conv2 = _ConvBlock(64, 128, 128)
|
||||
self.conv3 = _ConvBlock(128, 256, 256, 256)
|
||||
self.conv4 = _ConvBlock(256, 512, 512, 512)
|
||||
self.conv5 = _ConvBlock(512, 512, 512, 512)
|
||||
self.dropout = torch.nn.Dropout(0.5)
|
||||
self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096)
|
||||
self.fc2 = torch.nn.Linear(4096, 4096)
|
||||
self.fc3 = torch.nn.Linear(4096, classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.conv5(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.dropout(F.relu(self.fc1(x)))
|
||||
x = self.dropout(F.relu(self.fc2(x)))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class _ConvBlock(torch.nn.Module):
|
||||
"""A Convolutional block."""
|
||||
|
||||
def __init__(self, *units):
|
||||
"""Create a block with len(units) - 1 convolutions.
|
||||
|
||||
convolution number i transforms the number of channels from
|
||||
units[i - 1] to units[i] channels.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.convs = torch.nn.ModuleList([
|
||||
torch.nn.Conv2d(in_, out, 3, 1, 1)
|
||||
for in_, out in zip(units[:-1], units[1:])
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
# Each convolution is followed by a ReLU, then the block is
|
||||
# concluded by a max pooling.
|
||||
for c in self.convs:
|
||||
x = F.relu(c(x))
|
||||
return F.max_pool2d(x, 2, 2, 0, ceil_mode=True)
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
from torchvision import models
|
||||
class Vgg19(torch.nn.Module):
|
||||
"""
|
||||
Vgg19 network for perceptual loss.
|
||||
"""
|
||||
def __init__(self, requires_grad=False):
|
||||
super(Vgg19, self).__init__()
|
||||
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(2, 7):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(7, 12):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(12, 21):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(21, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
|
||||
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
||||
requires_grad=False)
|
||||
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
||||
requires_grad=False)
|
||||
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
X = (X - self.mean) / self.std
|
||||
h_relu1 = self.slice1(X)
|
||||
h_relu2 = self.slice2(h_relu1)
|
||||
h_relu3 = self.slice3(h_relu2)
|
||||
h_relu4 = self.slice4(h_relu3)
|
||||
h_relu5 = self.slice5(h_relu4)
|
||||
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
||||
return out
|
||||
|
||||
|
||||
from torch import nn
|
||||
class AntiAliasInterpolation2d(nn.Module):
|
||||
"""
|
||||
Band-limited downsampling, for better preservation of the input signal.
|
||||
"""
|
||||
def __init__(self, channels, scale):
|
||||
super(AntiAliasInterpolation2d, self).__init__()
|
||||
sigma = (1 / scale - 1) / 2
|
||||
kernel_size = 2 * round(sigma * 4) + 1
|
||||
self.ka = kernel_size // 2
|
||||
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
||||
|
||||
kernel_size = [kernel_size, kernel_size]
|
||||
sigma = [sigma, sigma]
|
||||
# The gaussian kernel is the product of the
|
||||
# gaussian function of each dimension.
|
||||
kernel = 1
|
||||
meshgrids = torch.meshgrid(
|
||||
[
|
||||
torch.arange(size, dtype=torch.float32)
|
||||
for size in kernel_size
|
||||
]
|
||||
)
|
||||
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
||||
mean = (size - 1) / 2
|
||||
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
||||
|
||||
# Make sure sum of values in gaussian kernel equals 1.
|
||||
kernel = kernel / torch.sum(kernel)
|
||||
# Reshape to depthwise convolutional weight
|
||||
kernel = kernel.view(1, 1, *kernel.size())
|
||||
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
||||
|
||||
self.register_buffer('weight', kernel)
|
||||
self.groups = channels
|
||||
self.scale = scale
|
||||
inv_scale = 1 / scale
|
||||
self.int_inv_scale = int(inv_scale)
|
||||
|
||||
def forward(self, input):
|
||||
if self.scale == 1.0:
|
||||
return input
|
||||
|
||||
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
||||
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
||||
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ImagePyramide(torch.nn.Module):
|
||||
"""
|
||||
Create image pyramide for computing pyramide perceptual loss.
|
||||
"""
|
||||
def __init__(self, scales, num_channels):
|
||||
super(ImagePyramide, self).__init__()
|
||||
downs = {}
|
||||
for scale in scales:
|
||||
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
||||
self.downs = nn.ModuleDict(downs)
|
||||
|
||||
def forward(self, x):
|
||||
out_dict = {}
|
||||
for scale, down_module in self.downs.items():
|
||||
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
||||
return out_dict
|
||||
Reference in New Issue
Block a user