''' This part of code contains a pretrained vgg_face model. ref link: https://github.com/prlz77/vgg-face.pytorch ''' import torch import torch.nn.functional as F import torch.utils.model_zoo import pickle from musetalk.loss import resnet as ResNet MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth" VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl' # It was 93.5940, 104.7624, 129.1863 before dividing by 255 MEAN_RGB = [ 0.367035294117647, 0.41083294117647057, 0.5066129411764705 ] def load_state_dict(model, fname): """ Set parameters converted from Caffe models authors of VGGFace2 provide. See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/. Arguments: model: model fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle. """ with open(fname, 'rb') as f: weights = pickle.load(f, encoding='latin1') own_state = model.state_dict() for name, param in weights.items(): if name in own_state: try: own_state[name].copy_(torch.from_numpy(param)) except Exception: raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\ 'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size())) else: raise KeyError('unexpected key "{}" in state_dict'.format(name)) def vggface2(pretrained=True): vggface = ResNet.resnet50(num_classes=8631, include_top=True) load_state_dict(vggface, VGG_FACE_PATH) return vggface def vggface(pretrained=False, **kwargs): """VGGFace model. Args: pretrained (bool): If True, returns pre-trained model """ model = VggFace(**kwargs) if pretrained: state = torch.utils.model_zoo.load_url(MODEL_URL) model.load_state_dict(state) return model class VggFace(torch.nn.Module): def __init__(self, classes=2622): """VGGFace model. Face recognition network. It takes as input a Bx3x224x224 batch of face images and gives as output a BxC score vector (C is the number of identities). Input images need to be scaled in the 0-1 range and then normalized with respect to the mean RGB used during training. Args: classes (int): number of identities recognized by the network """ super().__init__() self.conv1 = _ConvBlock(3, 64, 64) self.conv2 = _ConvBlock(64, 128, 128) self.conv3 = _ConvBlock(128, 256, 256, 256) self.conv4 = _ConvBlock(256, 512, 512, 512) self.conv5 = _ConvBlock(512, 512, 512, 512) self.dropout = torch.nn.Dropout(0.5) self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096) self.fc2 = torch.nn.Linear(4096, 4096) self.fc3 = torch.nn.Linear(4096, classes) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.conv5(x) x = x.view(x.size(0), -1) x = self.dropout(F.relu(self.fc1(x))) x = self.dropout(F.relu(self.fc2(x))) x = self.fc3(x) return x class _ConvBlock(torch.nn.Module): """A Convolutional block.""" def __init__(self, *units): """Create a block with len(units) - 1 convolutions. convolution number i transforms the number of channels from units[i - 1] to units[i] channels. """ super().__init__() self.convs = torch.nn.ModuleList([ torch.nn.Conv2d(in_, out, 3, 1, 1) for in_, out in zip(units[:-1], units[1:]) ]) def forward(self, x): # Each convolution is followed by a ReLU, then the block is # concluded by a max pooling. for c in self.convs: x = F.relu(c(x)) return F.max_pool2d(x, 2, 2, 0, ceil_mode=True) import numpy as np from torchvision import models class Vgg19(torch.nn.Module): """ Vgg19 network for perceptual loss. """ def __init__(self, requires_grad=False): super(Vgg19, self).__init__() vgg_pretrained_features = models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), requires_grad=False) self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), requires_grad=False) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): X = (X - self.mean) / self.std h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out from torch import nn class AntiAliasInterpolation2d(nn.Module): """ Band-limited downsampling, for better preservation of the input signal. """ def __init__(self, channels, scale): super(AntiAliasInterpolation2d, self).__init__() sigma = (1 / scale - 1) / 2 kernel_size = 2 * round(sigma * 4) + 1 self.ka = kernel_size // 2 self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka kernel_size = [kernel_size, kernel_size] sigma = [sigma, sigma] # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid( [ torch.arange(size, dtype=torch.float32) for size in kernel_size ] ) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) self.register_buffer('weight', kernel) self.groups = channels self.scale = scale inv_scale = 1 / scale self.int_inv_scale = int(inv_scale) def forward(self, input): if self.scale == 1.0: return input out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) out = F.conv2d(out, weight=self.weight, groups=self.groups) out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] return out class ImagePyramide(torch.nn.Module): """ Create image pyramide for computing pyramide perceptual loss. """ def __init__(self, scales, num_channels): super(ImagePyramide, self).__init__() downs = {} for scale in scales: downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) self.downs = nn.ModuleDict(downs) def forward(self, x): out_dict = {} for scale, down_module in self.downs.items(): out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) return out_dict