mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
145 lines
5.0 KiB
Python
Executable File
145 lines
5.0 KiB
Python
Executable File
from torch import nn
|
|
import torch.nn.functional as F
|
|
import torch
|
|
from musetalk.loss.vgg_face import ImagePyramide
|
|
|
|
class DownBlock2d(nn.Module):
|
|
"""
|
|
Simple block for processing video (encoder).
|
|
"""
|
|
|
|
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
|
|
super(DownBlock2d, self).__init__()
|
|
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
|
|
|
|
if sn:
|
|
self.conv = nn.utils.spectral_norm(self.conv)
|
|
|
|
if norm:
|
|
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
|
else:
|
|
self.norm = None
|
|
self.pool = pool
|
|
|
|
def forward(self, x):
|
|
out = x
|
|
out = self.conv(out)
|
|
if self.norm:
|
|
out = self.norm(out)
|
|
out = F.leaky_relu(out, 0.2)
|
|
if self.pool:
|
|
out = F.avg_pool2d(out, (2, 2))
|
|
return out
|
|
|
|
|
|
class Discriminator(nn.Module):
|
|
"""
|
|
Discriminator similar to Pix2Pix
|
|
"""
|
|
|
|
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
|
|
sn=False, **kwargs):
|
|
super(Discriminator, self).__init__()
|
|
|
|
down_blocks = []
|
|
for i in range(num_blocks):
|
|
down_blocks.append(
|
|
DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
|
min(max_features, block_expansion * (2 ** (i + 1))),
|
|
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
|
|
|
|
self.down_blocks = nn.ModuleList(down_blocks)
|
|
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
|
|
if sn:
|
|
self.conv = nn.utils.spectral_norm(self.conv)
|
|
|
|
def forward(self, x):
|
|
feature_maps = []
|
|
out = x
|
|
|
|
for down_block in self.down_blocks:
|
|
feature_maps.append(down_block(out))
|
|
out = feature_maps[-1]
|
|
prediction_map = self.conv(out)
|
|
|
|
return feature_maps, prediction_map
|
|
|
|
|
|
class MultiScaleDiscriminator(nn.Module):
|
|
"""
|
|
Multi-scale (scale) discriminator
|
|
"""
|
|
|
|
def __init__(self, scales=(), **kwargs):
|
|
super(MultiScaleDiscriminator, self).__init__()
|
|
self.scales = scales
|
|
discs = {}
|
|
for scale in scales:
|
|
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
|
|
self.discs = nn.ModuleDict(discs)
|
|
|
|
def forward(self, x):
|
|
out_dict = {}
|
|
for scale, disc in self.discs.items():
|
|
scale = str(scale).replace('-', '.')
|
|
key = 'prediction_' + scale
|
|
#print(key)
|
|
#print(x)
|
|
feature_maps, prediction_map = disc(x[key])
|
|
out_dict['feature_maps_' + scale] = feature_maps
|
|
out_dict['prediction_map_' + scale] = prediction_map
|
|
return out_dict
|
|
|
|
|
|
|
|
class DiscriminatorFullModel(torch.nn.Module):
|
|
"""
|
|
Merge all discriminator related updates into single model for better multi-gpu usage
|
|
"""
|
|
|
|
def __init__(self, discriminator):
|
|
super(DiscriminatorFullModel, self).__init__()
|
|
self.discriminator = discriminator
|
|
self.scales = self.discriminator.scales
|
|
print("scales",self.scales)
|
|
self.pyramid = ImagePyramide(self.scales, 3)
|
|
if torch.cuda.is_available():
|
|
self.pyramid = self.pyramid.cuda()
|
|
|
|
self.zero_tensor = None
|
|
|
|
def get_zero_tensor(self, input):
|
|
if self.zero_tensor is None:
|
|
self.zero_tensor = torch.FloatTensor(1).fill_(0).cuda()
|
|
self.zero_tensor.requires_grad_(False)
|
|
return self.zero_tensor.expand_as(input)
|
|
|
|
def forward(self, x, generated, gan_mode='ls'):
|
|
pyramide_real = self.pyramid(x)
|
|
pyramide_generated = self.pyramid(generated.detach())
|
|
|
|
discriminator_maps_generated = self.discriminator(pyramide_generated)
|
|
discriminator_maps_real = self.discriminator(pyramide_real)
|
|
|
|
value_total = 0
|
|
for scale in self.scales:
|
|
key = 'prediction_map_%s' % scale
|
|
if gan_mode == 'hinge':
|
|
value = -torch.mean(torch.min(discriminator_maps_real[key]-1, self.get_zero_tensor(discriminator_maps_real[key]))) - torch.mean(torch.min(-discriminator_maps_generated[key]-1, self.get_zero_tensor(discriminator_maps_generated[key])))
|
|
elif gan_mode == 'ls':
|
|
value = ((1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2).mean()
|
|
else:
|
|
raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode']))
|
|
|
|
value_total += value
|
|
|
|
return value_total
|
|
|
|
def main():
|
|
discriminator = MultiScaleDiscriminator(scales=[1],
|
|
block_expansion=32,
|
|
max_features=512,
|
|
num_blocks=4,
|
|
sn=True,
|
|
image_channel=3,
|
|
estimate_jacobian=False) |