mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-04 17:39:20 +08:00
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
82 lines
3.3 KiB
Python
Executable File
82 lines
3.3 KiB
Python
Executable File
import torch
|
||
import torch.nn as nn
|
||
from omegaconf import OmegaConf
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch import nn, optim
|
||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||
from musetalk.loss.discriminator import MultiScaleDiscriminator,DiscriminatorFullModel
|
||
import musetalk.loss.vgg_face as vgg_face
|
||
|
||
class Interpolate(nn.Module):
|
||
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
||
super(Interpolate, self).__init__()
|
||
self.size = size
|
||
self.scale_factor = scale_factor
|
||
self.mode = mode
|
||
self.align_corners = align_corners
|
||
|
||
def forward(self, input):
|
||
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
|
||
|
||
def set_requires_grad(net, requires_grad=False):
|
||
if net is not None:
|
||
for param in net.parameters():
|
||
param.requires_grad = requires_grad
|
||
|
||
if __name__ == "__main__":
|
||
cfg = OmegaConf.load("config/audio_adapter/E7.yaml")
|
||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
pyramid_scale = [1, 0.5, 0.25, 0.125]
|
||
vgg_IN = vgg_face.Vgg19().to(device)
|
||
pyramid = vgg_face.ImagePyramide(cfg.loss_params.pyramid_scale, 3).to(device)
|
||
vgg_IN.eval()
|
||
downsampler = Interpolate(size=(224, 224), mode='bilinear', align_corners=False)
|
||
|
||
image = torch.rand(8, 3, 256, 256).to(device)
|
||
image_pred = torch.rand(8, 3, 256, 256).to(device)
|
||
pyramide_real = pyramid(downsampler(image))
|
||
pyramide_generated = pyramid(downsampler(image_pred))
|
||
|
||
|
||
loss_IN = 0
|
||
for scale in cfg.loss_params.pyramid_scale:
|
||
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
|
||
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
|
||
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
|
||
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
||
loss_IN += weight * value
|
||
loss_IN /= sum(cfg.loss_params.vgg_layer_weight) # 对vgg不同层取均值,金字塔loss是每层叠
|
||
print(loss_IN)
|
||
|
||
#print(cfg.model_params.discriminator_params)
|
||
|
||
discriminator = MultiScaleDiscriminator(**cfg.model_params.discriminator_params).to(device)
|
||
discriminator_full = DiscriminatorFullModel(discriminator)
|
||
disc_scales = cfg.model_params.discriminator_params.scales
|
||
# Prepare optimizer and loss function
|
||
optimizer_D = optim.AdamW(discriminator.parameters(),
|
||
lr=cfg.discriminator_train_params.lr,
|
||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||
betas=cfg.discriminator_train_params.betas,
|
||
eps=cfg.discriminator_train_params.eps)
|
||
scheduler_D = CosineAnnealingLR(optimizer_D,
|
||
T_max=cfg.discriminator_train_params.epochs,
|
||
eta_min=1e-6)
|
||
|
||
discriminator.train()
|
||
|
||
set_requires_grad(discriminator, False)
|
||
|
||
loss_G = 0.
|
||
discriminator_maps_generated = discriminator(pyramide_generated)
|
||
discriminator_maps_real = discriminator(pyramide_real)
|
||
|
||
for scale in disc_scales:
|
||
key = 'prediction_map_%s' % scale
|
||
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
||
loss_G += value
|
||
|
||
print(loss_G)
|