mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 17:39:20 +08:00
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
237 lines
8.1 KiB
Python
Executable File
237 lines
8.1 KiB
Python
Executable File
'''
|
|
This part of code contains a pretrained vgg_face model.
|
|
ref link: https://github.com/prlz77/vgg-face.pytorch
|
|
'''
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.model_zoo
|
|
import pickle
|
|
from musetalk.loss import resnet as ResNet
|
|
|
|
|
|
MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth"
|
|
VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl'
|
|
|
|
# It was 93.5940, 104.7624, 129.1863 before dividing by 255
|
|
MEAN_RGB = [
|
|
0.367035294117647,
|
|
0.41083294117647057,
|
|
0.5066129411764705
|
|
]
|
|
def load_state_dict(model, fname):
|
|
"""
|
|
Set parameters converted from Caffe models authors of VGGFace2 provide.
|
|
See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
|
|
|
|
Arguments:
|
|
model: model
|
|
fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
|
|
"""
|
|
with open(fname, 'rb') as f:
|
|
weights = pickle.load(f, encoding='latin1')
|
|
|
|
own_state = model.state_dict()
|
|
for name, param in weights.items():
|
|
if name in own_state:
|
|
try:
|
|
own_state[name].copy_(torch.from_numpy(param))
|
|
except Exception:
|
|
raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
|
|
'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
|
|
else:
|
|
raise KeyError('unexpected key "{}" in state_dict'.format(name))
|
|
|
|
|
|
def vggface2(pretrained=True):
|
|
vggface = ResNet.resnet50(num_classes=8631, include_top=True)
|
|
load_state_dict(vggface, VGG_FACE_PATH)
|
|
return vggface
|
|
|
|
def vggface(pretrained=False, **kwargs):
|
|
"""VGGFace model.
|
|
|
|
Args:
|
|
pretrained (bool): If True, returns pre-trained model
|
|
"""
|
|
model = VggFace(**kwargs)
|
|
if pretrained:
|
|
state = torch.utils.model_zoo.load_url(MODEL_URL)
|
|
model.load_state_dict(state)
|
|
return model
|
|
|
|
|
|
class VggFace(torch.nn.Module):
|
|
def __init__(self, classes=2622):
|
|
"""VGGFace model.
|
|
|
|
Face recognition network. It takes as input a Bx3x224x224
|
|
batch of face images and gives as output a BxC score vector
|
|
(C is the number of identities).
|
|
Input images need to be scaled in the 0-1 range and then
|
|
normalized with respect to the mean RGB used during training.
|
|
|
|
Args:
|
|
classes (int): number of identities recognized by the
|
|
network
|
|
|
|
"""
|
|
super().__init__()
|
|
self.conv1 = _ConvBlock(3, 64, 64)
|
|
self.conv2 = _ConvBlock(64, 128, 128)
|
|
self.conv3 = _ConvBlock(128, 256, 256, 256)
|
|
self.conv4 = _ConvBlock(256, 512, 512, 512)
|
|
self.conv5 = _ConvBlock(512, 512, 512, 512)
|
|
self.dropout = torch.nn.Dropout(0.5)
|
|
self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096)
|
|
self.fc2 = torch.nn.Linear(4096, 4096)
|
|
self.fc3 = torch.nn.Linear(4096, classes)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.conv2(x)
|
|
x = self.conv3(x)
|
|
x = self.conv4(x)
|
|
x = self.conv5(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.dropout(F.relu(self.fc1(x)))
|
|
x = self.dropout(F.relu(self.fc2(x)))
|
|
x = self.fc3(x)
|
|
return x
|
|
|
|
|
|
class _ConvBlock(torch.nn.Module):
|
|
"""A Convolutional block."""
|
|
|
|
def __init__(self, *units):
|
|
"""Create a block with len(units) - 1 convolutions.
|
|
|
|
convolution number i transforms the number of channels from
|
|
units[i - 1] to units[i] channels.
|
|
|
|
"""
|
|
super().__init__()
|
|
self.convs = torch.nn.ModuleList([
|
|
torch.nn.Conv2d(in_, out, 3, 1, 1)
|
|
for in_, out in zip(units[:-1], units[1:])
|
|
])
|
|
|
|
def forward(self, x):
|
|
# Each convolution is followed by a ReLU, then the block is
|
|
# concluded by a max pooling.
|
|
for c in self.convs:
|
|
x = F.relu(c(x))
|
|
return F.max_pool2d(x, 2, 2, 0, ceil_mode=True)
|
|
|
|
|
|
|
|
import numpy as np
|
|
from torchvision import models
|
|
class Vgg19(torch.nn.Module):
|
|
"""
|
|
Vgg19 network for perceptual loss.
|
|
"""
|
|
def __init__(self, requires_grad=False):
|
|
super(Vgg19, self).__init__()
|
|
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
|
self.slice1 = torch.nn.Sequential()
|
|
self.slice2 = torch.nn.Sequential()
|
|
self.slice3 = torch.nn.Sequential()
|
|
self.slice4 = torch.nn.Sequential()
|
|
self.slice5 = torch.nn.Sequential()
|
|
for x in range(2):
|
|
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
|
for x in range(2, 7):
|
|
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
|
for x in range(7, 12):
|
|
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
|
for x in range(12, 21):
|
|
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
|
for x in range(21, 30):
|
|
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
|
|
|
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
|
requires_grad=False)
|
|
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
|
requires_grad=False)
|
|
|
|
if not requires_grad:
|
|
for param in self.parameters():
|
|
param.requires_grad = False
|
|
|
|
def forward(self, X):
|
|
X = (X - self.mean) / self.std
|
|
h_relu1 = self.slice1(X)
|
|
h_relu2 = self.slice2(h_relu1)
|
|
h_relu3 = self.slice3(h_relu2)
|
|
h_relu4 = self.slice4(h_relu3)
|
|
h_relu5 = self.slice5(h_relu4)
|
|
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
|
return out
|
|
|
|
|
|
from torch import nn
|
|
class AntiAliasInterpolation2d(nn.Module):
|
|
"""
|
|
Band-limited downsampling, for better preservation of the input signal.
|
|
"""
|
|
def __init__(self, channels, scale):
|
|
super(AntiAliasInterpolation2d, self).__init__()
|
|
sigma = (1 / scale - 1) / 2
|
|
kernel_size = 2 * round(sigma * 4) + 1
|
|
self.ka = kernel_size // 2
|
|
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
|
|
|
kernel_size = [kernel_size, kernel_size]
|
|
sigma = [sigma, sigma]
|
|
# The gaussian kernel is the product of the
|
|
# gaussian function of each dimension.
|
|
kernel = 1
|
|
meshgrids = torch.meshgrid(
|
|
[
|
|
torch.arange(size, dtype=torch.float32)
|
|
for size in kernel_size
|
|
]
|
|
)
|
|
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
|
mean = (size - 1) / 2
|
|
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
|
|
|
# Make sure sum of values in gaussian kernel equals 1.
|
|
kernel = kernel / torch.sum(kernel)
|
|
# Reshape to depthwise convolutional weight
|
|
kernel = kernel.view(1, 1, *kernel.size())
|
|
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
|
|
|
self.register_buffer('weight', kernel)
|
|
self.groups = channels
|
|
self.scale = scale
|
|
inv_scale = 1 / scale
|
|
self.int_inv_scale = int(inv_scale)
|
|
|
|
def forward(self, input):
|
|
if self.scale == 1.0:
|
|
return input
|
|
|
|
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
|
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
|
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
|
|
|
|
return out
|
|
|
|
|
|
class ImagePyramide(torch.nn.Module):
|
|
"""
|
|
Create image pyramide for computing pyramide perceptual loss.
|
|
"""
|
|
def __init__(self, scales, num_channels):
|
|
super(ImagePyramide, self).__init__()
|
|
downs = {}
|
|
for scale in scales:
|
|
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
|
self.downs = nn.ModuleDict(downs)
|
|
|
|
def forward(self, x):
|
|
out_dict = {}
|
|
for scale, down_module in self.downs.items():
|
|
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
|
return out_dict |