mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-04 17:39:24 +08:00
feat: Initial commit
This commit is contained in:
7
models/__init__.py
Normal file
7
models/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .builder import build_model
|
||||
|
||||
from .default import DefaultEstimator
|
||||
|
||||
# Backbones
|
||||
from .network import Audio2Expression
|
||||
|
||||
13
models/builder.py
Normal file
13
models/builder.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Modified by https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
from utils.registry import Registry
|
||||
|
||||
MODELS = Registry("models")
|
||||
MODULES = Registry("modules")
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
"""Build models."""
|
||||
return MODELS.build(cfg)
|
||||
25
models/default.py
Normal file
25
models/default.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from models.losses import build_criteria
|
||||
from .builder import MODELS, build_model
|
||||
|
||||
@MODELS.register_module()
|
||||
class DefaultEstimator(nn.Module):
|
||||
def __init__(self, backbone=None, criteria=None):
|
||||
super().__init__()
|
||||
self.backbone = build_model(backbone)
|
||||
self.criteria = build_criteria(criteria)
|
||||
|
||||
def forward(self, input_dict):
|
||||
pred_exp = self.backbone(input_dict)
|
||||
# train
|
||||
if self.training:
|
||||
loss = self.criteria(pred_exp, input_dict["gt_exp"])
|
||||
return dict(loss=loss)
|
||||
# eval
|
||||
elif "gt_exp" in input_dict.keys():
|
||||
loss = self.criteria(pred_exp, input_dict["gt_exp"])
|
||||
return dict(loss=loss, pred_exp=pred_exp)
|
||||
# infer
|
||||
else:
|
||||
return dict(pred_exp=pred_exp)
|
||||
248
models/encoder/wav2vec.py
Normal file
248
models/encoder/wav2vec.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import numpy as np
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from dataclasses import dataclass
|
||||
from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
from transformers.file_utils import ModelOutput
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||
_HIDDEN_STATES_START_POSITION = 2
|
||||
|
||||
|
||||
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
|
||||
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
||||
def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
min_masks: int = 0,
|
||||
) -> np.ndarray:
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
all_num_mask = int(
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
mask_idcs = []
|
||||
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
|
||||
for i in range(bsz):
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
num_mask = int(
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
sz = all_sz
|
||||
num_mask = all_num_mask
|
||||
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
lengths[0] = min(mask_length, sz - 1)
|
||||
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
|
||||
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
||||
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||
|
||||
min_len = min([len(m) for m in mask_idcs])
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if len(mask_idc) > min_len:
|
||||
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||
mask[i, mask_idc] = True
|
||||
return mask
|
||||
|
||||
|
||||
# linear interpolation layer
|
||||
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
||||
features = features.transpose(1, 2)
|
||||
seq_len = features.shape[2] / float(input_fps)
|
||||
if output_len is None:
|
||||
output_len = int(seq_len * output_fps)
|
||||
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
|
||||
return output_features.transpose(1, 2)
|
||||
|
||||
|
||||
class Wav2Vec2Model(Wav2Vec2Model):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.lm_head = nn.Linear(1024, 32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
frame_num=None
|
||||
):
|
||||
self.config.output_attentions = True
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
hidden_states = self.feature_extractor(input_values)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
hidden_states = linear_interpolation(hidden_states, 50, 30, output_len=frame_num)
|
||||
|
||||
if attention_mask is not None:
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
||||
attention_mask = torch.zeros(
|
||||
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
attention_mask[
|
||||
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
|
||||
] = 1
|
||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||
|
||||
hidden_states = self.feature_projection(hidden_states)[0]
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = encoder_outputs[0]
|
||||
if not return_dict:
|
||||
return (hidden_states,) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechClassifierOutput(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class Wav2Vec2ClassificationHead(nn.Module):
|
||||
"""Head for wav2vec classification task."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = features
|
||||
x = self.dropout(x)
|
||||
x = self.dense(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.dropout(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.pooling_mode = config.pooling_mode
|
||||
self.config = config
|
||||
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
self.classifier = Wav2Vec2ClassificationHead(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||
|
||||
def merged_strategy(
|
||||
self,
|
||||
hidden_states,
|
||||
mode="mean"
|
||||
):
|
||||
if mode == "mean":
|
||||
outputs = torch.mean(hidden_states, dim=1)
|
||||
elif mode == "sum":
|
||||
outputs = torch.sum(hidden_states, dim=1)
|
||||
elif mode == "max":
|
||||
outputs = torch.max(hidden_states, dim=1)[0]
|
||||
else:
|
||||
raise Exception(
|
||||
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
|
||||
|
||||
return outputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
frame_num=None,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
outputs = self.wav2vec2(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
hidden_states1 = linear_interpolation(hidden_states, 50, 30, output_len=frame_num)
|
||||
hidden_states = self.merged_strategy(hidden_states1, mode=self.pooling_mode)
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SpeechClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=hidden_states1,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
87
models/encoder/wavlm.py
Normal file
87
models/encoder/wavlm.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import WavLMModel
|
||||
from transformers.modeling_outputs import Wav2Vec2BaseModelOutput
|
||||
from typing import Optional, Tuple, Union
|
||||
import torch.nn.functional as F
|
||||
|
||||
def linear_interpolation(features, output_len: int):
|
||||
features = features.transpose(1, 2)
|
||||
output_features = F.interpolate(
|
||||
features, size=output_len, align_corners=True, mode='linear')
|
||||
return output_features.transpose(1, 2)
|
||||
|
||||
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
|
||||
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
||||
|
||||
|
||||
class WavLMModel(WavLMModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = (not do_freeze)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
mask_time_indices: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
frame_num=None,
|
||||
interpolate_pos: int = 0,
|
||||
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
extract_features = self.feature_extractor(input_values)
|
||||
extract_features = extract_features.transpose(1, 2)
|
||||
|
||||
if interpolate_pos == 0:
|
||||
extract_features = linear_interpolation(
|
||||
extract_features, output_len=frame_num)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask corresponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(
|
||||
extract_features.shape[1], attention_mask, add_adapter=False
|
||||
)
|
||||
|
||||
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||
hidden_states = self._mask_hidden_states(
|
||||
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if interpolate_pos == 1:
|
||||
hidden_states = linear_interpolation(
|
||||
hidden_states, output_len=frame_num)
|
||||
|
||||
if self.adapter is not None:
|
||||
hidden_states = self.adapter(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states, extract_features) + encoder_outputs[1:]
|
||||
|
||||
return Wav2Vec2BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
extract_features=extract_features,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
4
models/losses/__init__.py
Normal file
4
models/losses/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .builder import build_criteria
|
||||
|
||||
from .misc import CrossEntropyLoss, SmoothCELoss, DiceLoss, FocalLoss, BinaryFocalLoss, L1Loss
|
||||
from .lovasz import LovaszLoss
|
||||
28
models/losses/builder.py
Normal file
28
models/losses/builder.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
from utils.registry import Registry
|
||||
|
||||
LOSSES = Registry("losses")
|
||||
|
||||
|
||||
class Criteria(object):
|
||||
def __init__(self, cfg=None):
|
||||
self.cfg = cfg if cfg is not None else []
|
||||
self.criteria = []
|
||||
for loss_cfg in self.cfg:
|
||||
self.criteria.append(LOSSES.build(cfg=loss_cfg))
|
||||
|
||||
def __call__(self, pred, target):
|
||||
if len(self.criteria) == 0:
|
||||
# loss computation occur in model
|
||||
return pred
|
||||
loss = 0
|
||||
for c in self.criteria:
|
||||
loss += c(pred, target)
|
||||
return loss
|
||||
|
||||
|
||||
def build_criteria(cfg):
|
||||
return Criteria(cfg)
|
||||
253
models/losses/lovasz.py
Normal file
253
models/losses/lovasz.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from itertools import filterfalse
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from .builder import LOSSES
|
||||
|
||||
BINARY_MODE: str = "binary"
|
||||
MULTICLASS_MODE: str = "multiclass"
|
||||
MULTILABEL_MODE: str = "multilabel"
|
||||
|
||||
|
||||
def _lovasz_grad(gt_sorted):
|
||||
"""Compute gradient of the Lovasz extension w.r.t sorted errors
|
||||
See Alg. 1 in paper
|
||||
"""
|
||||
p = len(gt_sorted)
|
||||
gts = gt_sorted.sum()
|
||||
intersection = gts - gt_sorted.float().cumsum(0)
|
||||
union = gts + (1 - gt_sorted).float().cumsum(0)
|
||||
jaccard = 1.0 - intersection / union
|
||||
if p > 1: # cover 1-pixel case
|
||||
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
||||
return jaccard
|
||||
|
||||
|
||||
def _lovasz_hinge(logits, labels, per_image=True, ignore=None):
|
||||
"""
|
||||
Binary Lovasz hinge loss
|
||||
logits: [B, H, W] Logits at each pixel (between -infinity and +infinity)
|
||||
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
|
||||
per_image: compute the loss per image instead of per batch
|
||||
ignore: void class id
|
||||
"""
|
||||
if per_image:
|
||||
loss = mean(
|
||||
_lovasz_hinge_flat(
|
||||
*_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)
|
||||
)
|
||||
for log, lab in zip(logits, labels)
|
||||
)
|
||||
else:
|
||||
loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore))
|
||||
return loss
|
||||
|
||||
|
||||
def _lovasz_hinge_flat(logits, labels):
|
||||
"""Binary Lovasz hinge loss
|
||||
Args:
|
||||
logits: [P] Logits at each prediction (between -infinity and +infinity)
|
||||
labels: [P] Tensor, binary ground truth labels (0 or 1)
|
||||
"""
|
||||
if len(labels) == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return logits.sum() * 0.0
|
||||
signs = 2.0 * labels.float() - 1.0
|
||||
errors = 1.0 - logits * signs
|
||||
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
|
||||
perm = perm.data
|
||||
gt_sorted = labels[perm]
|
||||
grad = _lovasz_grad(gt_sorted)
|
||||
loss = torch.dot(F.relu(errors_sorted), grad)
|
||||
return loss
|
||||
|
||||
|
||||
def _flatten_binary_scores(scores, labels, ignore=None):
|
||||
"""Flattens predictions in the batch (binary case)
|
||||
Remove labels equal to 'ignore'
|
||||
"""
|
||||
scores = scores.view(-1)
|
||||
labels = labels.view(-1)
|
||||
if ignore is None:
|
||||
return scores, labels
|
||||
valid = labels != ignore
|
||||
vscores = scores[valid]
|
||||
vlabels = labels[valid]
|
||||
return vscores, vlabels
|
||||
|
||||
|
||||
def _lovasz_softmax(
|
||||
probas, labels, classes="present", class_seen=None, per_image=False, ignore=None
|
||||
):
|
||||
"""Multi-class Lovasz-Softmax loss
|
||||
Args:
|
||||
@param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1).
|
||||
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
|
||||
@param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
|
||||
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
|
||||
@param per_image: compute the loss per image instead of per batch
|
||||
@param ignore: void class labels
|
||||
"""
|
||||
if per_image:
|
||||
loss = mean(
|
||||
_lovasz_softmax_flat(
|
||||
*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore),
|
||||
classes=classes
|
||||
)
|
||||
for prob, lab in zip(probas, labels)
|
||||
)
|
||||
else:
|
||||
loss = _lovasz_softmax_flat(
|
||||
*_flatten_probas(probas, labels, ignore),
|
||||
classes=classes,
|
||||
class_seen=class_seen
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
def _lovasz_softmax_flat(probas, labels, classes="present", class_seen=None):
|
||||
"""Multi-class Lovasz-Softmax loss
|
||||
Args:
|
||||
@param probas: [P, C] Class probabilities at each prediction (between 0 and 1)
|
||||
@param labels: [P] Tensor, ground truth labels (between 0 and C - 1)
|
||||
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
|
||||
"""
|
||||
if probas.numel() == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return probas * 0.0
|
||||
C = probas.size(1)
|
||||
losses = []
|
||||
class_to_sum = list(range(C)) if classes in ["all", "present"] else classes
|
||||
# for c in class_to_sum:
|
||||
for c in labels.unique():
|
||||
if class_seen is None:
|
||||
fg = (labels == c).type_as(probas) # foreground for class c
|
||||
if classes == "present" and fg.sum() == 0:
|
||||
continue
|
||||
if C == 1:
|
||||
if len(classes) > 1:
|
||||
raise ValueError("Sigmoid output possible only with 1 class")
|
||||
class_pred = probas[:, 0]
|
||||
else:
|
||||
class_pred = probas[:, c]
|
||||
errors = (fg - class_pred).abs()
|
||||
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
||||
perm = perm.data
|
||||
fg_sorted = fg[perm]
|
||||
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
|
||||
else:
|
||||
if c in class_seen:
|
||||
fg = (labels == c).type_as(probas) # foreground for class c
|
||||
if classes == "present" and fg.sum() == 0:
|
||||
continue
|
||||
if C == 1:
|
||||
if len(classes) > 1:
|
||||
raise ValueError("Sigmoid output possible only with 1 class")
|
||||
class_pred = probas[:, 0]
|
||||
else:
|
||||
class_pred = probas[:, c]
|
||||
errors = (fg - class_pred).abs()
|
||||
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
||||
perm = perm.data
|
||||
fg_sorted = fg[perm]
|
||||
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
|
||||
return mean(losses)
|
||||
|
||||
|
||||
def _flatten_probas(probas, labels, ignore=None):
|
||||
"""Flattens predictions in the batch"""
|
||||
if probas.dim() == 3:
|
||||
# assumes output of a sigmoid layer
|
||||
B, H, W = probas.size()
|
||||
probas = probas.view(B, 1, H, W)
|
||||
|
||||
C = probas.size(1)
|
||||
probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C]
|
||||
probas = probas.contiguous().view(-1, C) # [P, C]
|
||||
|
||||
labels = labels.view(-1)
|
||||
if ignore is None:
|
||||
return probas, labels
|
||||
valid = labels != ignore
|
||||
vprobas = probas[valid]
|
||||
vlabels = labels[valid]
|
||||
return vprobas, vlabels
|
||||
|
||||
|
||||
def isnan(x):
|
||||
return x != x
|
||||
|
||||
|
||||
def mean(values, ignore_nan=False, empty=0):
|
||||
"""Nan-mean compatible with generators."""
|
||||
values = iter(values)
|
||||
if ignore_nan:
|
||||
values = filterfalse(isnan, values)
|
||||
try:
|
||||
n = 1
|
||||
acc = next(values)
|
||||
except StopIteration:
|
||||
if empty == "raise":
|
||||
raise ValueError("Empty mean")
|
||||
return empty
|
||||
for n, v in enumerate(values, 2):
|
||||
acc += v
|
||||
if n == 1:
|
||||
return acc
|
||||
return acc / n
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class LovaszLoss(_Loss):
|
||||
def __init__(
|
||||
self,
|
||||
mode: str,
|
||||
class_seen: Optional[int] = None,
|
||||
per_image: bool = False,
|
||||
ignore_index: Optional[int] = None,
|
||||
loss_weight: float = 1.0,
|
||||
):
|
||||
"""Lovasz loss for segmentation task.
|
||||
It supports binary, multiclass and multilabel cases
|
||||
Args:
|
||||
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
|
||||
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
|
||||
per_image: If True loss computed per each image and then averaged, else computed per whole batch
|
||||
Shape
|
||||
- **y_pred** - torch.Tensor of shape (N, C, H, W)
|
||||
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
|
||||
Reference
|
||||
https://github.com/BloodAxe/pytorch-toolbelt
|
||||
"""
|
||||
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
self.ignore_index = ignore_index
|
||||
self.per_image = per_image
|
||||
self.class_seen = class_seen
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
|
||||
loss = _lovasz_hinge(
|
||||
y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index
|
||||
)
|
||||
elif self.mode == MULTICLASS_MODE:
|
||||
y_pred = y_pred.softmax(dim=1)
|
||||
loss = _lovasz_softmax(
|
||||
y_pred,
|
||||
y_true,
|
||||
class_seen=self.class_seen,
|
||||
per_image=self.per_image,
|
||||
ignore=self.ignore_index,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Wrong mode {}.".format(self.mode))
|
||||
return loss * self.loss_weight
|
||||
241
models/losses/misc.py
Normal file
241
models/losses/misc.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight=None,
|
||||
size_average=None,
|
||||
reduce=None,
|
||||
reduction="mean",
|
||||
label_smoothing=0.0,
|
||||
loss_weight=1.0,
|
||||
ignore_index=-1,
|
||||
):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
weight = torch.tensor(weight).cuda() if weight is not None else None
|
||||
self.loss_weight = loss_weight
|
||||
self.loss = nn.CrossEntropyLoss(
|
||||
weight=weight,
|
||||
size_average=size_average,
|
||||
ignore_index=ignore_index,
|
||||
reduce=reduce,
|
||||
reduction=reduction,
|
||||
label_smoothing=label_smoothing,
|
||||
)
|
||||
|
||||
def forward(self, pred, target):
|
||||
return self.loss(pred, target) * self.loss_weight
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class L1Loss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight=None,
|
||||
size_average=None,
|
||||
reduce=None,
|
||||
reduction="mean",
|
||||
label_smoothing=0.0,
|
||||
loss_weight=1.0,
|
||||
ignore_index=-1,
|
||||
):
|
||||
super(L1Loss, self).__init__()
|
||||
weight = torch.tensor(weight).cuda() if weight is not None else None
|
||||
self.loss_weight = loss_weight
|
||||
self.loss = nn.L1Loss(reduction='mean')
|
||||
|
||||
def forward(self, pred, target):
|
||||
return self.loss(pred, target[:,None]) * self.loss_weight
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class SmoothCELoss(nn.Module):
|
||||
def __init__(self, smoothing_ratio=0.1):
|
||||
super(SmoothCELoss, self).__init__()
|
||||
self.smoothing_ratio = smoothing_ratio
|
||||
|
||||
def forward(self, pred, target):
|
||||
eps = self.smoothing_ratio
|
||||
n_class = pred.size(1)
|
||||
one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
|
||||
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
||||
log_prb = F.log_softmax(pred, dim=1)
|
||||
loss = -(one_hot * log_prb).total(dim=1)
|
||||
loss = loss[torch.isfinite(loss)].mean()
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class BinaryFocalLoss(nn.Module):
|
||||
def __init__(self, gamma=2.0, alpha=0.5, logits=True, reduce=True, loss_weight=1.0):
|
||||
"""Binary Focal Loss
|
||||
<https://arxiv.org/abs/1708.02002>`
|
||||
"""
|
||||
super(BinaryFocalLoss, self).__init__()
|
||||
assert 0 < alpha < 1
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.logits = logits
|
||||
self.reduce = reduce
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self, pred, target, **kwargs):
|
||||
"""Forward function.
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N)
|
||||
target (torch.Tensor): The ground truth. If containing class
|
||||
indices, shape (N) where each value is 0≤targets[i]≤1, If containing class probabilities,
|
||||
same shape as the input.
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
if self.logits:
|
||||
bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
|
||||
else:
|
||||
bce = F.binary_cross_entropy(pred, target, reduction="none")
|
||||
pt = torch.exp(-bce)
|
||||
alpha = self.alpha * target + (1 - self.alpha) * (1 - target)
|
||||
focal_loss = alpha * (1 - pt) ** self.gamma * bce
|
||||
|
||||
if self.reduce:
|
||||
focal_loss = torch.mean(focal_loss)
|
||||
return focal_loss * self.loss_weight
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class FocalLoss(nn.Module):
|
||||
def __init__(
|
||||
self, gamma=2.0, alpha=0.5, reduction="mean", loss_weight=1.0, ignore_index=-1
|
||||
):
|
||||
"""Focal Loss
|
||||
<https://arxiv.org/abs/1708.02002>`
|
||||
"""
|
||||
super(FocalLoss, self).__init__()
|
||||
assert reduction in (
|
||||
"mean",
|
||||
"sum",
|
||||
), "AssertionError: reduction should be 'mean' or 'sum'"
|
||||
assert isinstance(
|
||||
alpha, (float, list)
|
||||
), "AssertionError: alpha should be of type float"
|
||||
assert isinstance(gamma, float), "AssertionError: gamma should be of type float"
|
||||
assert isinstance(
|
||||
loss_weight, float
|
||||
), "AssertionError: loss_weight should be of type float"
|
||||
assert isinstance(ignore_index, int), "ignore_index must be of type int"
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self, pred, target, **kwargs):
|
||||
"""Forward function.
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C) where C = number of classes.
|
||||
target (torch.Tensor): The ground truth. If containing class
|
||||
indices, shape (N) where each value is 0≤targets[i]≤C−1, If containing class probabilities,
|
||||
same shape as the input.
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
|
||||
pred = pred.transpose(0, 1)
|
||||
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||
pred = pred.reshape(pred.size(0), -1)
|
||||
# [C, N] -> [N, C]
|
||||
pred = pred.transpose(0, 1).contiguous()
|
||||
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
|
||||
target = target.view(-1).contiguous()
|
||||
assert pred.size(0) == target.size(
|
||||
0
|
||||
), "The shape of pred doesn't match the shape of target"
|
||||
valid_mask = target != self.ignore_index
|
||||
target = target[valid_mask]
|
||||
pred = pred[valid_mask]
|
||||
|
||||
if len(target) == 0:
|
||||
return 0.0
|
||||
|
||||
num_classes = pred.size(1)
|
||||
target = F.one_hot(target, num_classes=num_classes)
|
||||
|
||||
alpha = self.alpha
|
||||
if isinstance(alpha, list):
|
||||
alpha = pred.new_tensor(alpha)
|
||||
pred_sigmoid = pred.sigmoid()
|
||||
target = target.type_as(pred)
|
||||
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
|
||||
focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow(
|
||||
self.gamma
|
||||
)
|
||||
|
||||
loss = (
|
||||
F.binary_cross_entropy_with_logits(pred, target, reduction="none")
|
||||
* focal_weight
|
||||
)
|
||||
if self.reduction == "mean":
|
||||
loss = loss.mean()
|
||||
elif self.reduction == "sum":
|
||||
loss = loss.total()
|
||||
return self.loss_weight * loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class DiceLoss(nn.Module):
|
||||
def __init__(self, smooth=1, exponent=2, loss_weight=1.0, ignore_index=-1):
|
||||
"""DiceLoss.
|
||||
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
|
||||
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
|
||||
"""
|
||||
super(DiceLoss, self).__init__()
|
||||
self.smooth = smooth
|
||||
self.exponent = exponent
|
||||
self.loss_weight = loss_weight
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self, pred, target, **kwargs):
|
||||
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
|
||||
pred = pred.transpose(0, 1)
|
||||
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||
pred = pred.reshape(pred.size(0), -1)
|
||||
# [C, N] -> [N, C]
|
||||
pred = pred.transpose(0, 1).contiguous()
|
||||
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
|
||||
target = target.view(-1).contiguous()
|
||||
assert pred.size(0) == target.size(
|
||||
0
|
||||
), "The shape of pred doesn't match the shape of target"
|
||||
valid_mask = target != self.ignore_index
|
||||
target = target[valid_mask]
|
||||
pred = pred[valid_mask]
|
||||
|
||||
pred = F.softmax(pred, dim=1)
|
||||
num_classes = pred.shape[1]
|
||||
target = F.one_hot(
|
||||
torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes
|
||||
)
|
||||
|
||||
total_loss = 0
|
||||
for i in range(num_classes):
|
||||
if i != self.ignore_index:
|
||||
num = torch.sum(torch.mul(pred[:, i], target[:, i])) * 2 + self.smooth
|
||||
den = (
|
||||
torch.sum(
|
||||
pred[:, i].pow(self.exponent) + target[:, i].pow(self.exponent)
|
||||
)
|
||||
+ self.smooth
|
||||
)
|
||||
dice_loss = 1 - num / den
|
||||
total_loss += dice_loss
|
||||
loss = total_loss / num_classes
|
||||
return self.loss_weight * loss
|
||||
646
models/network.py
Normal file
646
models/network.py
Normal file
@@ -0,0 +1,646 @@
|
||||
import math
|
||||
import os.path
|
||||
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio as ta
|
||||
|
||||
from models.encoder.wav2vec import Wav2Vec2Model
|
||||
from models.encoder.wavlm import WavLMModel
|
||||
|
||||
from models.builder import MODELS
|
||||
|
||||
from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config
|
||||
|
||||
@MODELS.register_module("Audio2Expression")
|
||||
class Audio2Expression(nn.Module):
|
||||
def __init__(self,
|
||||
device: torch.device = None,
|
||||
pretrained_encoder_type: str = 'wav2vec',
|
||||
pretrained_encoder_path: str = '',
|
||||
wav2vec2_config_path: str = '',
|
||||
num_identity_classes: int = 0,
|
||||
identity_feat_dim: int = 64,
|
||||
hidden_dim: int = 512,
|
||||
expression_dim: int = 52,
|
||||
norm_type: str = 'ln',
|
||||
decoder_depth: int = 3,
|
||||
use_transformer: bool = False,
|
||||
num_attention_heads: int = 8,
|
||||
num_transformer_layers: int = 6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
# Initialize audio feature encoder
|
||||
if pretrained_encoder_type == 'wav2vec':
|
||||
if os.path.exists(pretrained_encoder_path):
|
||||
self.audio_encoder = Wav2Vec2Model.from_pretrained(pretrained_encoder_path)
|
||||
else:
|
||||
config = Wav2Vec2Config.from_pretrained(wav2vec2_config_path)
|
||||
self.audio_encoder = Wav2Vec2Model(config)
|
||||
encoder_output_dim = 768
|
||||
elif pretrained_encoder_type == 'wavlm':
|
||||
self.audio_encoder = WavLMModel.from_pretrained(pretrained_encoder_path)
|
||||
encoder_output_dim = 768
|
||||
else:
|
||||
raise NotImplementedError(f"Encoder type {pretrained_encoder_type} not supported")
|
||||
|
||||
self.audio_encoder.feature_extractor._freeze_parameters()
|
||||
self.feature_projection = nn.Linear(encoder_output_dim, hidden_dim)
|
||||
|
||||
self.identity_encoder = AudioIdentityEncoder(
|
||||
hidden_dim,
|
||||
num_identity_classes,
|
||||
identity_feat_dim,
|
||||
use_transformer,
|
||||
num_attention_heads,
|
||||
num_transformer_layers
|
||||
)
|
||||
|
||||
self.decoder = nn.ModuleList([
|
||||
nn.Sequential(*[
|
||||
ConvNormRelu(hidden_dim, hidden_dim, norm=norm_type)
|
||||
for _ in range(decoder_depth)
|
||||
])
|
||||
])
|
||||
|
||||
self.output_proj = nn.Linear(hidden_dim, expression_dim)
|
||||
|
||||
def freeze_encoder_parameters(self, do_freeze=False):
|
||||
|
||||
for name, param in self.audio_encoder.named_parameters():
|
||||
if('feature_extractor' in name):
|
||||
param.requires_grad = False
|
||||
else:
|
||||
param.requires_grad = (not do_freeze)
|
||||
|
||||
def forward(self, input_dict):
|
||||
|
||||
if 'time_steps' not in input_dict:
|
||||
audio_length = input_dict['input_audio_array'].shape[1]
|
||||
time_steps = math.ceil(audio_length / 16000 * 30)
|
||||
else:
|
||||
time_steps = input_dict['time_steps']
|
||||
|
||||
# Process audio through encoder
|
||||
audio_input = input_dict['input_audio_array'].flatten(start_dim=1)
|
||||
hidden_states = self.audio_encoder(audio_input, frame_num=time_steps).last_hidden_state
|
||||
|
||||
# Project features to hidden dimension
|
||||
audio_features = self.feature_projection(hidden_states).transpose(1, 2)
|
||||
|
||||
# Process identity-conditioned features
|
||||
audio_features = self.identity_encoder(audio_features, identity=input_dict['id_idx'])
|
||||
|
||||
# Refine features through decoder
|
||||
audio_features = self.decoder[0](audio_features)
|
||||
|
||||
# Generate output parameters
|
||||
audio_features = audio_features.permute(0, 2, 1)
|
||||
expression_params = self.output_proj(audio_features)
|
||||
|
||||
return torch.sigmoid(expression_params)
|
||||
|
||||
|
||||
class AudioIdentityEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_dim,
|
||||
num_identity_classes=0,
|
||||
identity_feat_dim=64,
|
||||
use_transformer=False,
|
||||
num_attention_heads = 8,
|
||||
num_transformer_layers = 6,
|
||||
dropout_ratio=0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
in_dim = hidden_dim + identity_feat_dim
|
||||
self.id_mlp = nn.Conv1d(num_identity_classes, identity_feat_dim, 1, 1)
|
||||
self.first_net = SeqTranslator1D(in_dim, hidden_dim,
|
||||
min_layers_num=3,
|
||||
residual=True,
|
||||
norm='ln'
|
||||
)
|
||||
self.grus = nn.GRU(hidden_dim, hidden_dim, 1, batch_first=True)
|
||||
self.dropout = nn.Dropout(dropout_ratio)
|
||||
|
||||
self.use_transformer = use_transformer
|
||||
if(self.use_transformer):
|
||||
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_attention_heads, dim_feedforward= 2 * hidden_dim, batch_first=True)
|
||||
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers)
|
||||
|
||||
def forward(self,
|
||||
audio_features: torch.Tensor,
|
||||
identity: torch.Tensor = None,
|
||||
time_steps: int = None) -> tuple:
|
||||
|
||||
audio_features = self.dropout(audio_features)
|
||||
identity = identity.reshape(identity.shape[0], -1, 1).repeat(1, 1, audio_features.shape[2]).to(torch.float32)
|
||||
identity = self.id_mlp(identity)
|
||||
audio_features = torch.cat([audio_features, identity], dim=1)
|
||||
|
||||
x = self.first_net(audio_features)
|
||||
|
||||
if time_steps is not None:
|
||||
x = F.interpolate(x, size=time_steps, align_corners=False, mode='linear')
|
||||
|
||||
if(self.use_transformer):
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.transformer_encoder(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
return x
|
||||
|
||||
class ConvNormRelu(nn.Module):
|
||||
'''
|
||||
(B,C_in,H,W) -> (B, C_out, H, W)
|
||||
there exist some kernel size that makes the result is not H/s
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
type='1d',
|
||||
leaky=False,
|
||||
downsample=False,
|
||||
kernel_size=None,
|
||||
stride=None,
|
||||
padding=None,
|
||||
p=0,
|
||||
groups=1,
|
||||
residual=False,
|
||||
norm='bn'):
|
||||
'''
|
||||
conv-bn-relu
|
||||
'''
|
||||
super(ConvNormRelu, self).__init__()
|
||||
self.residual = residual
|
||||
self.norm_type = norm
|
||||
# kernel_size = k
|
||||
# stride = s
|
||||
|
||||
if kernel_size is None and stride is None:
|
||||
if not downsample:
|
||||
kernel_size = 3
|
||||
stride = 1
|
||||
else:
|
||||
kernel_size = 4
|
||||
stride = 2
|
||||
|
||||
if padding is None:
|
||||
if isinstance(kernel_size, int) and isinstance(stride, tuple):
|
||||
padding = tuple(int((kernel_size - st) / 2) for st in stride)
|
||||
elif isinstance(kernel_size, tuple) and isinstance(stride, int):
|
||||
padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
|
||||
elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
|
||||
padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
|
||||
else:
|
||||
padding = int((kernel_size - stride) / 2)
|
||||
|
||||
if self.residual:
|
||||
if downsample:
|
||||
if type == '1d':
|
||||
self.residual_layer = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding
|
||||
)
|
||||
)
|
||||
elif type == '2d':
|
||||
self.residual_layer = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding
|
||||
)
|
||||
)
|
||||
else:
|
||||
if in_channels == out_channels:
|
||||
self.residual_layer = nn.Identity()
|
||||
else:
|
||||
if type == '1d':
|
||||
self.residual_layer = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding
|
||||
)
|
||||
)
|
||||
elif type == '2d':
|
||||
self.residual_layer = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding
|
||||
)
|
||||
)
|
||||
|
||||
in_channels = in_channels * groups
|
||||
out_channels = out_channels * groups
|
||||
if type == '1d':
|
||||
self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
groups=groups)
|
||||
self.norm = nn.BatchNorm1d(out_channels)
|
||||
self.dropout = nn.Dropout(p=p)
|
||||
elif type == '2d':
|
||||
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
groups=groups)
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
self.dropout = nn.Dropout2d(p=p)
|
||||
if norm == 'gn':
|
||||
self.norm = nn.GroupNorm(2, out_channels)
|
||||
elif norm == 'ln':
|
||||
self.norm = nn.LayerNorm(out_channels)
|
||||
if leaky:
|
||||
self.relu = nn.LeakyReLU(negative_slope=0.2)
|
||||
else:
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
if self.norm_type == 'ln':
|
||||
out = self.dropout(self.conv(x))
|
||||
out = self.norm(out.transpose(1,2)).transpose(1,2)
|
||||
else:
|
||||
out = self.norm(self.dropout(self.conv(x)))
|
||||
if self.residual:
|
||||
residual = self.residual_layer(x)
|
||||
out += residual
|
||||
return self.relu(out)
|
||||
|
||||
""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
|
||||
class SeqTranslator1D(nn.Module):
|
||||
'''
|
||||
(B, C, T)->(B, C_out, T)
|
||||
'''
|
||||
def __init__(self,
|
||||
C_in,
|
||||
C_out,
|
||||
kernel_size=None,
|
||||
stride=None,
|
||||
min_layers_num=None,
|
||||
residual=True,
|
||||
norm='bn'
|
||||
):
|
||||
super(SeqTranslator1D, self).__init__()
|
||||
|
||||
conv_layers = nn.ModuleList([])
|
||||
conv_layers.append(ConvNormRelu(
|
||||
in_channels=C_in,
|
||||
out_channels=C_out,
|
||||
type='1d',
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
residual=residual,
|
||||
norm=norm
|
||||
))
|
||||
self.num_layers = 1
|
||||
if min_layers_num is not None and self.num_layers < min_layers_num:
|
||||
while self.num_layers < min_layers_num:
|
||||
conv_layers.append(ConvNormRelu(
|
||||
in_channels=C_out,
|
||||
out_channels=C_out,
|
||||
type='1d',
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
residual=residual,
|
||||
norm=norm
|
||||
))
|
||||
self.num_layers += 1
|
||||
self.conv_layers = nn.Sequential(*conv_layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_layers(x)
|
||||
|
||||
|
||||
def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
|
||||
"""
|
||||
:param audio: 1 x T tensor containing a 16kHz audio signal
|
||||
:param frame_rate: frame rate for video (we need one audio chunk per video frame)
|
||||
:param chunk_size: number of audio samples per chunk
|
||||
:return: num_chunks x chunk_size tensor containing sliced audio
|
||||
"""
|
||||
samples_per_frame = 16000 // frame_rate
|
||||
padding = (chunk_size - samples_per_frame) // 2
|
||||
audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
|
||||
anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
|
||||
audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
|
||||
return audio
|
||||
|
||||
""" https://github.com/facebookresearch/meshtalk """
|
||||
class MeshtalkEncoder(nn.Module):
|
||||
def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'):
|
||||
"""
|
||||
:param latent_dim: size of the latent audio embedding
|
||||
:param model_name: name of the model, used to load and save the model
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.melspec = ta.transforms.MelSpectrogram(
|
||||
sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80
|
||||
)
|
||||
|
||||
conv_len = 5
|
||||
self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len)
|
||||
self.weights_init(self.convert_dimensions)
|
||||
self.receptive_field = conv_len
|
||||
|
||||
convs = []
|
||||
for i in range(6):
|
||||
dilation = 2 * (i % 3 + 1)
|
||||
self.receptive_field += (conv_len - 1) * dilation
|
||||
convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)]
|
||||
self.weights_init(convs[-1])
|
||||
self.convs = torch.nn.ModuleList(convs)
|
||||
self.code = torch.nn.Linear(128, latent_dim)
|
||||
|
||||
self.apply(lambda x: self.weights_init(x))
|
||||
|
||||
def weights_init(self, m):
|
||||
if isinstance(m, torch.nn.Conv1d):
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
try:
|
||||
torch.nn.init.constant_(m.bias, .01)
|
||||
except:
|
||||
pass
|
||||
|
||||
def forward(self, audio: torch.Tensor):
|
||||
"""
|
||||
:param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame
|
||||
:return: code: B x T x latent_dim Tensor containing a latent audio code/embedding
|
||||
"""
|
||||
B, T = audio.shape[0], audio.shape[1]
|
||||
x = self.melspec(audio).squeeze(1)
|
||||
x = torch.log(x.clamp(min=1e-10, max=None))
|
||||
if T == 1:
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
# Convert to the right dimensionality
|
||||
x = x.view(-1, x.shape[2], x.shape[3])
|
||||
x = F.leaky_relu(self.convert_dimensions(x), .2)
|
||||
|
||||
# Process stacks
|
||||
for conv in self.convs:
|
||||
x_ = F.leaky_relu(conv(x), .2)
|
||||
if self.training:
|
||||
x_ = F.dropout(x_, .2)
|
||||
l = (x.shape[2] - x_.shape[2]) // 2
|
||||
x = (x[:, :, l:-l] + x_) / 2
|
||||
|
||||
x = torch.mean(x, dim=-1)
|
||||
x = x.view(B, T, x.shape[-1])
|
||||
x = self.code(x)
|
||||
|
||||
return {"code": x}
|
||||
|
||||
class PeriodicPositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=64):
|
||||
super(PeriodicPositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
pe = torch.zeros(period, d_model)
|
||||
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0) # (1, period, d_model)
|
||||
repeat_num = (max_seq_len//period) + 1
|
||||
pe = pe.repeat(1, repeat_num, 1) # (1, repeat_num, period, d_model)
|
||||
self.register_buffer('pe', pe)
|
||||
def forward(self, x):
|
||||
# print(self.pe.shape, x.shape)
|
||||
x = x + self.pe[:, :x.size(1), :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class GeneratorTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
n_poses,
|
||||
each_dim: list,
|
||||
dim_list: list,
|
||||
training=True,
|
||||
device=None,
|
||||
identity=False,
|
||||
num_classes=0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.training = training
|
||||
self.device = device
|
||||
self.gen_length = n_poses
|
||||
|
||||
norm = 'ln'
|
||||
in_dim = 256
|
||||
out_dim = 256
|
||||
|
||||
self.encoder_choice = 'faceformer'
|
||||
|
||||
self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
|
||||
self.audio_encoder.feature_extractor._freeze_parameters()
|
||||
self.audio_feature_map = nn.Linear(768, in_dim)
|
||||
|
||||
self.audio_middle = AudioEncoder(in_dim, out_dim, False, num_classes)
|
||||
|
||||
self.dim_list = dim_list
|
||||
|
||||
self.decoder = nn.ModuleList()
|
||||
self.final_out = nn.ModuleList()
|
||||
|
||||
self.hidden_size = 768
|
||||
self.transformer_de_layer = nn.TransformerDecoderLayer(
|
||||
d_model=self.hidden_size,
|
||||
nhead=4,
|
||||
dim_feedforward=self.hidden_size*2,
|
||||
batch_first=True
|
||||
)
|
||||
self.face_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=4)
|
||||
self.feature2face = nn.Linear(256, self.hidden_size)
|
||||
|
||||
self.position_embeddings = PeriodicPositionalEncoding(self.hidden_size, period=64, max_seq_len=64)
|
||||
self.id_maping = nn.Linear(12,self.hidden_size)
|
||||
|
||||
|
||||
self.decoder.append(self.face_decoder)
|
||||
self.final_out.append(nn.Linear(self.hidden_size, 32))
|
||||
|
||||
def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None):
|
||||
if gt_poses is None:
|
||||
time_steps = 64
|
||||
else:
|
||||
time_steps = gt_poses.shape[1]
|
||||
|
||||
# vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
|
||||
if self.encoder_choice == 'meshtalk':
|
||||
in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000)
|
||||
feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2)
|
||||
elif self.encoder_choice == 'faceformer':
|
||||
hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state
|
||||
feature = self.audio_feature_map(hidden_states).transpose(1, 2)
|
||||
else:
|
||||
feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
|
||||
|
||||
feature, _ = self.audio_middle(feature, id=None)
|
||||
feature = self.feature2face(feature.permute(0,2,1))
|
||||
|
||||
id = id.unsqueeze(1).repeat(1,64,1).to(torch.float32)
|
||||
id_feature = self.id_maping(id)
|
||||
id_feature = self.position_embeddings(id_feature)
|
||||
|
||||
for i in range(self.decoder.__len__()):
|
||||
mid = self.decoder[i](tgt=id_feature, memory=feature)
|
||||
out = self.final_out[i](mid)
|
||||
|
||||
return out, None
|
||||
|
||||
def linear_interpolation(features, output_len: int):
|
||||
features = features.transpose(1, 2)
|
||||
output_features = F.interpolate(
|
||||
features, size=output_len, align_corners=True, mode='linear')
|
||||
return output_features.transpose(1, 2)
|
||||
|
||||
def init_biased_mask(n_head, max_seq_len, period):
|
||||
|
||||
def get_slopes(n):
|
||||
|
||||
def get_slopes_power_of_2(n):
|
||||
start = (2**(-2**-(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(n))
|
||||
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(
|
||||
2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
|
||||
|
||||
slopes = torch.Tensor(get_slopes(n_head))
|
||||
bias = torch.div(
|
||||
torch.arange(start=0, end=max_seq_len,
|
||||
step=period).unsqueeze(1).repeat(1, period).view(-1),
|
||||
period,
|
||||
rounding_mode='floor')
|
||||
bias = -torch.flip(bias, dims=[0])
|
||||
alibi = torch.zeros(max_seq_len, max_seq_len)
|
||||
for i in range(max_seq_len):
|
||||
alibi[i, :i + 1] = bias[-(i + 1):]
|
||||
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
|
||||
mask = (torch.triu(torch.ones(max_seq_len,
|
||||
max_seq_len)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
|
||||
mask == 1, float(0.0))
|
||||
mask = mask.unsqueeze(0) + alibi
|
||||
return mask
|
||||
|
||||
|
||||
# Alignment Bias
|
||||
def enc_dec_mask(device, T, S):
|
||||
mask = torch.ones(T, S)
|
||||
for i in range(T):
|
||||
mask[i, i] = 0
|
||||
return (mask == 1).to(device=device)
|
||||
|
||||
|
||||
# Periodic Positional Encoding
|
||||
class PeriodicPositionalEncoding(nn.Module):
|
||||
|
||||
def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000):
|
||||
super(PeriodicPositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
pe = torch.zeros(period, d_model)
|
||||
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, d_model, 2).float() *
|
||||
(-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0) # (1, period, d_model)
|
||||
repeat_num = (max_seq_len // period) + 1
|
||||
pe = pe.repeat(1, repeat_num, 1)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[:, :x.size(1), :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
"""Base class for all models."""
|
||||
|
||||
def __init__(self):
|
||||
super(BaseModel, self).__init__()
|
||||
# self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
def forward(self, *x):
|
||||
"""Forward pass logic.
|
||||
|
||||
:return: Model output
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def freeze_model(self, do_freeze: bool = True):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = (not do_freeze)
|
||||
|
||||
def summary(self, logger, writer=None):
|
||||
"""Model summary."""
|
||||
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
params = sum([np.prod(p.size())
|
||||
for p in model_parameters]) / 1e6 # Unit is Mega
|
||||
logger.info('===>Trainable parameters: %.3f M' % params)
|
||||
if writer is not None:
|
||||
writer.add_text('Model Summary',
|
||||
'Trainable parameters: %.3f M' % params)
|
||||
|
||||
|
||||
"""https://github.com/X-niper/UniTalker"""
|
||||
class UniTalkerDecoderTransformer(BaseModel):
|
||||
|
||||
def __init__(self, out_dim, identity_num, period=30, interpolate_pos=1) -> None:
|
||||
super().__init__()
|
||||
self.learnable_style_emb = nn.Embedding(identity_num, out_dim)
|
||||
self.PPE = PeriodicPositionalEncoding(
|
||||
out_dim, period=period, max_seq_len=3000)
|
||||
self.biased_mask = init_biased_mask(
|
||||
n_head=4, max_seq_len=3000, period=period)
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=out_dim,
|
||||
nhead=4,
|
||||
dim_feedforward=2 * out_dim,
|
||||
batch_first=True)
|
||||
self.transformer_decoder = nn.TransformerDecoder(
|
||||
decoder_layer, num_layers=1)
|
||||
self.interpolate_pos = interpolate_pos
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor,
|
||||
frame_num: int):
|
||||
style_idx = torch.argmax(style_idx, dim=1)
|
||||
obj_embedding = self.learnable_style_emb(style_idx)
|
||||
obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1)
|
||||
style_input = self.PPE(obj_embedding)
|
||||
tgt_mask = self.biased_mask.repeat(style_idx.shape[0], 1, 1)[:, :style_input.shape[1], :style_input.
|
||||
shape[1]].clone().detach().to(
|
||||
device=style_input.device)
|
||||
memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1],
|
||||
frame_num)
|
||||
feat_out = self.transformer_decoder(
|
||||
style_input,
|
||||
hidden_states,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask)
|
||||
if self.interpolate_pos == 2:
|
||||
feat_out = linear_interpolation(feat_out, output_len=frame_num)
|
||||
return feat_out
|
||||
752
models/utils.py
Normal file
752
models/utils.py
Normal file
@@ -0,0 +1,752 @@
|
||||
import json
|
||||
import time
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import List, Optional,Tuple
|
||||
from scipy.signal import savgol_filter
|
||||
|
||||
|
||||
ARKitLeftRightPair = [
|
||||
("jawLeft", "jawRight"),
|
||||
("mouthLeft", "mouthRight"),
|
||||
("mouthSmileLeft", "mouthSmileRight"),
|
||||
("mouthFrownLeft", "mouthFrownRight"),
|
||||
("mouthDimpleLeft", "mouthDimpleRight"),
|
||||
("mouthStretchLeft", "mouthStretchRight"),
|
||||
("mouthPressLeft", "mouthPressRight"),
|
||||
("mouthLowerDownLeft", "mouthLowerDownRight"),
|
||||
("mouthUpperUpLeft", "mouthUpperUpRight"),
|
||||
("cheekSquintLeft", "cheekSquintRight"),
|
||||
("noseSneerLeft", "noseSneerRight"),
|
||||
("browDownLeft", "browDownRight"),
|
||||
("browOuterUpLeft", "browOuterUpRight"),
|
||||
("eyeBlinkLeft","eyeBlinkRight"),
|
||||
("eyeLookDownLeft","eyeLookDownRight"),
|
||||
("eyeLookInLeft", "eyeLookInRight"),
|
||||
("eyeLookOutLeft","eyeLookOutRight"),
|
||||
("eyeLookUpLeft","eyeLookUpRight"),
|
||||
("eyeSquintLeft","eyeSquintRight"),
|
||||
("eyeWideLeft","eyeWideRight")
|
||||
]
|
||||
|
||||
ARKitBlendShape =[
|
||||
"browDownLeft",
|
||||
"browDownRight",
|
||||
"browInnerUp",
|
||||
"browOuterUpLeft",
|
||||
"browOuterUpRight",
|
||||
"cheekPuff",
|
||||
"cheekSquintLeft",
|
||||
"cheekSquintRight",
|
||||
"eyeBlinkLeft",
|
||||
"eyeBlinkRight",
|
||||
"eyeLookDownLeft",
|
||||
"eyeLookDownRight",
|
||||
"eyeLookInLeft",
|
||||
"eyeLookInRight",
|
||||
"eyeLookOutLeft",
|
||||
"eyeLookOutRight",
|
||||
"eyeLookUpLeft",
|
||||
"eyeLookUpRight",
|
||||
"eyeSquintLeft",
|
||||
"eyeSquintRight",
|
||||
"eyeWideLeft",
|
||||
"eyeWideRight",
|
||||
"jawForward",
|
||||
"jawLeft",
|
||||
"jawOpen",
|
||||
"jawRight",
|
||||
"mouthClose",
|
||||
"mouthDimpleLeft",
|
||||
"mouthDimpleRight",
|
||||
"mouthFrownLeft",
|
||||
"mouthFrownRight",
|
||||
"mouthFunnel",
|
||||
"mouthLeft",
|
||||
"mouthLowerDownLeft",
|
||||
"mouthLowerDownRight",
|
||||
"mouthPressLeft",
|
||||
"mouthPressRight",
|
||||
"mouthPucker",
|
||||
"mouthRight",
|
||||
"mouthRollLower",
|
||||
"mouthRollUpper",
|
||||
"mouthShrugLower",
|
||||
"mouthShrugUpper",
|
||||
"mouthSmileLeft",
|
||||
"mouthSmileRight",
|
||||
"mouthStretchLeft",
|
||||
"mouthStretchRight",
|
||||
"mouthUpperUpLeft",
|
||||
"mouthUpperUpRight",
|
||||
"noseSneerLeft",
|
||||
"noseSneerRight",
|
||||
"tongueOut"
|
||||
]
|
||||
|
||||
MOUTH_BLENDSHAPES = [ "mouthDimpleLeft",
|
||||
"mouthDimpleRight",
|
||||
"mouthFrownLeft",
|
||||
"mouthFrownRight",
|
||||
"mouthFunnel",
|
||||
"mouthLeft",
|
||||
"mouthLowerDownLeft",
|
||||
"mouthLowerDownRight",
|
||||
"mouthPressLeft",
|
||||
"mouthPressRight",
|
||||
"mouthPucker",
|
||||
"mouthRight",
|
||||
"mouthRollLower",
|
||||
"mouthRollUpper",
|
||||
"mouthShrugLower",
|
||||
"mouthShrugUpper",
|
||||
"mouthSmileLeft",
|
||||
"mouthSmileRight",
|
||||
"mouthStretchLeft",
|
||||
"mouthStretchRight",
|
||||
"mouthUpperUpLeft",
|
||||
"mouthUpperUpRight",
|
||||
"jawForward",
|
||||
"jawLeft",
|
||||
"jawOpen",
|
||||
"jawRight",
|
||||
"noseSneerLeft",
|
||||
"noseSneerRight",
|
||||
"cheekPuff",
|
||||
]
|
||||
|
||||
DEFAULT_CONTEXT ={
|
||||
'is_initial_input': True,
|
||||
'previous_audio': None,
|
||||
'previous_expression': None,
|
||||
'previous_volume': None,
|
||||
'previous_headpose': None,
|
||||
}
|
||||
|
||||
RETURN_CODE = {
|
||||
"SUCCESS": 0,
|
||||
"AUDIO_LENGTH_ERROR": 1,
|
||||
"CHECKPOINT_PATH_ERROR":2,
|
||||
"MODEL_INFERENCE_ERROR":3,
|
||||
}
|
||||
|
||||
DEFAULT_CONTEXTRETURN = {
|
||||
"code": RETURN_CODE['SUCCESS'],
|
||||
"expression": None,
|
||||
"headpose": None,
|
||||
}
|
||||
|
||||
BLINK_PATTERNS = [
|
||||
np.array([0.365, 0.950, 0.956, 0.917, 0.367, 0.119, 0.025]),
|
||||
np.array([0.235, 0.910, 0.945, 0.778, 0.191, 0.235, 0.089]),
|
||||
np.array([0.870, 0.950, 0.949, 0.696, 0.191, 0.073, 0.007]),
|
||||
np.array([0.000, 0.557, 0.953, 0.942, 0.426, 0.148, 0.018])
|
||||
]
|
||||
|
||||
# Postprocess
|
||||
def symmetrize_blendshapes(
|
||||
bs_params: np.ndarray,
|
||||
mode: str = "average",
|
||||
symmetric_pairs: list = ARKitLeftRightPair
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply symmetrization to ARKit blendshape parameters (batched version)
|
||||
|
||||
Args:
|
||||
bs_params: numpy array of shape (N, 52), batch of ARKit parameters
|
||||
mode: symmetrization mode ["average", "max", "min", "left_dominant", "right_dominant"]
|
||||
symmetric_pairs: list of left-right parameter pairs
|
||||
|
||||
Returns:
|
||||
Symmetrized parameters with same shape (N, 52)
|
||||
"""
|
||||
|
||||
name_to_idx = {name: i for i, name in enumerate(ARKitBlendShape)}
|
||||
|
||||
# Input validation
|
||||
if bs_params.ndim != 2 or bs_params.shape[1] != 52:
|
||||
raise ValueError("Input must be of shape (N, 52)")
|
||||
|
||||
symmetric_bs = bs_params.copy() # Shape (N, 52)
|
||||
|
||||
# Precompute valid index pairs
|
||||
valid_pairs = []
|
||||
for left, right in symmetric_pairs:
|
||||
left_idx = name_to_idx.get(left)
|
||||
right_idx = name_to_idx.get(right)
|
||||
if None not in (left_idx, right_idx):
|
||||
valid_pairs.append((left_idx, right_idx))
|
||||
|
||||
# Vectorized processing
|
||||
for l_idx, r_idx in valid_pairs:
|
||||
left_col = symmetric_bs[:, l_idx]
|
||||
right_col = symmetric_bs[:, r_idx]
|
||||
|
||||
if mode == "average":
|
||||
new_vals = (left_col + right_col) / 2
|
||||
elif mode == "max":
|
||||
new_vals = np.maximum(left_col, right_col)
|
||||
elif mode == "min":
|
||||
new_vals = np.minimum(left_col, right_col)
|
||||
elif mode == "left_dominant":
|
||||
new_vals = left_col
|
||||
elif mode == "right_dominant":
|
||||
new_vals = right_col
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
# Update both columns simultaneously
|
||||
symmetric_bs[:, l_idx] = new_vals
|
||||
symmetric_bs[:, r_idx] = new_vals
|
||||
|
||||
return symmetric_bs
|
||||
|
||||
|
||||
def apply_random_eye_blinks(
|
||||
input: np.ndarray,
|
||||
blink_scale: tuple = (0.8, 1.0),
|
||||
blink_interval: tuple = (60, 120),
|
||||
blink_duration: int = 7
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply randomized eye blinks to blendshape parameters
|
||||
|
||||
Args:
|
||||
output: Input array of shape (N, 52) containing blendshape parameters
|
||||
blink_scale: Tuple (min, max) for random blink intensity scaling
|
||||
blink_interval: Tuple (min, max) for random blink spacing in frames
|
||||
blink_duration: Number of frames for blink animation (fixed)
|
||||
|
||||
Returns:
|
||||
None (modifies output array in-place)
|
||||
"""
|
||||
# Define eye blink patterns (normalized 0-1)
|
||||
|
||||
# Initialize parameters
|
||||
n_frames = input.shape[0]
|
||||
input[:,8:10] = np.zeros((n_frames,2))
|
||||
current_frame = 0
|
||||
|
||||
# Main blink application loop
|
||||
while current_frame < n_frames - blink_duration:
|
||||
# Randomize blink parameters
|
||||
scale = np.random.uniform(*blink_scale)
|
||||
pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
|
||||
|
||||
# Apply blink animation
|
||||
blink_values = pattern * scale
|
||||
input[current_frame:current_frame + blink_duration, 8] = blink_values
|
||||
input[current_frame:current_frame + blink_duration, 9] = blink_values
|
||||
|
||||
# Advance to next blink position
|
||||
current_frame += blink_duration + np.random.randint(*blink_interval)
|
||||
|
||||
return input
|
||||
|
||||
|
||||
def apply_random_eye_blinks_context(
|
||||
animation_params: np.ndarray,
|
||||
processed_frames: int = 0,
|
||||
intensity_range: tuple = (0.8, 1.0)
|
||||
) -> np.ndarray:
|
||||
"""Applies random eye blink patterns to facial animation parameters.
|
||||
|
||||
Args:
|
||||
animation_params: Input facial animation parameters array with shape [num_frames, num_features].
|
||||
Columns 8 and 9 typically represent left/right eye blink parameters.
|
||||
processed_frames: Number of already processed frames that shouldn't be modified
|
||||
intensity_range: Tuple defining (min, max) scaling for blink intensity
|
||||
|
||||
Returns:
|
||||
Modified animation parameters array with random eye blinks added to unprocessed frames
|
||||
"""
|
||||
remaining_frames = animation_params.shape[0] - processed_frames
|
||||
|
||||
# Only apply blinks if there's enough remaining frames (blink pattern requires 7 frames)
|
||||
if remaining_frames <= 7:
|
||||
return animation_params
|
||||
|
||||
# Configure blink timing parameters
|
||||
min_blink_interval = 40 # Minimum frames between blinks
|
||||
max_blink_interval = 100 # Maximum frames between blinks
|
||||
|
||||
# Find last blink in previously processed frames (column 8 > 0.5 indicates blink)
|
||||
previous_blink_indices = np.where(animation_params[:processed_frames, 8] > 0.5)[0]
|
||||
last_processed_blink = previous_blink_indices[-1] - 7 if previous_blink_indices.size > 0 else processed_frames
|
||||
|
||||
# Calculate first new blink position
|
||||
blink_interval = np.random.randint(min_blink_interval, max_blink_interval)
|
||||
first_blink_start = max(0, blink_interval - last_processed_blink)
|
||||
|
||||
# Apply first blink if there's enough space
|
||||
if first_blink_start <= (remaining_frames - 7):
|
||||
# Randomly select blink pattern and intensity
|
||||
blink_pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
|
||||
intensity = np.random.uniform(*intensity_range)
|
||||
|
||||
# Calculate blink frame range
|
||||
blink_start = processed_frames + first_blink_start
|
||||
blink_end = blink_start + 7
|
||||
|
||||
# Apply pattern to both eyes
|
||||
animation_params[blink_start:blink_end, 8] = blink_pattern * intensity
|
||||
animation_params[blink_start:blink_end, 9] = blink_pattern * intensity
|
||||
|
||||
# Check space for additional blink
|
||||
remaining_after_blink = animation_params.shape[0] - blink_end
|
||||
if remaining_after_blink > min_blink_interval:
|
||||
# Calculate second blink position
|
||||
second_intensity = np.random.uniform(*intensity_range)
|
||||
second_interval = np.random.randint(min_blink_interval, max_blink_interval)
|
||||
|
||||
if (remaining_after_blink - 7) > second_interval:
|
||||
second_pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
|
||||
second_blink_start = blink_end + second_interval
|
||||
second_blink_end = second_blink_start + 7
|
||||
|
||||
# Apply second blink
|
||||
animation_params[second_blink_start:second_blink_end, 8] = second_pattern * second_intensity
|
||||
animation_params[second_blink_start:second_blink_end, 9] = second_pattern * second_intensity
|
||||
|
||||
return animation_params
|
||||
|
||||
|
||||
def export_blendshape_animation(
|
||||
blendshape_weights: np.ndarray,
|
||||
output_path: str,
|
||||
blendshape_names: List[str],
|
||||
fps: float,
|
||||
rotation_data: Optional[np.ndarray] = None
|
||||
) -> None:
|
||||
"""
|
||||
Export blendshape animation data to JSON format compatible with ARKit.
|
||||
|
||||
Args:
|
||||
blendshape_weights: 2D numpy array of shape (N, 52) containing animation frames
|
||||
output_path: Full path for output JSON file (including .json extension)
|
||||
blendshape_names: Ordered list of 52 ARKit-standard blendshape names
|
||||
fps: Frame rate for timing calculations (frames per second)
|
||||
rotation_data: Optional 3D rotation data array of shape (N, 3)
|
||||
|
||||
Raises:
|
||||
ValueError: If input dimensions are incompatible
|
||||
IOError: If file writing fails
|
||||
"""
|
||||
# Validate input dimensions
|
||||
if blendshape_weights.shape[1] != 52:
|
||||
raise ValueError(f"Expected 52 blendshapes, got {blendshape_weights.shape[1]}")
|
||||
if len(blendshape_names) != 52:
|
||||
raise ValueError(f"Requires 52 blendshape names, got {len(blendshape_names)}")
|
||||
if rotation_data is not None and len(rotation_data) != len(blendshape_weights):
|
||||
raise ValueError("Rotation data length must match animation frames")
|
||||
|
||||
# Build animation data structure
|
||||
animation_data = {
|
||||
"names":blendshape_names,
|
||||
"metadata": {
|
||||
"fps": fps,
|
||||
"frame_count": len(blendshape_weights),
|
||||
"blendshape_names": blendshape_names
|
||||
},
|
||||
"frames": []
|
||||
}
|
||||
|
||||
# Convert numpy array to serializable format
|
||||
for frame_idx in range(blendshape_weights.shape[0]):
|
||||
frame_data = {
|
||||
"weights": blendshape_weights[frame_idx].tolist(),
|
||||
"time": frame_idx / fps,
|
||||
"rotation": rotation_data[frame_idx].tolist() if rotation_data else []
|
||||
}
|
||||
animation_data["frames"].append(frame_data)
|
||||
|
||||
# Safeguard against data loss
|
||||
if not output_path.endswith('.json'):
|
||||
output_path += '.json'
|
||||
|
||||
# Write to file with error handling
|
||||
try:
|
||||
with open(output_path, 'w', encoding='utf-8') as json_file:
|
||||
json.dump(animation_data, json_file, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
raise IOError(f"Failed to write animation data: {str(e)}") from e
|
||||
|
||||
|
||||
def apply_savitzky_golay_smoothing(
|
||||
input_data: np.ndarray,
|
||||
window_length: int = 5,
|
||||
polyorder: int = 2,
|
||||
axis: int = 0,
|
||||
validate: bool = True
|
||||
) -> Tuple[np.ndarray, Optional[float]]:
|
||||
"""
|
||||
Apply Savitzky-Golay filter smoothing along specified axis of input data.
|
||||
|
||||
Args:
|
||||
input_data: 2D numpy array of shape (n_samples, n_features)
|
||||
window_length: Length of the filter window (must be odd and > polyorder)
|
||||
polyorder: Order of the polynomial fit
|
||||
axis: Axis along which to filter (0: column-wise, 1: row-wise)
|
||||
validate: Enable input validation checks when True
|
||||
|
||||
Returns:
|
||||
tuple: (smoothed_data, processing_time)
|
||||
- smoothed_data: Smoothed output array
|
||||
- processing_time: Execution time in seconds (None in validation mode)
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid input dimensions or filter parameters
|
||||
"""
|
||||
# Validation mode timing bypass
|
||||
processing_time = None
|
||||
|
||||
if validate:
|
||||
# Input integrity checks
|
||||
if input_data.ndim != 2:
|
||||
raise ValueError(f"Expected 2D input, got {input_data.ndim}D array")
|
||||
|
||||
if window_length % 2 == 0 or window_length < 3:
|
||||
raise ValueError("Window length must be odd integer ≥ 3")
|
||||
|
||||
if polyorder >= window_length:
|
||||
raise ValueError("Polynomial order must be < window length")
|
||||
|
||||
# Store original dtype and convert to float64 for numerical stability
|
||||
original_dtype = input_data.dtype
|
||||
working_data = input_data.astype(np.float64)
|
||||
|
||||
# Start performance timer
|
||||
timer_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Vectorized Savitzky-Golay application
|
||||
smoothed_data = savgol_filter(working_data,
|
||||
window_length=window_length,
|
||||
polyorder=polyorder,
|
||||
axis=axis,
|
||||
mode='mirror')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Filtering failed: {str(e)}") from e
|
||||
|
||||
# Stop timer and calculate duration
|
||||
processing_time = time.perf_counter() - timer_start
|
||||
|
||||
# Restore original data type with overflow protection
|
||||
return (
|
||||
np.clip(smoothed_data,
|
||||
0.0,
|
||||
1.0
|
||||
).astype(original_dtype),
|
||||
processing_time
|
||||
)
|
||||
|
||||
|
||||
def _blend_region_start(
|
||||
array: np.ndarray,
|
||||
region: np.ndarray,
|
||||
processed_boundary: int,
|
||||
blend_frames: int
|
||||
) -> None:
|
||||
"""Applies linear blend between last active frame and silent region start."""
|
||||
blend_length = min(blend_frames, region[0] - processed_boundary)
|
||||
if blend_length <= 0:
|
||||
return
|
||||
|
||||
pre_frame = array[region[0] - 1]
|
||||
for i in range(blend_length):
|
||||
weight = (i + 1) / (blend_length + 1)
|
||||
array[region[0] + i] = pre_frame * (1 - weight) + array[region[0] + i] * weight
|
||||
|
||||
def _blend_region_end(
|
||||
array: np.ndarray,
|
||||
region: np.ndarray,
|
||||
blend_frames: int
|
||||
) -> None:
|
||||
"""Applies linear blend between silent region end and next active frame."""
|
||||
blend_length = min(blend_frames, array.shape[0] - region[-1] - 1)
|
||||
if blend_length <= 0:
|
||||
return
|
||||
|
||||
post_frame = array[region[-1] + 1]
|
||||
for i in range(blend_length):
|
||||
weight = (i + 1) / (blend_length + 1)
|
||||
array[region[-1] - i] = post_frame * (1 - weight) + array[region[-1] - i] * weight
|
||||
|
||||
def find_low_value_regions(
|
||||
signal: np.ndarray,
|
||||
threshold: float,
|
||||
min_region_length: int = 5
|
||||
) -> list:
|
||||
"""Identifies contiguous regions in a signal where values fall below a threshold.
|
||||
|
||||
Args:
|
||||
signal: Input 1D array of numerical values
|
||||
threshold: Value threshold for identifying low regions
|
||||
min_region_length: Minimum consecutive samples required to qualify as a region
|
||||
|
||||
Returns:
|
||||
List of numpy arrays, each containing indices for a qualifying low-value region
|
||||
"""
|
||||
low_value_indices = np.where(signal < threshold)[0]
|
||||
contiguous_regions = []
|
||||
current_region_length = 0
|
||||
region_start_idx = 0
|
||||
|
||||
for i in range(1, len(low_value_indices)):
|
||||
# Check if current index continues a consecutive sequence
|
||||
if low_value_indices[i] != low_value_indices[i - 1] + 1:
|
||||
# Finalize previous region if it meets length requirement
|
||||
if current_region_length >= min_region_length:
|
||||
contiguous_regions.append(low_value_indices[region_start_idx:i])
|
||||
# Reset tracking for new potential region
|
||||
region_start_idx = i
|
||||
current_region_length = 0
|
||||
current_region_length += 1
|
||||
|
||||
# Add the final region if it qualifies
|
||||
if current_region_length >= min_region_length:
|
||||
contiguous_regions.append(low_value_indices[region_start_idx:])
|
||||
|
||||
return contiguous_regions
|
||||
|
||||
|
||||
def smooth_mouth_movements(
|
||||
blend_shapes: np.ndarray,
|
||||
processed_frames: int,
|
||||
volume: np.ndarray = None,
|
||||
silence_threshold: float = 0.001,
|
||||
min_silence_duration: int = 7,
|
||||
blend_window: int = 3
|
||||
) -> np.ndarray:
|
||||
"""Reduces jaw movement artifacts during silent periods in audio-driven animation.
|
||||
|
||||
Args:
|
||||
blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes]
|
||||
processed_frames: Number of already processed frames that shouldn't be modified
|
||||
volume: Audio volume array used to detect silent periods
|
||||
silence_threshold: Volume threshold for considering a frame silent
|
||||
min_silence_duration: Minimum consecutive silent frames to qualify for processing
|
||||
blend_window: Number of frames to smooth at region boundaries
|
||||
|
||||
Returns:
|
||||
Modified blend shape array with reduced mouth movements during silence
|
||||
"""
|
||||
if volume is None:
|
||||
return blend_shapes
|
||||
|
||||
# Detect silence periods using volume data
|
||||
silent_regions = find_low_value_regions(
|
||||
volume,
|
||||
threshold=silence_threshold,
|
||||
min_region_length=min_silence_duration
|
||||
)
|
||||
|
||||
for region_indices in silent_regions:
|
||||
# Reduce mouth blend shapes in silent region
|
||||
mouth_blend_indices = [ARKitBlendShape.index(name) for name in MOUTH_BLENDSHAPES]
|
||||
for region_indice in region_indices.tolist():
|
||||
blend_shapes[region_indice, mouth_blend_indices] *= 0.1
|
||||
|
||||
try:
|
||||
# Smooth transition into silent region
|
||||
_blend_region_start(
|
||||
blend_shapes,
|
||||
region_indices,
|
||||
processed_frames,
|
||||
blend_window
|
||||
)
|
||||
|
||||
# Smooth transition out of silent region
|
||||
_blend_region_end(
|
||||
blend_shapes,
|
||||
region_indices,
|
||||
blend_window
|
||||
)
|
||||
except IndexError as e:
|
||||
warnings.warn(f"Edge blending skipped at region {region_indices}: {str(e)}")
|
||||
|
||||
return blend_shapes
|
||||
|
||||
|
||||
def apply_frame_blending(
|
||||
blend_shapes: np.ndarray,
|
||||
processed_frames: int,
|
||||
initial_blend_window: int = 3,
|
||||
subsequent_blend_window: int = 5
|
||||
) -> np.ndarray:
|
||||
"""Smooths transitions between processed and unprocessed animation frames using linear blending.
|
||||
|
||||
Args:
|
||||
blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes]
|
||||
processed_frames: Number of already processed frames (0 means no previous processing)
|
||||
initial_blend_window: Max frames to blend at sequence start
|
||||
subsequent_blend_window: Max frames to blend between processed and new frames
|
||||
|
||||
Returns:
|
||||
Modified blend shape array with smoothed transitions
|
||||
"""
|
||||
if processed_frames > 0:
|
||||
# Blend transition between existing and new animation
|
||||
_blend_animation_segment(
|
||||
blend_shapes,
|
||||
transition_start=processed_frames,
|
||||
blend_window=subsequent_blend_window,
|
||||
reference_frame=blend_shapes[processed_frames - 1]
|
||||
)
|
||||
else:
|
||||
# Smooth initial frames from neutral expression (zeros)
|
||||
_blend_animation_segment(
|
||||
blend_shapes,
|
||||
transition_start=0,
|
||||
blend_window=initial_blend_window,
|
||||
reference_frame=np.zeros_like(blend_shapes[0])
|
||||
)
|
||||
return blend_shapes
|
||||
|
||||
|
||||
def _blend_animation_segment(
|
||||
array: np.ndarray,
|
||||
transition_start: int,
|
||||
blend_window: int,
|
||||
reference_frame: np.ndarray
|
||||
) -> None:
|
||||
"""Applies linear interpolation between reference frame and target frames.
|
||||
|
||||
Args:
|
||||
array: Blend shape array to modify
|
||||
transition_start: Starting index for blending
|
||||
blend_window: Maximum number of frames to blend
|
||||
reference_frame: The reference frame to blend from
|
||||
"""
|
||||
actual_blend_length = min(blend_window, array.shape[0] - transition_start)
|
||||
|
||||
for frame_offset in range(actual_blend_length):
|
||||
current_idx = transition_start + frame_offset
|
||||
blend_weight = (frame_offset + 1) / (actual_blend_length + 1)
|
||||
|
||||
# Linear interpolation: ref_frame * (1 - weight) + current_frame * weight
|
||||
array[current_idx] = (reference_frame * (1 - blend_weight)
|
||||
+ array[current_idx] * blend_weight)
|
||||
|
||||
|
||||
BROW1 = np.array([[0.05597309, 0.05727929, 0.07995935, 0. , 0. ],
|
||||
[0.00757574, 0.00936678, 0.12242376, 0. , 0. ],
|
||||
[0. , 0. , 0.14943372, 0.04535687, 0.04264118],
|
||||
[0. , 0. , 0.18015374, 0.09019445, 0.08736137],
|
||||
[0. , 0. , 0.20549579, 0.12802747, 0.12450772],
|
||||
[0. , 0. , 0.21098022, 0.1369939 , 0.13343132],
|
||||
[0. , 0. , 0.20904602, 0.13903855, 0.13562402],
|
||||
[0. , 0. , 0.20365039, 0.13977394, 0.13653506],
|
||||
[0. , 0. , 0.19714841, 0.14096624, 0.13805152],
|
||||
[0. , 0. , 0.20325482, 0.17303431, 0.17028868],
|
||||
[0. , 0. , 0.21990852, 0.20164253, 0.19818163],
|
||||
[0. , 0. , 0.23858181, 0.21908803, 0.21540019],
|
||||
[0. , 0. , 0.2567876 , 0.23762083, 0.23396946],
|
||||
[0. , 0. , 0.34093422, 0.27898848, 0.27651772],
|
||||
[0. , 0. , 0.45288125, 0.35008961, 0.34887788],
|
||||
[0. , 0. , 0.48076251, 0.36878952, 0.36778417],
|
||||
[0. , 0. , 0.47798249, 0.36362219, 0.36145973],
|
||||
[0. , 0. , 0.46186113, 0.33865979, 0.33597934],
|
||||
[0. , 0. , 0.45264384, 0.33152157, 0.32891783],
|
||||
[0. , 0. , 0.40986338, 0.29646468, 0.2945672 ],
|
||||
[0. , 0. , 0.35628179, 0.23356403, 0.23155804],
|
||||
[0. , 0. , 0.30870566, 0.1780673 , 0.17637439],
|
||||
[0. , 0. , 0.25293985, 0.10710219, 0.10622486],
|
||||
[0. , 0. , 0.18743332, 0.03252602, 0.03244236],
|
||||
[0.02340254, 0.02364671, 0.15736724, 0. , 0. ]])
|
||||
|
||||
BROW2 = np.array([
|
||||
[0. , 0. , 0.09799323, 0.05944436, 0.05002545],
|
||||
[0. , 0. , 0.09780276, 0.07674237, 0.01636653],
|
||||
[0. , 0. , 0.11136199, 0.1027964 , 0.04249811],
|
||||
[0. , 0. , 0.26883412, 0.15861984, 0.15832305],
|
||||
[0. , 0. , 0.42191629, 0.27038204, 0.27007768],
|
||||
[0. , 0. , 0.3404977 , 0.21633868, 0.21597538],
|
||||
[0. , 0. , 0.27301185, 0.17176409, 0.17134669],
|
||||
[0. , 0. , 0.25960442, 0.15670464, 0.15622253],
|
||||
[0. , 0. , 0.22877269, 0.11805892, 0.11754539],
|
||||
[0. , 0. , 0.1451605 , 0.06389034, 0.0636282 ]])
|
||||
|
||||
BROW3 = np.array([
|
||||
[0. , 0. , 0.124 , 0.0295, 0.0295],
|
||||
[0. , 0. , 0.267 , 0.184 , 0.184 ],
|
||||
[0. , 0. , 0.359 , 0.2765, 0.2765],
|
||||
[0. , 0. , 0.3945, 0.3125, 0.3125],
|
||||
[0. , 0. , 0.4125, 0.331 , 0.331 ],
|
||||
[0. , 0. , 0.4235, 0.3445, 0.3445],
|
||||
[0. , 0. , 0.4085, 0.3305, 0.3305],
|
||||
[0. , 0. , 0.3695, 0.294 , 0.294 ],
|
||||
[0. , 0. , 0.2835, 0.213 , 0.213 ],
|
||||
[0. , 0. , 0.1795, 0.1005, 0.1005],
|
||||
[0. , 0. , 0.108 , 0.014 , 0.014 ]])
|
||||
|
||||
|
||||
import numpy as np
|
||||
from scipy.ndimage import label
|
||||
|
||||
|
||||
def apply_random_brow_movement(input_exp, volume):
|
||||
FRAME_SEGMENT = 150
|
||||
HOLD_THRESHOLD = 10
|
||||
VOLUME_THRESHOLD = 0.08
|
||||
MIN_REGION_LENGTH = 6
|
||||
STRENGTH_RANGE = (0.7, 1.3)
|
||||
|
||||
BROW_PEAKS = {
|
||||
0: np.argmax(BROW1[:, 2]),
|
||||
1: np.argmax(BROW2[:, 2])
|
||||
}
|
||||
|
||||
for seg_start in range(0, len(volume), FRAME_SEGMENT):
|
||||
seg_end = min(seg_start + FRAME_SEGMENT, len(volume))
|
||||
seg_volume = volume[seg_start:seg_end]
|
||||
|
||||
candidate_regions = []
|
||||
|
||||
high_vol_mask = seg_volume > VOLUME_THRESHOLD
|
||||
labeled_array, num_features = label(high_vol_mask)
|
||||
|
||||
for i in range(1, num_features + 1):
|
||||
region = (labeled_array == i)
|
||||
region_indices = np.where(region)[0]
|
||||
if len(region_indices) >= MIN_REGION_LENGTH:
|
||||
candidate_regions.append(region_indices)
|
||||
|
||||
if candidate_regions:
|
||||
selected_region = candidate_regions[np.random.choice(len(candidate_regions))]
|
||||
region_start = selected_region[0]
|
||||
region_end = selected_region[-1]
|
||||
region_length = region_end - region_start + 1
|
||||
|
||||
brow_idx = np.random.randint(0, 2)
|
||||
base_brow = BROW1 if brow_idx == 0 else BROW2
|
||||
peak_idx = BROW_PEAKS[brow_idx]
|
||||
|
||||
if region_length > HOLD_THRESHOLD:
|
||||
local_max_pos = seg_volume[selected_region].argmax()
|
||||
global_peak_frame = seg_start + selected_region[local_max_pos]
|
||||
|
||||
rise_anim = base_brow[:peak_idx + 1]
|
||||
hold_frame = base_brow[peak_idx:peak_idx + 1]
|
||||
|
||||
insert_start = max(global_peak_frame - peak_idx, seg_start)
|
||||
insert_end = min(global_peak_frame + (region_length - local_max_pos), seg_end)
|
||||
|
||||
strength = np.random.uniform(*STRENGTH_RANGE)
|
||||
|
||||
if insert_start + len(rise_anim) <= seg_end:
|
||||
input_exp[insert_start:insert_start + len(rise_anim), :5] += rise_anim * strength
|
||||
hold_duration = insert_end - (insert_start + len(rise_anim))
|
||||
if hold_duration > 0:
|
||||
input_exp[insert_start + len(rise_anim):insert_end, :5] += np.tile(hold_frame * strength,
|
||||
(hold_duration, 1))
|
||||
else:
|
||||
anim_length = base_brow.shape[0]
|
||||
insert_pos = seg_start + region_start + (region_length - anim_length) // 2
|
||||
insert_pos = max(seg_start, min(insert_pos, seg_end - anim_length))
|
||||
|
||||
if insert_pos + anim_length <= seg_end:
|
||||
strength = np.random.uniform(*STRENGTH_RANGE)
|
||||
input_exp[insert_pos:insert_pos + anim_length, :5] += base_brow * strength
|
||||
|
||||
return np.clip(input_exp, 0, 1)
|
||||
Reference in New Issue
Block a user