mirror of
https://github.com/TMElyralab/MuseTalk.git
synced 2026-02-05 18:09:19 +08:00
feat: data preprocessing and training (#294)
* docs: update readme * docs: update readme * feat: training codes * feat: data preprocess * docs: release training
This commit is contained in:
44
musetalk/loss/conv.py
Executable file
44
musetalk/loss/conv.py
Executable file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||
nn.BatchNorm2d(cout)
|
||||
)
|
||||
self.act = nn.ReLU()
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
if self.residual:
|
||||
out += x
|
||||
return self.act(out)
|
||||
|
||||
class nonorm_Conv2d(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||
)
|
||||
self.act = nn.LeakyReLU(0.01, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
return self.act(out)
|
||||
|
||||
class Conv2dTranspose(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
|
||||
nn.BatchNorm2d(cout)
|
||||
)
|
||||
self.act = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
return self.act(out)
|
||||
Reference in New Issue
Block a user