import math import os.path import torch import torch.nn as nn import torch.nn.functional as F import torchaudio as ta from models.encoder.wav2vec import Wav2Vec2Model from models.encoder.wavlm import WavLMModel from models.builder import MODELS from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config @MODELS.register_module("Audio2Expression") class Audio2Expression(nn.Module): def __init__(self, device: torch.device = None, pretrained_encoder_type: str = 'wav2vec', pretrained_encoder_path: str = '', wav2vec2_config_path: str = '', num_identity_classes: int = 0, identity_feat_dim: int = 64, hidden_dim: int = 512, expression_dim: int = 52, norm_type: str = 'ln', decoder_depth: int = 3, use_transformer: bool = False, num_attention_heads: int = 8, num_transformer_layers: int = 6, ): super().__init__() self.device = device # Initialize audio feature encoder if pretrained_encoder_type == 'wav2vec': if os.path.exists(pretrained_encoder_path): self.audio_encoder = Wav2Vec2Model.from_pretrained(pretrained_encoder_path) else: config = Wav2Vec2Config.from_pretrained(wav2vec2_config_path) self.audio_encoder = Wav2Vec2Model(config) encoder_output_dim = 768 elif pretrained_encoder_type == 'wavlm': self.audio_encoder = WavLMModel.from_pretrained(pretrained_encoder_path) encoder_output_dim = 768 else: raise NotImplementedError(f"Encoder type {pretrained_encoder_type} not supported") self.audio_encoder.feature_extractor._freeze_parameters() self.feature_projection = nn.Linear(encoder_output_dim, hidden_dim) self.identity_encoder = AudioIdentityEncoder( hidden_dim, num_identity_classes, identity_feat_dim, use_transformer, num_attention_heads, num_transformer_layers ) self.decoder = nn.ModuleList([ nn.Sequential(*[ ConvNormRelu(hidden_dim, hidden_dim, norm=norm_type) for _ in range(decoder_depth) ]) ]) self.output_proj = nn.Linear(hidden_dim, expression_dim) def freeze_encoder_parameters(self, do_freeze=False): for name, param in self.audio_encoder.named_parameters(): if('feature_extractor' in name): param.requires_grad = False else: param.requires_grad = (not do_freeze) def forward(self, input_dict): if 'time_steps' not in input_dict: audio_length = input_dict['input_audio_array'].shape[1] time_steps = math.ceil(audio_length / 16000 * 30) else: time_steps = input_dict['time_steps'] # Process audio through encoder audio_input = input_dict['input_audio_array'].flatten(start_dim=1) hidden_states = self.audio_encoder(audio_input, frame_num=time_steps).last_hidden_state # Project features to hidden dimension audio_features = self.feature_projection(hidden_states).transpose(1, 2) # Process identity-conditioned features audio_features = self.identity_encoder(audio_features, identity=input_dict['id_idx']) # Refine features through decoder audio_features = self.decoder[0](audio_features) # Generate output parameters audio_features = audio_features.permute(0, 2, 1) expression_params = self.output_proj(audio_features) return torch.sigmoid(expression_params) class AudioIdentityEncoder(nn.Module): def __init__(self, hidden_dim, num_identity_classes=0, identity_feat_dim=64, use_transformer=False, num_attention_heads = 8, num_transformer_layers = 6, dropout_ratio=0.1, ): super().__init__() in_dim = hidden_dim + identity_feat_dim self.id_mlp = nn.Conv1d(num_identity_classes, identity_feat_dim, 1, 1) self.first_net = SeqTranslator1D(in_dim, hidden_dim, min_layers_num=3, residual=True, norm='ln' ) self.grus = nn.GRU(hidden_dim, hidden_dim, 1, batch_first=True) self.dropout = nn.Dropout(dropout_ratio) self.use_transformer = use_transformer if(self.use_transformer): encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_attention_heads, dim_feedforward= 2 * hidden_dim, batch_first=True) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers) def forward(self, audio_features: torch.Tensor, identity: torch.Tensor = None, time_steps: int = None) -> tuple: audio_features = self.dropout(audio_features) identity = identity.reshape(identity.shape[0], -1, 1).repeat(1, 1, audio_features.shape[2]).to(torch.float32) identity = self.id_mlp(identity) audio_features = torch.cat([audio_features, identity], dim=1) x = self.first_net(audio_features) if time_steps is not None: x = F.interpolate(x, size=time_steps, align_corners=False, mode='linear') if(self.use_transformer): x = x.permute(0, 2, 1) x = self.transformer_encoder(x) x = x.permute(0, 2, 1) return x class ConvNormRelu(nn.Module): ''' (B,C_in,H,W) -> (B, C_out, H, W) there exist some kernel size that makes the result is not H/s ''' def __init__(self, in_channels, out_channels, type='1d', leaky=False, downsample=False, kernel_size=None, stride=None, padding=None, p=0, groups=1, residual=False, norm='bn'): ''' conv-bn-relu ''' super(ConvNormRelu, self).__init__() self.residual = residual self.norm_type = norm # kernel_size = k # stride = s if kernel_size is None and stride is None: if not downsample: kernel_size = 3 stride = 1 else: kernel_size = 4 stride = 2 if padding is None: if isinstance(kernel_size, int) and isinstance(stride, tuple): padding = tuple(int((kernel_size - st) / 2) for st in stride) elif isinstance(kernel_size, tuple) and isinstance(stride, int): padding = tuple(int((ks - stride) / 2) for ks in kernel_size) elif isinstance(kernel_size, tuple) and isinstance(stride, tuple): padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride)) else: padding = int((kernel_size - stride) / 2) if self.residual: if downsample: if type == '1d': self.residual_layer = nn.Sequential( nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) ) elif type == '2d': self.residual_layer = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) ) else: if in_channels == out_channels: self.residual_layer = nn.Identity() else: if type == '1d': self.residual_layer = nn.Sequential( nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) ) elif type == '2d': self.residual_layer = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) ) in_channels = in_channels * groups out_channels = out_channels * groups if type == '1d': self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) self.norm = nn.BatchNorm1d(out_channels) self.dropout = nn.Dropout(p=p) elif type == '2d': self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) self.norm = nn.BatchNorm2d(out_channels) self.dropout = nn.Dropout2d(p=p) if norm == 'gn': self.norm = nn.GroupNorm(2, out_channels) elif norm == 'ln': self.norm = nn.LayerNorm(out_channels) if leaky: self.relu = nn.LeakyReLU(negative_slope=0.2) else: self.relu = nn.ReLU() def forward(self, x, **kwargs): if self.norm_type == 'ln': out = self.dropout(self.conv(x)) out = self.norm(out.transpose(1,2)).transpose(1,2) else: out = self.norm(self.dropout(self.conv(x))) if self.residual: residual = self.residual_layer(x) out += residual return self.relu(out) """ from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """ class SeqTranslator1D(nn.Module): ''' (B, C, T)->(B, C_out, T) ''' def __init__(self, C_in, C_out, kernel_size=None, stride=None, min_layers_num=None, residual=True, norm='bn' ): super(SeqTranslator1D, self).__init__() conv_layers = nn.ModuleList([]) conv_layers.append(ConvNormRelu( in_channels=C_in, out_channels=C_out, type='1d', kernel_size=kernel_size, stride=stride, residual=residual, norm=norm )) self.num_layers = 1 if min_layers_num is not None and self.num_layers < min_layers_num: while self.num_layers < min_layers_num: conv_layers.append(ConvNormRelu( in_channels=C_out, out_channels=C_out, type='1d', kernel_size=kernel_size, stride=stride, residual=residual, norm=norm )) self.num_layers += 1 self.conv_layers = nn.Sequential(*conv_layers) def forward(self, x): return self.conv_layers(x) def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000): """ :param audio: 1 x T tensor containing a 16kHz audio signal :param frame_rate: frame rate for video (we need one audio chunk per video frame) :param chunk_size: number of audio samples per chunk :return: num_chunks x chunk_size tensor containing sliced audio """ samples_per_frame = 16000 // frame_rate padding = (chunk_size - samples_per_frame) // 2 audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0) anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame)) audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0) return audio """ https://github.com/facebookresearch/meshtalk """ class MeshtalkEncoder(nn.Module): def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'): """ :param latent_dim: size of the latent audio embedding :param model_name: name of the model, used to load and save the model """ super().__init__() self.melspec = ta.transforms.MelSpectrogram( sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80 ) conv_len = 5 self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len) self.weights_init(self.convert_dimensions) self.receptive_field = conv_len convs = [] for i in range(6): dilation = 2 * (i % 3 + 1) self.receptive_field += (conv_len - 1) * dilation convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)] self.weights_init(convs[-1]) self.convs = torch.nn.ModuleList(convs) self.code = torch.nn.Linear(128, latent_dim) self.apply(lambda x: self.weights_init(x)) def weights_init(self, m): if isinstance(m, torch.nn.Conv1d): torch.nn.init.xavier_uniform_(m.weight) try: torch.nn.init.constant_(m.bias, .01) except: pass def forward(self, audio: torch.Tensor): """ :param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame :return: code: B x T x latent_dim Tensor containing a latent audio code/embedding """ B, T = audio.shape[0], audio.shape[1] x = self.melspec(audio).squeeze(1) x = torch.log(x.clamp(min=1e-10, max=None)) if T == 1: x = x.unsqueeze(1) # Convert to the right dimensionality x = x.view(-1, x.shape[2], x.shape[3]) x = F.leaky_relu(self.convert_dimensions(x), .2) # Process stacks for conv in self.convs: x_ = F.leaky_relu(conv(x), .2) if self.training: x_ = F.dropout(x_, .2) l = (x.shape[2] - x_.shape[2]) // 2 x = (x[:, :, l:-l] + x_) / 2 x = torch.mean(x, dim=-1) x = x.view(B, T, x.shape[-1]) x = self.code(x) return {"code": x} class PeriodicPositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=64): super(PeriodicPositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(period, d_model) position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, period, d_model) repeat_num = (max_seq_len//period) + 1 pe = pe.repeat(1, repeat_num, 1) # (1, repeat_num, period, d_model) self.register_buffer('pe', pe) def forward(self, x): # print(self.pe.shape, x.shape) x = x + self.pe[:, :x.size(1), :] return self.dropout(x) class GeneratorTransformer(nn.Module): def __init__(self, n_poses, each_dim: list, dim_list: list, training=True, device=None, identity=False, num_classes=0, ): super().__init__() self.training = training self.device = device self.gen_length = n_poses norm = 'ln' in_dim = 256 out_dim = 256 self.encoder_choice = 'faceformer' self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h" self.audio_encoder.feature_extractor._freeze_parameters() self.audio_feature_map = nn.Linear(768, in_dim) self.audio_middle = AudioEncoder(in_dim, out_dim, False, num_classes) self.dim_list = dim_list self.decoder = nn.ModuleList() self.final_out = nn.ModuleList() self.hidden_size = 768 self.transformer_de_layer = nn.TransformerDecoderLayer( d_model=self.hidden_size, nhead=4, dim_feedforward=self.hidden_size*2, batch_first=True ) self.face_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=4) self.feature2face = nn.Linear(256, self.hidden_size) self.position_embeddings = PeriodicPositionalEncoding(self.hidden_size, period=64, max_seq_len=64) self.id_maping = nn.Linear(12,self.hidden_size) self.decoder.append(self.face_decoder) self.final_out.append(nn.Linear(self.hidden_size, 32)) def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None): if gt_poses is None: time_steps = 64 else: time_steps = gt_poses.shape[1] # vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) if self.encoder_choice == 'meshtalk': in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000) feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2) elif self.encoder_choice == 'faceformer': hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state feature = self.audio_feature_map(hidden_states).transpose(1, 2) else: feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) feature, _ = self.audio_middle(feature, id=None) feature = self.feature2face(feature.permute(0,2,1)) id = id.unsqueeze(1).repeat(1,64,1).to(torch.float32) id_feature = self.id_maping(id) id_feature = self.position_embeddings(id_feature) for i in range(self.decoder.__len__()): mid = self.decoder[i](tgt=id_feature, memory=feature) out = self.final_out[i](mid) return out, None def linear_interpolation(features, output_len: int): features = features.transpose(1, 2) output_features = F.interpolate( features, size=output_len, align_corners=True, mode='linear') return output_features.transpose(1, 2) def init_biased_mask(n_head, max_seq_len, period): def get_slopes(n): def get_slopes_power_of_2(n): start = (2**(-2**-(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: closest_power_of_2 = 2**math.floor(math.log2(n)) return get_slopes_power_of_2(closest_power_of_2) + get_slopes( 2 * closest_power_of_2)[0::2][:n - closest_power_of_2] slopes = torch.Tensor(get_slopes(n_head)) bias = torch.div( torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1, period).view(-1), period, rounding_mode='floor') bias = -torch.flip(bias, dims=[0]) alibi = torch.zeros(max_seq_len, max_seq_len) for i in range(max_seq_len): alibi[i, :i + 1] = bias[-(i + 1):] alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0) mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill( mask == 1, float(0.0)) mask = mask.unsqueeze(0) + alibi return mask # Alignment Bias def enc_dec_mask(device, T, S): mask = torch.ones(T, S) for i in range(T): mask[i, i] = 0 return (mask == 1).to(device=device) # Periodic Positional Encoding class PeriodicPositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000): super(PeriodicPositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(period, d_model) position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, period, d_model) repeat_num = (max_seq_len // period) + 1 pe = pe.repeat(1, repeat_num, 1) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:, :x.size(1), :] return self.dropout(x) class BaseModel(nn.Module): """Base class for all models.""" def __init__(self): super(BaseModel, self).__init__() # self.logger = logging.getLogger(self.__class__.__name__) def forward(self, *x): """Forward pass logic. :return: Model output """ raise NotImplementedError def freeze_model(self, do_freeze: bool = True): for param in self.parameters(): param.requires_grad = (not do_freeze) def summary(self, logger, writer=None): """Model summary.""" model_parameters = filter(lambda p: p.requires_grad, self.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) / 1e6 # Unit is Mega logger.info('===>Trainable parameters: %.3f M' % params) if writer is not None: writer.add_text('Model Summary', 'Trainable parameters: %.3f M' % params) """https://github.com/X-niper/UniTalker""" class UniTalkerDecoderTransformer(BaseModel): def __init__(self, out_dim, identity_num, period=30, interpolate_pos=1) -> None: super().__init__() self.learnable_style_emb = nn.Embedding(identity_num, out_dim) self.PPE = PeriodicPositionalEncoding( out_dim, period=period, max_seq_len=3000) self.biased_mask = init_biased_mask( n_head=4, max_seq_len=3000, period=period) decoder_layer = nn.TransformerDecoderLayer( d_model=out_dim, nhead=4, dim_feedforward=2 * out_dim, batch_first=True) self.transformer_decoder = nn.TransformerDecoder( decoder_layer, num_layers=1) self.interpolate_pos = interpolate_pos def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor, frame_num: int): style_idx = torch.argmax(style_idx, dim=1) obj_embedding = self.learnable_style_emb(style_idx) obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1) style_input = self.PPE(obj_embedding) tgt_mask = self.biased_mask.repeat(style_idx.shape[0], 1, 1)[:, :style_input.shape[1], :style_input. shape[1]].clone().detach().to( device=style_input.device) memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1], frame_num) feat_out = self.transformer_decoder( style_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask) if self.interpolate_pos == 2: feat_out = linear_interpolation(feat_out, output_len=frame_num) return feat_out