mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 01:49:20 +08:00
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
95 lines
4.1 KiB
Python
Executable File
95 lines
4.1 KiB
Python
Executable File
import torch
|
||
from torch import nn
|
||
from torch.nn import functional as F
|
||
|
||
from .conv import Conv2d
|
||
|
||
logloss = nn.BCELoss(reduction="none")
|
||
def cosine_loss(a, v, y):
|
||
d = nn.functional.cosine_similarity(a, v)
|
||
d = d.clamp(0,1) # cosine_similarity的取值范围是【-1,1】,BCE如果输入负数会报错RuntimeError: CUDA error: device-side assert triggered
|
||
loss = logloss(d.unsqueeze(1), y).squeeze()
|
||
loss = loss.mean()
|
||
return loss, d
|
||
|
||
def get_sync_loss(
|
||
audio_embed,
|
||
gt_frames,
|
||
pred_frames,
|
||
syncnet,
|
||
adapted_weight,
|
||
frames_left_index=0,
|
||
frames_right_index=16,
|
||
):
|
||
# 跟gt_frames做随机的插入交换,节省显存开销
|
||
assert pred_frames.shape[1] == (frames_right_index - frames_left_index) * 3
|
||
# 3通道图像
|
||
frames_sync_loss = torch.cat(
|
||
[gt_frames[:, :3 * frames_left_index, ...], pred_frames, gt_frames[:, 3 * frames_right_index:, ...]],
|
||
axis=1
|
||
)
|
||
vision_embed = syncnet.get_image_embed(frames_sync_loss)
|
||
y = torch.ones(frames_sync_loss.size(0), 1).float().to(audio_embed.device)
|
||
loss, score = cosine_loss(audio_embed, vision_embed, y)
|
||
return loss, score
|
||
|
||
class SyncNet_color(nn.Module):
|
||
def __init__(self):
|
||
super(SyncNet_color, self).__init__()
|
||
|
||
self.face_encoder = nn.Sequential(
|
||
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
|
||
|
||
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
|
||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||
|
||
Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||
|
||
Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||
|
||
Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
||
|
||
Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||
Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
|
||
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
||
|
||
self.audio_encoder = nn.Sequential(
|
||
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||
|
||
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||
|
||
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||
|
||
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||
|
||
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
||
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
||
|
||
def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
|
||
face_embedding = self.face_encoder(face_sequences)
|
||
audio_embedding = self.audio_encoder(audio_sequences)
|
||
|
||
audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
|
||
face_embedding = face_embedding.view(face_embedding.size(0), -1)
|
||
|
||
audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
|
||
face_embedding = F.normalize(face_embedding, p=2, dim=1)
|
||
|
||
|
||
return audio_embedding, face_embedding |