feat: Initial commit

This commit is contained in:
fdyuandong
2025-04-17 23:14:24 +08:00
commit ca93dd0572
51 changed files with 7904 additions and 0 deletions

7
models/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
from .builder import build_model
from .default import DefaultEstimator
# Backbones
from .network import Audio2Expression

13
models/builder.py Normal file
View File

@@ -0,0 +1,13 @@
"""
Modified by https://github.com/Pointcept/Pointcept
"""
from utils.registry import Registry
MODELS = Registry("models")
MODULES = Registry("modules")
def build_model(cfg):
"""Build models."""
return MODELS.build(cfg)

25
models/default.py Normal file
View File

@@ -0,0 +1,25 @@
import torch.nn as nn
from models.losses import build_criteria
from .builder import MODELS, build_model
@MODELS.register_module()
class DefaultEstimator(nn.Module):
def __init__(self, backbone=None, criteria=None):
super().__init__()
self.backbone = build_model(backbone)
self.criteria = build_criteria(criteria)
def forward(self, input_dict):
pred_exp = self.backbone(input_dict)
# train
if self.training:
loss = self.criteria(pred_exp, input_dict["gt_exp"])
return dict(loss=loss)
# eval
elif "gt_exp" in input_dict.keys():
loss = self.criteria(pred_exp, input_dict["gt_exp"])
return dict(loss=loss, pred_exp=pred_exp)
# infer
else:
return dict(pred_exp=pred_exp)

248
models/encoder/wav2vec.py Normal file
View File

@@ -0,0 +1,248 @@
import numpy as np
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from dataclasses import dataclass
from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from transformers.file_utils import ModelOutput
_CONFIG_FOR_DOC = "Wav2Vec2Config"
_HIDDEN_STATES_START_POSITION = 2
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.Tensor] = None,
min_masks: int = 0,
) -> np.ndarray:
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
lengths = np.full(num_mask, mask_length)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True
return mask
# linear interpolation layer
def linear_interpolation(features, input_fps, output_fps, output_len=None):
features = features.transpose(1, 2)
seq_len = features.shape[2] / float(input_fps)
if output_len is None:
output_len = int(seq_len * output_fps)
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)
class Wav2Vec2Model(Wav2Vec2Model):
def __init__(self, config):
super().__init__(config)
self.lm_head = nn.Linear(1024, 32)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
frame_num=None
):
self.config.output_attentions = True
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.feature_extractor(input_values)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = linear_interpolation(hidden_states, 50, 30, output_len=frame_num)
if attention_mask is not None:
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
attention_mask = torch.zeros(
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
)
attention_mask[
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
hidden_states = self.feature_projection(hidden_states)[0]
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states,) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@dataclass
class SpeechClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class Wav2Vec2ClassificationHead(nn.Module):
"""Head for wav2vec classification task."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.pooling_mode = config.pooling_mode
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.classifier = Wav2Vec2ClassificationHead(config)
self.init_weights()
def freeze_feature_extractor(self):
self.wav2vec2.feature_extractor._freeze_parameters()
def merged_strategy(
self,
hidden_states,
mode="mean"
):
if mode == "mean":
outputs = torch.mean(hidden_states, dim=1)
elif mode == "sum":
outputs = torch.sum(hidden_states, dim=1)
elif mode == "max":
outputs = torch.max(hidden_states, dim=1)[0]
else:
raise Exception(
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
return outputs
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
frame_num=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states1 = linear_interpolation(hidden_states, 50, 30, output_len=frame_num)
hidden_states = self.merged_strategy(hidden_states1, mode=self.pooling_mode)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SpeechClassifierOutput(
loss=loss,
logits=logits,
hidden_states=hidden_states1,
attentions=outputs.attentions,
)

87
models/encoder/wavlm.py Normal file
View File

@@ -0,0 +1,87 @@
import numpy as np
import torch
from transformers import WavLMModel
from transformers.modeling_outputs import Wav2Vec2BaseModelOutput
from typing import Optional, Tuple, Union
import torch.nn.functional as F
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
output_features = F.interpolate(
features, size=output_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)
# the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model # noqa: E501
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
class WavLMModel(WavLMModel):
def __init__(self, config):
super().__init__(config)
def _freeze_wav2vec2_parameters(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
mask_time_indices: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
frame_num=None,
interpolate_pos: int = 0,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose(1, 2)
if interpolate_pos == 0:
extract_features = linear_interpolation(
extract_features, output_len=frame_num)
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if interpolate_pos == 1:
hidden_states = linear_interpolation(
hidden_states, output_len=frame_num)
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)

View File

@@ -0,0 +1,4 @@
from .builder import build_criteria
from .misc import CrossEntropyLoss, SmoothCELoss, DiceLoss, FocalLoss, BinaryFocalLoss, L1Loss
from .lovasz import LovaszLoss

28
models/losses/builder.py Normal file
View File

@@ -0,0 +1,28 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
from utils.registry import Registry
LOSSES = Registry("losses")
class Criteria(object):
def __init__(self, cfg=None):
self.cfg = cfg if cfg is not None else []
self.criteria = []
for loss_cfg in self.cfg:
self.criteria.append(LOSSES.build(cfg=loss_cfg))
def __call__(self, pred, target):
if len(self.criteria) == 0:
# loss computation occur in model
return pred
loss = 0
for c in self.criteria:
loss += c(pred, target)
return loss
def build_criteria(cfg):
return Criteria(cfg)

253
models/losses/lovasz.py Normal file
View File

@@ -0,0 +1,253 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
from typing import Optional
from itertools import filterfalse
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from .builder import LOSSES
BINARY_MODE: str = "binary"
MULTICLASS_MODE: str = "multiclass"
MULTILABEL_MODE: str = "multilabel"
def _lovasz_grad(gt_sorted):
"""Compute gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1.0 - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def _lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Logits at each pixel (between -infinity and +infinity)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
loss = mean(
_lovasz_hinge_flat(
*_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)
)
for log, lab in zip(logits, labels)
)
else:
loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore))
return loss
def _lovasz_hinge_flat(logits, labels):
"""Binary Lovasz hinge loss
Args:
logits: [P] Logits at each prediction (between -infinity and +infinity)
labels: [P] Tensor, binary ground truth labels (0 or 1)
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.0
signs = 2.0 * labels.float() - 1.0
errors = 1.0 - logits * signs
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = _lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), grad)
return loss
def _flatten_binary_scores(scores, labels, ignore=None):
"""Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = scores.view(-1)
labels = labels.view(-1)
if ignore is None:
return scores, labels
valid = labels != ignore
vscores = scores[valid]
vlabels = labels[valid]
return vscores, vlabels
def _lovasz_softmax(
probas, labels, classes="present", class_seen=None, per_image=False, ignore=None
):
"""Multi-class Lovasz-Softmax loss
Args:
@param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
@param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
@param per_image: compute the loss per image instead of per batch
@param ignore: void class labels
"""
if per_image:
loss = mean(
_lovasz_softmax_flat(
*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore),
classes=classes
)
for prob, lab in zip(probas, labels)
)
else:
loss = _lovasz_softmax_flat(
*_flatten_probas(probas, labels, ignore),
classes=classes,
class_seen=class_seen
)
return loss
def _lovasz_softmax_flat(probas, labels, classes="present", class_seen=None):
"""Multi-class Lovasz-Softmax loss
Args:
@param probas: [P, C] Class probabilities at each prediction (between 0 and 1)
@param labels: [P] Tensor, ground truth labels (between 0 and C - 1)
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
if probas.numel() == 0:
# only void pixels, the gradients should be 0
return probas * 0.0
C = probas.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ["all", "present"] else classes
# for c in class_to_sum:
for c in labels.unique():
if class_seen is None:
fg = (labels == c).type_as(probas) # foreground for class c
if classes == "present" and fg.sum() == 0:
continue
if C == 1:
if len(classes) > 1:
raise ValueError("Sigmoid output possible only with 1 class")
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = (fg - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
else:
if c in class_seen:
fg = (labels == c).type_as(probas) # foreground for class c
if classes == "present" and fg.sum() == 0:
continue
if C == 1:
if len(classes) > 1:
raise ValueError("Sigmoid output possible only with 1 class")
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = (fg - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
return mean(losses)
def _flatten_probas(probas, labels, ignore=None):
"""Flattens predictions in the batch"""
if probas.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probas.size()
probas = probas.view(B, 1, H, W)
C = probas.size(1)
probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C]
probas = probas.contiguous().view(-1, C) # [P, C]
labels = labels.view(-1)
if ignore is None:
return probas, labels
valid = labels != ignore
vprobas = probas[valid]
vlabels = labels[valid]
return vprobas, vlabels
def isnan(x):
return x != x
def mean(values, ignore_nan=False, empty=0):
"""Nan-mean compatible with generators."""
values = iter(values)
if ignore_nan:
values = filterfalse(isnan, values)
try:
n = 1
acc = next(values)
except StopIteration:
if empty == "raise":
raise ValueError("Empty mean")
return empty
for n, v in enumerate(values, 2):
acc += v
if n == 1:
return acc
return acc / n
@LOSSES.register_module()
class LovaszLoss(_Loss):
def __init__(
self,
mode: str,
class_seen: Optional[int] = None,
per_image: bool = False,
ignore_index: Optional[int] = None,
loss_weight: float = 1.0,
):
"""Lovasz loss for segmentation task.
It supports binary, multiclass and multilabel cases
Args:
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
per_image: If True loss computed per each image and then averaged, else computed per whole batch
Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
Reference
https://github.com/BloodAxe/pytorch-toolbelt
"""
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__()
self.mode = mode
self.ignore_index = ignore_index
self.per_image = per_image
self.class_seen = class_seen
self.loss_weight = loss_weight
def forward(self, y_pred, y_true):
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
loss = _lovasz_hinge(
y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index
)
elif self.mode == MULTICLASS_MODE:
y_pred = y_pred.softmax(dim=1)
loss = _lovasz_softmax(
y_pred,
y_true,
class_seen=self.class_seen,
per_image=self.per_image,
ignore=self.ignore_index,
)
else:
raise ValueError("Wrong mode {}.".format(self.mode))
return loss * self.loss_weight

241
models/losses/misc.py Normal file
View File

@@ -0,0 +1,241 @@
"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .builder import LOSSES
@LOSSES.register_module()
class CrossEntropyLoss(nn.Module):
def __init__(
self,
weight=None,
size_average=None,
reduce=None,
reduction="mean",
label_smoothing=0.0,
loss_weight=1.0,
ignore_index=-1,
):
super(CrossEntropyLoss, self).__init__()
weight = torch.tensor(weight).cuda() if weight is not None else None
self.loss_weight = loss_weight
self.loss = nn.CrossEntropyLoss(
weight=weight,
size_average=size_average,
ignore_index=ignore_index,
reduce=reduce,
reduction=reduction,
label_smoothing=label_smoothing,
)
def forward(self, pred, target):
return self.loss(pred, target) * self.loss_weight
@LOSSES.register_module()
class L1Loss(nn.Module):
def __init__(
self,
weight=None,
size_average=None,
reduce=None,
reduction="mean",
label_smoothing=0.0,
loss_weight=1.0,
ignore_index=-1,
):
super(L1Loss, self).__init__()
weight = torch.tensor(weight).cuda() if weight is not None else None
self.loss_weight = loss_weight
self.loss = nn.L1Loss(reduction='mean')
def forward(self, pred, target):
return self.loss(pred, target[:,None]) * self.loss_weight
@LOSSES.register_module()
class SmoothCELoss(nn.Module):
def __init__(self, smoothing_ratio=0.1):
super(SmoothCELoss, self).__init__()
self.smoothing_ratio = smoothing_ratio
def forward(self, pred, target):
eps = self.smoothing_ratio
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).total(dim=1)
loss = loss[torch.isfinite(loss)].mean()
return loss
@LOSSES.register_module()
class BinaryFocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=0.5, logits=True, reduce=True, loss_weight=1.0):
"""Binary Focal Loss
<https://arxiv.org/abs/1708.02002>`
"""
super(BinaryFocalLoss, self).__init__()
assert 0 < alpha < 1
self.gamma = gamma
self.alpha = alpha
self.logits = logits
self.reduce = reduce
self.loss_weight = loss_weight
def forward(self, pred, target, **kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction with shape (N)
target (torch.Tensor): The ground truth. If containing class
indices, shape (N) where each value is 0≤targets[i]≤1, If containing class probabilities,
same shape as the input.
Returns:
torch.Tensor: The calculated loss
"""
if self.logits:
bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
else:
bce = F.binary_cross_entropy(pred, target, reduction="none")
pt = torch.exp(-bce)
alpha = self.alpha * target + (1 - self.alpha) * (1 - target)
focal_loss = alpha * (1 - pt) ** self.gamma * bce
if self.reduce:
focal_loss = torch.mean(focal_loss)
return focal_loss * self.loss_weight
@LOSSES.register_module()
class FocalLoss(nn.Module):
def __init__(
self, gamma=2.0, alpha=0.5, reduction="mean", loss_weight=1.0, ignore_index=-1
):
"""Focal Loss
<https://arxiv.org/abs/1708.02002>`
"""
super(FocalLoss, self).__init__()
assert reduction in (
"mean",
"sum",
), "AssertionError: reduction should be 'mean' or 'sum'"
assert isinstance(
alpha, (float, list)
), "AssertionError: alpha should be of type float"
assert isinstance(gamma, float), "AssertionError: gamma should be of type float"
assert isinstance(
loss_weight, float
), "AssertionError: loss_weight should be of type float"
assert isinstance(ignore_index, int), "ignore_index must be of type int"
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
self.ignore_index = ignore_index
def forward(self, pred, target, **kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction with shape (N, C) where C = number of classes.
target (torch.Tensor): The ground truth. If containing class
indices, shape (N) where each value is 0≤targets[i]≤C1, If containing class probabilities,
same shape as the input.
Returns:
torch.Tensor: The calculated loss
"""
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
pred = pred.transpose(0, 1)
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
pred = pred.reshape(pred.size(0), -1)
# [C, N] -> [N, C]
pred = pred.transpose(0, 1).contiguous()
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
target = target.view(-1).contiguous()
assert pred.size(0) == target.size(
0
), "The shape of pred doesn't match the shape of target"
valid_mask = target != self.ignore_index
target = target[valid_mask]
pred = pred[valid_mask]
if len(target) == 0:
return 0.0
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes)
alpha = self.alpha
if isinstance(alpha, list):
alpha = pred.new_tensor(alpha)
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow(
self.gamma
)
loss = (
F.binary_cross_entropy_with_logits(pred, target, reduction="none")
* focal_weight
)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.total()
return self.loss_weight * loss
@LOSSES.register_module()
class DiceLoss(nn.Module):
def __init__(self, smooth=1, exponent=2, loss_weight=1.0, ignore_index=-1):
"""DiceLoss.
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
"""
super(DiceLoss, self).__init__()
self.smooth = smooth
self.exponent = exponent
self.loss_weight = loss_weight
self.ignore_index = ignore_index
def forward(self, pred, target, **kwargs):
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
pred = pred.transpose(0, 1)
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
pred = pred.reshape(pred.size(0), -1)
# [C, N] -> [N, C]
pred = pred.transpose(0, 1).contiguous()
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
target = target.view(-1).contiguous()
assert pred.size(0) == target.size(
0
), "The shape of pred doesn't match the shape of target"
valid_mask = target != self.ignore_index
target = target[valid_mask]
pred = pred[valid_mask]
pred = F.softmax(pred, dim=1)
num_classes = pred.shape[1]
target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes
)
total_loss = 0
for i in range(num_classes):
if i != self.ignore_index:
num = torch.sum(torch.mul(pred[:, i], target[:, i])) * 2 + self.smooth
den = (
torch.sum(
pred[:, i].pow(self.exponent) + target[:, i].pow(self.exponent)
)
+ self.smooth
)
dice_loss = 1 - num / den
total_loss += dice_loss
loss = total_loss / num_classes
return self.loss_weight * loss

646
models/network.py Normal file
View File

@@ -0,0 +1,646 @@
import math
import os.path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio as ta
from models.encoder.wav2vec import Wav2Vec2Model
from models.encoder.wavlm import WavLMModel
from models.builder import MODELS
from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config
@MODELS.register_module("Audio2Expression")
class Audio2Expression(nn.Module):
def __init__(self,
device: torch.device = None,
pretrained_encoder_type: str = 'wav2vec',
pretrained_encoder_path: str = '',
wav2vec2_config_path: str = '',
num_identity_classes: int = 0,
identity_feat_dim: int = 64,
hidden_dim: int = 512,
expression_dim: int = 52,
norm_type: str = 'ln',
decoder_depth: int = 3,
use_transformer: bool = False,
num_attention_heads: int = 8,
num_transformer_layers: int = 6,
):
super().__init__()
self.device = device
# Initialize audio feature encoder
if pretrained_encoder_type == 'wav2vec':
if os.path.exists(pretrained_encoder_path):
self.audio_encoder = Wav2Vec2Model.from_pretrained(pretrained_encoder_path)
else:
config = Wav2Vec2Config.from_pretrained(wav2vec2_config_path)
self.audio_encoder = Wav2Vec2Model(config)
encoder_output_dim = 768
elif pretrained_encoder_type == 'wavlm':
self.audio_encoder = WavLMModel.from_pretrained(pretrained_encoder_path)
encoder_output_dim = 768
else:
raise NotImplementedError(f"Encoder type {pretrained_encoder_type} not supported")
self.audio_encoder.feature_extractor._freeze_parameters()
self.feature_projection = nn.Linear(encoder_output_dim, hidden_dim)
self.identity_encoder = AudioIdentityEncoder(
hidden_dim,
num_identity_classes,
identity_feat_dim,
use_transformer,
num_attention_heads,
num_transformer_layers
)
self.decoder = nn.ModuleList([
nn.Sequential(*[
ConvNormRelu(hidden_dim, hidden_dim, norm=norm_type)
for _ in range(decoder_depth)
])
])
self.output_proj = nn.Linear(hidden_dim, expression_dim)
def freeze_encoder_parameters(self, do_freeze=False):
for name, param in self.audio_encoder.named_parameters():
if('feature_extractor' in name):
param.requires_grad = False
else:
param.requires_grad = (not do_freeze)
def forward(self, input_dict):
if 'time_steps' not in input_dict:
audio_length = input_dict['input_audio_array'].shape[1]
time_steps = math.ceil(audio_length / 16000 * 30)
else:
time_steps = input_dict['time_steps']
# Process audio through encoder
audio_input = input_dict['input_audio_array'].flatten(start_dim=1)
hidden_states = self.audio_encoder(audio_input, frame_num=time_steps).last_hidden_state
# Project features to hidden dimension
audio_features = self.feature_projection(hidden_states).transpose(1, 2)
# Process identity-conditioned features
audio_features = self.identity_encoder(audio_features, identity=input_dict['id_idx'])
# Refine features through decoder
audio_features = self.decoder[0](audio_features)
# Generate output parameters
audio_features = audio_features.permute(0, 2, 1)
expression_params = self.output_proj(audio_features)
return torch.sigmoid(expression_params)
class AudioIdentityEncoder(nn.Module):
def __init__(self,
hidden_dim,
num_identity_classes=0,
identity_feat_dim=64,
use_transformer=False,
num_attention_heads = 8,
num_transformer_layers = 6,
dropout_ratio=0.1,
):
super().__init__()
in_dim = hidden_dim + identity_feat_dim
self.id_mlp = nn.Conv1d(num_identity_classes, identity_feat_dim, 1, 1)
self.first_net = SeqTranslator1D(in_dim, hidden_dim,
min_layers_num=3,
residual=True,
norm='ln'
)
self.grus = nn.GRU(hidden_dim, hidden_dim, 1, batch_first=True)
self.dropout = nn.Dropout(dropout_ratio)
self.use_transformer = use_transformer
if(self.use_transformer):
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_attention_heads, dim_feedforward= 2 * hidden_dim, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers)
def forward(self,
audio_features: torch.Tensor,
identity: torch.Tensor = None,
time_steps: int = None) -> tuple:
audio_features = self.dropout(audio_features)
identity = identity.reshape(identity.shape[0], -1, 1).repeat(1, 1, audio_features.shape[2]).to(torch.float32)
identity = self.id_mlp(identity)
audio_features = torch.cat([audio_features, identity], dim=1)
x = self.first_net(audio_features)
if time_steps is not None:
x = F.interpolate(x, size=time_steps, align_corners=False, mode='linear')
if(self.use_transformer):
x = x.permute(0, 2, 1)
x = self.transformer_encoder(x)
x = x.permute(0, 2, 1)
return x
class ConvNormRelu(nn.Module):
'''
(B,C_in,H,W) -> (B, C_out, H, W)
there exist some kernel size that makes the result is not H/s
'''
def __init__(self,
in_channels,
out_channels,
type='1d',
leaky=False,
downsample=False,
kernel_size=None,
stride=None,
padding=None,
p=0,
groups=1,
residual=False,
norm='bn'):
'''
conv-bn-relu
'''
super(ConvNormRelu, self).__init__()
self.residual = residual
self.norm_type = norm
# kernel_size = k
# stride = s
if kernel_size is None and stride is None:
if not downsample:
kernel_size = 3
stride = 1
else:
kernel_size = 4
stride = 2
if padding is None:
if isinstance(kernel_size, int) and isinstance(stride, tuple):
padding = tuple(int((kernel_size - st) / 2) for st in stride)
elif isinstance(kernel_size, tuple) and isinstance(stride, int):
padding = tuple(int((ks - stride) / 2) for ks in kernel_size)
elif isinstance(kernel_size, tuple) and isinstance(stride, tuple):
padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride))
else:
padding = int((kernel_size - stride) / 2)
if self.residual:
if downsample:
if type == '1d':
self.residual_layer = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
)
elif type == '2d':
self.residual_layer = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
)
else:
if in_channels == out_channels:
self.residual_layer = nn.Identity()
else:
if type == '1d':
self.residual_layer = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
)
elif type == '2d':
self.residual_layer = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
)
in_channels = in_channels * groups
out_channels = out_channels * groups
if type == '1d':
self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
groups=groups)
self.norm = nn.BatchNorm1d(out_channels)
self.dropout = nn.Dropout(p=p)
elif type == '2d':
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
groups=groups)
self.norm = nn.BatchNorm2d(out_channels)
self.dropout = nn.Dropout2d(p=p)
if norm == 'gn':
self.norm = nn.GroupNorm(2, out_channels)
elif norm == 'ln':
self.norm = nn.LayerNorm(out_channels)
if leaky:
self.relu = nn.LeakyReLU(negative_slope=0.2)
else:
self.relu = nn.ReLU()
def forward(self, x, **kwargs):
if self.norm_type == 'ln':
out = self.dropout(self.conv(x))
out = self.norm(out.transpose(1,2)).transpose(1,2)
else:
out = self.norm(self.dropout(self.conv(x)))
if self.residual:
residual = self.residual_layer(x)
out += residual
return self.relu(out)
""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
class SeqTranslator1D(nn.Module):
'''
(B, C, T)->(B, C_out, T)
'''
def __init__(self,
C_in,
C_out,
kernel_size=None,
stride=None,
min_layers_num=None,
residual=True,
norm='bn'
):
super(SeqTranslator1D, self).__init__()
conv_layers = nn.ModuleList([])
conv_layers.append(ConvNormRelu(
in_channels=C_in,
out_channels=C_out,
type='1d',
kernel_size=kernel_size,
stride=stride,
residual=residual,
norm=norm
))
self.num_layers = 1
if min_layers_num is not None and self.num_layers < min_layers_num:
while self.num_layers < min_layers_num:
conv_layers.append(ConvNormRelu(
in_channels=C_out,
out_channels=C_out,
type='1d',
kernel_size=kernel_size,
stride=stride,
residual=residual,
norm=norm
))
self.num_layers += 1
self.conv_layers = nn.Sequential(*conv_layers)
def forward(self, x):
return self.conv_layers(x)
def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
"""
:param audio: 1 x T tensor containing a 16kHz audio signal
:param frame_rate: frame rate for video (we need one audio chunk per video frame)
:param chunk_size: number of audio samples per chunk
:return: num_chunks x chunk_size tensor containing sliced audio
"""
samples_per_frame = 16000 // frame_rate
padding = (chunk_size - samples_per_frame) // 2
audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
return audio
""" https://github.com/facebookresearch/meshtalk """
class MeshtalkEncoder(nn.Module):
def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'):
"""
:param latent_dim: size of the latent audio embedding
:param model_name: name of the model, used to load and save the model
"""
super().__init__()
self.melspec = ta.transforms.MelSpectrogram(
sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80
)
conv_len = 5
self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len)
self.weights_init(self.convert_dimensions)
self.receptive_field = conv_len
convs = []
for i in range(6):
dilation = 2 * (i % 3 + 1)
self.receptive_field += (conv_len - 1) * dilation
convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)]
self.weights_init(convs[-1])
self.convs = torch.nn.ModuleList(convs)
self.code = torch.nn.Linear(128, latent_dim)
self.apply(lambda x: self.weights_init(x))
def weights_init(self, m):
if isinstance(m, torch.nn.Conv1d):
torch.nn.init.xavier_uniform_(m.weight)
try:
torch.nn.init.constant_(m.bias, .01)
except:
pass
def forward(self, audio: torch.Tensor):
"""
:param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame
:return: code: B x T x latent_dim Tensor containing a latent audio code/embedding
"""
B, T = audio.shape[0], audio.shape[1]
x = self.melspec(audio).squeeze(1)
x = torch.log(x.clamp(min=1e-10, max=None))
if T == 1:
x = x.unsqueeze(1)
# Convert to the right dimensionality
x = x.view(-1, x.shape[2], x.shape[3])
x = F.leaky_relu(self.convert_dimensions(x), .2)
# Process stacks
for conv in self.convs:
x_ = F.leaky_relu(conv(x), .2)
if self.training:
x_ = F.dropout(x_, .2)
l = (x.shape[2] - x_.shape[2]) // 2
x = (x[:, :, l:-l] + x_) / 2
x = torch.mean(x, dim=-1)
x = x.view(B, T, x.shape[-1])
x = self.code(x)
return {"code": x}
class PeriodicPositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=64):
super(PeriodicPositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(period, d_model)
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, period, d_model)
repeat_num = (max_seq_len//period) + 1
pe = pe.repeat(1, repeat_num, 1) # (1, repeat_num, period, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
# print(self.pe.shape, x.shape)
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class GeneratorTransformer(nn.Module):
def __init__(self,
n_poses,
each_dim: list,
dim_list: list,
training=True,
device=None,
identity=False,
num_classes=0,
):
super().__init__()
self.training = training
self.device = device
self.gen_length = n_poses
norm = 'ln'
in_dim = 256
out_dim = 256
self.encoder_choice = 'faceformer'
self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h"
self.audio_encoder.feature_extractor._freeze_parameters()
self.audio_feature_map = nn.Linear(768, in_dim)
self.audio_middle = AudioEncoder(in_dim, out_dim, False, num_classes)
self.dim_list = dim_list
self.decoder = nn.ModuleList()
self.final_out = nn.ModuleList()
self.hidden_size = 768
self.transformer_de_layer = nn.TransformerDecoderLayer(
d_model=self.hidden_size,
nhead=4,
dim_feedforward=self.hidden_size*2,
batch_first=True
)
self.face_decoder = nn.TransformerDecoder(self.transformer_de_layer, num_layers=4)
self.feature2face = nn.Linear(256, self.hidden_size)
self.position_embeddings = PeriodicPositionalEncoding(self.hidden_size, period=64, max_seq_len=64)
self.id_maping = nn.Linear(12,self.hidden_size)
self.decoder.append(self.face_decoder)
self.final_out.append(nn.Linear(self.hidden_size, 32))
def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None):
if gt_poses is None:
time_steps = 64
else:
time_steps = gt_poses.shape[1]
# vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
if self.encoder_choice == 'meshtalk':
in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000)
feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2)
elif self.encoder_choice == 'faceformer':
hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state
feature = self.audio_feature_map(hidden_states).transpose(1, 2)
else:
feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps)
feature, _ = self.audio_middle(feature, id=None)
feature = self.feature2face(feature.permute(0,2,1))
id = id.unsqueeze(1).repeat(1,64,1).to(torch.float32)
id_feature = self.id_maping(id)
id_feature = self.position_embeddings(id_feature)
for i in range(self.decoder.__len__()):
mid = self.decoder[i](tgt=id_feature, memory=feature)
out = self.final_out[i](mid)
return out, None
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
output_features = F.interpolate(
features, size=output_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)
def init_biased_mask(n_head, max_seq_len, period):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
slopes = torch.Tensor(get_slopes(n_head))
bias = torch.div(
torch.arange(start=0, end=max_seq_len,
step=period).unsqueeze(1).repeat(1, period).view(-1),
period,
rounding_mode='floor')
bias = -torch.flip(bias, dims=[0])
alibi = torch.zeros(max_seq_len, max_seq_len)
for i in range(max_seq_len):
alibi[i, :i + 1] = bias[-(i + 1):]
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
mask = (torch.triu(torch.ones(max_seq_len,
max_seq_len)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
mask == 1, float(0.0))
mask = mask.unsqueeze(0) + alibi
return mask
# Alignment Bias
def enc_dec_mask(device, T, S):
mask = torch.ones(T, S)
for i in range(T):
mask[i, i] = 0
return (mask == 1).to(device=device)
# Periodic Positional Encoding
class PeriodicPositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, period=25, max_seq_len=3000):
super(PeriodicPositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(period, d_model)
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, period, d_model)
repeat_num = (max_seq_len // period) + 1
pe = pe.repeat(1, repeat_num, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class BaseModel(nn.Module):
"""Base class for all models."""
def __init__(self):
super(BaseModel, self).__init__()
# self.logger = logging.getLogger(self.__class__.__name__)
def forward(self, *x):
"""Forward pass logic.
:return: Model output
"""
raise NotImplementedError
def freeze_model(self, do_freeze: bool = True):
for param in self.parameters():
param.requires_grad = (not do_freeze)
def summary(self, logger, writer=None):
"""Model summary."""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size())
for p in model_parameters]) / 1e6 # Unit is Mega
logger.info('===>Trainable parameters: %.3f M' % params)
if writer is not None:
writer.add_text('Model Summary',
'Trainable parameters: %.3f M' % params)
"""https://github.com/X-niper/UniTalker"""
class UniTalkerDecoderTransformer(BaseModel):
def __init__(self, out_dim, identity_num, period=30, interpolate_pos=1) -> None:
super().__init__()
self.learnable_style_emb = nn.Embedding(identity_num, out_dim)
self.PPE = PeriodicPositionalEncoding(
out_dim, period=period, max_seq_len=3000)
self.biased_mask = init_biased_mask(
n_head=4, max_seq_len=3000, period=period)
decoder_layer = nn.TransformerDecoderLayer(
d_model=out_dim,
nhead=4,
dim_feedforward=2 * out_dim,
batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(
decoder_layer, num_layers=1)
self.interpolate_pos = interpolate_pos
def forward(self, hidden_states: torch.Tensor, style_idx: torch.Tensor,
frame_num: int):
style_idx = torch.argmax(style_idx, dim=1)
obj_embedding = self.learnable_style_emb(style_idx)
obj_embedding = obj_embedding.unsqueeze(1).repeat(1, frame_num, 1)
style_input = self.PPE(obj_embedding)
tgt_mask = self.biased_mask.repeat(style_idx.shape[0], 1, 1)[:, :style_input.shape[1], :style_input.
shape[1]].clone().detach().to(
device=style_input.device)
memory_mask = enc_dec_mask(hidden_states.device, style_input.shape[1],
frame_num)
feat_out = self.transformer_decoder(
style_input,
hidden_states,
tgt_mask=tgt_mask,
memory_mask=memory_mask)
if self.interpolate_pos == 2:
feat_out = linear_interpolation(feat_out, output_len=frame_num)
return feat_out

752
models/utils.py Normal file
View File

@@ -0,0 +1,752 @@
import json
import time
import warnings
import numpy as np
from typing import List, Optional,Tuple
from scipy.signal import savgol_filter
ARKitLeftRightPair = [
("jawLeft", "jawRight"),
("mouthLeft", "mouthRight"),
("mouthSmileLeft", "mouthSmileRight"),
("mouthFrownLeft", "mouthFrownRight"),
("mouthDimpleLeft", "mouthDimpleRight"),
("mouthStretchLeft", "mouthStretchRight"),
("mouthPressLeft", "mouthPressRight"),
("mouthLowerDownLeft", "mouthLowerDownRight"),
("mouthUpperUpLeft", "mouthUpperUpRight"),
("cheekSquintLeft", "cheekSquintRight"),
("noseSneerLeft", "noseSneerRight"),
("browDownLeft", "browDownRight"),
("browOuterUpLeft", "browOuterUpRight"),
("eyeBlinkLeft","eyeBlinkRight"),
("eyeLookDownLeft","eyeLookDownRight"),
("eyeLookInLeft", "eyeLookInRight"),
("eyeLookOutLeft","eyeLookOutRight"),
("eyeLookUpLeft","eyeLookUpRight"),
("eyeSquintLeft","eyeSquintRight"),
("eyeWideLeft","eyeWideRight")
]
ARKitBlendShape =[
"browDownLeft",
"browDownRight",
"browInnerUp",
"browOuterUpLeft",
"browOuterUpRight",
"cheekPuff",
"cheekSquintLeft",
"cheekSquintRight",
"eyeBlinkLeft",
"eyeBlinkRight",
"eyeLookDownLeft",
"eyeLookDownRight",
"eyeLookInLeft",
"eyeLookInRight",
"eyeLookOutLeft",
"eyeLookOutRight",
"eyeLookUpLeft",
"eyeLookUpRight",
"eyeSquintLeft",
"eyeSquintRight",
"eyeWideLeft",
"eyeWideRight",
"jawForward",
"jawLeft",
"jawOpen",
"jawRight",
"mouthClose",
"mouthDimpleLeft",
"mouthDimpleRight",
"mouthFrownLeft",
"mouthFrownRight",
"mouthFunnel",
"mouthLeft",
"mouthLowerDownLeft",
"mouthLowerDownRight",
"mouthPressLeft",
"mouthPressRight",
"mouthPucker",
"mouthRight",
"mouthRollLower",
"mouthRollUpper",
"mouthShrugLower",
"mouthShrugUpper",
"mouthSmileLeft",
"mouthSmileRight",
"mouthStretchLeft",
"mouthStretchRight",
"mouthUpperUpLeft",
"mouthUpperUpRight",
"noseSneerLeft",
"noseSneerRight",
"tongueOut"
]
MOUTH_BLENDSHAPES = [ "mouthDimpleLeft",
"mouthDimpleRight",
"mouthFrownLeft",
"mouthFrownRight",
"mouthFunnel",
"mouthLeft",
"mouthLowerDownLeft",
"mouthLowerDownRight",
"mouthPressLeft",
"mouthPressRight",
"mouthPucker",
"mouthRight",
"mouthRollLower",
"mouthRollUpper",
"mouthShrugLower",
"mouthShrugUpper",
"mouthSmileLeft",
"mouthSmileRight",
"mouthStretchLeft",
"mouthStretchRight",
"mouthUpperUpLeft",
"mouthUpperUpRight",
"jawForward",
"jawLeft",
"jawOpen",
"jawRight",
"noseSneerLeft",
"noseSneerRight",
"cheekPuff",
]
DEFAULT_CONTEXT ={
'is_initial_input': True,
'previous_audio': None,
'previous_expression': None,
'previous_volume': None,
'previous_headpose': None,
}
RETURN_CODE = {
"SUCCESS": 0,
"AUDIO_LENGTH_ERROR": 1,
"CHECKPOINT_PATH_ERROR":2,
"MODEL_INFERENCE_ERROR":3,
}
DEFAULT_CONTEXTRETURN = {
"code": RETURN_CODE['SUCCESS'],
"expression": None,
"headpose": None,
}
BLINK_PATTERNS = [
np.array([0.365, 0.950, 0.956, 0.917, 0.367, 0.119, 0.025]),
np.array([0.235, 0.910, 0.945, 0.778, 0.191, 0.235, 0.089]),
np.array([0.870, 0.950, 0.949, 0.696, 0.191, 0.073, 0.007]),
np.array([0.000, 0.557, 0.953, 0.942, 0.426, 0.148, 0.018])
]
# Postprocess
def symmetrize_blendshapes(
bs_params: np.ndarray,
mode: str = "average",
symmetric_pairs: list = ARKitLeftRightPair
) -> np.ndarray:
"""
Apply symmetrization to ARKit blendshape parameters (batched version)
Args:
bs_params: numpy array of shape (N, 52), batch of ARKit parameters
mode: symmetrization mode ["average", "max", "min", "left_dominant", "right_dominant"]
symmetric_pairs: list of left-right parameter pairs
Returns:
Symmetrized parameters with same shape (N, 52)
"""
name_to_idx = {name: i for i, name in enumerate(ARKitBlendShape)}
# Input validation
if bs_params.ndim != 2 or bs_params.shape[1] != 52:
raise ValueError("Input must be of shape (N, 52)")
symmetric_bs = bs_params.copy() # Shape (N, 52)
# Precompute valid index pairs
valid_pairs = []
for left, right in symmetric_pairs:
left_idx = name_to_idx.get(left)
right_idx = name_to_idx.get(right)
if None not in (left_idx, right_idx):
valid_pairs.append((left_idx, right_idx))
# Vectorized processing
for l_idx, r_idx in valid_pairs:
left_col = symmetric_bs[:, l_idx]
right_col = symmetric_bs[:, r_idx]
if mode == "average":
new_vals = (left_col + right_col) / 2
elif mode == "max":
new_vals = np.maximum(left_col, right_col)
elif mode == "min":
new_vals = np.minimum(left_col, right_col)
elif mode == "left_dominant":
new_vals = left_col
elif mode == "right_dominant":
new_vals = right_col
else:
raise ValueError(f"Invalid mode: {mode}")
# Update both columns simultaneously
symmetric_bs[:, l_idx] = new_vals
symmetric_bs[:, r_idx] = new_vals
return symmetric_bs
def apply_random_eye_blinks(
input: np.ndarray,
blink_scale: tuple = (0.8, 1.0),
blink_interval: tuple = (60, 120),
blink_duration: int = 7
) -> np.ndarray:
"""
Apply randomized eye blinks to blendshape parameters
Args:
output: Input array of shape (N, 52) containing blendshape parameters
blink_scale: Tuple (min, max) for random blink intensity scaling
blink_interval: Tuple (min, max) for random blink spacing in frames
blink_duration: Number of frames for blink animation (fixed)
Returns:
None (modifies output array in-place)
"""
# Define eye blink patterns (normalized 0-1)
# Initialize parameters
n_frames = input.shape[0]
input[:,8:10] = np.zeros((n_frames,2))
current_frame = 0
# Main blink application loop
while current_frame < n_frames - blink_duration:
# Randomize blink parameters
scale = np.random.uniform(*blink_scale)
pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
# Apply blink animation
blink_values = pattern * scale
input[current_frame:current_frame + blink_duration, 8] = blink_values
input[current_frame:current_frame + blink_duration, 9] = blink_values
# Advance to next blink position
current_frame += blink_duration + np.random.randint(*blink_interval)
return input
def apply_random_eye_blinks_context(
animation_params: np.ndarray,
processed_frames: int = 0,
intensity_range: tuple = (0.8, 1.0)
) -> np.ndarray:
"""Applies random eye blink patterns to facial animation parameters.
Args:
animation_params: Input facial animation parameters array with shape [num_frames, num_features].
Columns 8 and 9 typically represent left/right eye blink parameters.
processed_frames: Number of already processed frames that shouldn't be modified
intensity_range: Tuple defining (min, max) scaling for blink intensity
Returns:
Modified animation parameters array with random eye blinks added to unprocessed frames
"""
remaining_frames = animation_params.shape[0] - processed_frames
# Only apply blinks if there's enough remaining frames (blink pattern requires 7 frames)
if remaining_frames <= 7:
return animation_params
# Configure blink timing parameters
min_blink_interval = 40 # Minimum frames between blinks
max_blink_interval = 100 # Maximum frames between blinks
# Find last blink in previously processed frames (column 8 > 0.5 indicates blink)
previous_blink_indices = np.where(animation_params[:processed_frames, 8] > 0.5)[0]
last_processed_blink = previous_blink_indices[-1] - 7 if previous_blink_indices.size > 0 else processed_frames
# Calculate first new blink position
blink_interval = np.random.randint(min_blink_interval, max_blink_interval)
first_blink_start = max(0, blink_interval - last_processed_blink)
# Apply first blink if there's enough space
if first_blink_start <= (remaining_frames - 7):
# Randomly select blink pattern and intensity
blink_pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
intensity = np.random.uniform(*intensity_range)
# Calculate blink frame range
blink_start = processed_frames + first_blink_start
blink_end = blink_start + 7
# Apply pattern to both eyes
animation_params[blink_start:blink_end, 8] = blink_pattern * intensity
animation_params[blink_start:blink_end, 9] = blink_pattern * intensity
# Check space for additional blink
remaining_after_blink = animation_params.shape[0] - blink_end
if remaining_after_blink > min_blink_interval:
# Calculate second blink position
second_intensity = np.random.uniform(*intensity_range)
second_interval = np.random.randint(min_blink_interval, max_blink_interval)
if (remaining_after_blink - 7) > second_interval:
second_pattern = BLINK_PATTERNS[np.random.randint(0, 4)]
second_blink_start = blink_end + second_interval
second_blink_end = second_blink_start + 7
# Apply second blink
animation_params[second_blink_start:second_blink_end, 8] = second_pattern * second_intensity
animation_params[second_blink_start:second_blink_end, 9] = second_pattern * second_intensity
return animation_params
def export_blendshape_animation(
blendshape_weights: np.ndarray,
output_path: str,
blendshape_names: List[str],
fps: float,
rotation_data: Optional[np.ndarray] = None
) -> None:
"""
Export blendshape animation data to JSON format compatible with ARKit.
Args:
blendshape_weights: 2D numpy array of shape (N, 52) containing animation frames
output_path: Full path for output JSON file (including .json extension)
blendshape_names: Ordered list of 52 ARKit-standard blendshape names
fps: Frame rate for timing calculations (frames per second)
rotation_data: Optional 3D rotation data array of shape (N, 3)
Raises:
ValueError: If input dimensions are incompatible
IOError: If file writing fails
"""
# Validate input dimensions
if blendshape_weights.shape[1] != 52:
raise ValueError(f"Expected 52 blendshapes, got {blendshape_weights.shape[1]}")
if len(blendshape_names) != 52:
raise ValueError(f"Requires 52 blendshape names, got {len(blendshape_names)}")
if rotation_data is not None and len(rotation_data) != len(blendshape_weights):
raise ValueError("Rotation data length must match animation frames")
# Build animation data structure
animation_data = {
"names":blendshape_names,
"metadata": {
"fps": fps,
"frame_count": len(blendshape_weights),
"blendshape_names": blendshape_names
},
"frames": []
}
# Convert numpy array to serializable format
for frame_idx in range(blendshape_weights.shape[0]):
frame_data = {
"weights": blendshape_weights[frame_idx].tolist(),
"time": frame_idx / fps,
"rotation": rotation_data[frame_idx].tolist() if rotation_data else []
}
animation_data["frames"].append(frame_data)
# Safeguard against data loss
if not output_path.endswith('.json'):
output_path += '.json'
# Write to file with error handling
try:
with open(output_path, 'w', encoding='utf-8') as json_file:
json.dump(animation_data, json_file, indent=2, ensure_ascii=False)
except Exception as e:
raise IOError(f"Failed to write animation data: {str(e)}") from e
def apply_savitzky_golay_smoothing(
input_data: np.ndarray,
window_length: int = 5,
polyorder: int = 2,
axis: int = 0,
validate: bool = True
) -> Tuple[np.ndarray, Optional[float]]:
"""
Apply Savitzky-Golay filter smoothing along specified axis of input data.
Args:
input_data: 2D numpy array of shape (n_samples, n_features)
window_length: Length of the filter window (must be odd and > polyorder)
polyorder: Order of the polynomial fit
axis: Axis along which to filter (0: column-wise, 1: row-wise)
validate: Enable input validation checks when True
Returns:
tuple: (smoothed_data, processing_time)
- smoothed_data: Smoothed output array
- processing_time: Execution time in seconds (None in validation mode)
Raises:
ValueError: For invalid input dimensions or filter parameters
"""
# Validation mode timing bypass
processing_time = None
if validate:
# Input integrity checks
if input_data.ndim != 2:
raise ValueError(f"Expected 2D input, got {input_data.ndim}D array")
if window_length % 2 == 0 or window_length < 3:
raise ValueError("Window length must be odd integer ≥ 3")
if polyorder >= window_length:
raise ValueError("Polynomial order must be < window length")
# Store original dtype and convert to float64 for numerical stability
original_dtype = input_data.dtype
working_data = input_data.astype(np.float64)
# Start performance timer
timer_start = time.perf_counter()
try:
# Vectorized Savitzky-Golay application
smoothed_data = savgol_filter(working_data,
window_length=window_length,
polyorder=polyorder,
axis=axis,
mode='mirror')
except Exception as e:
raise RuntimeError(f"Filtering failed: {str(e)}") from e
# Stop timer and calculate duration
processing_time = time.perf_counter() - timer_start
# Restore original data type with overflow protection
return (
np.clip(smoothed_data,
0.0,
1.0
).astype(original_dtype),
processing_time
)
def _blend_region_start(
array: np.ndarray,
region: np.ndarray,
processed_boundary: int,
blend_frames: int
) -> None:
"""Applies linear blend between last active frame and silent region start."""
blend_length = min(blend_frames, region[0] - processed_boundary)
if blend_length <= 0:
return
pre_frame = array[region[0] - 1]
for i in range(blend_length):
weight = (i + 1) / (blend_length + 1)
array[region[0] + i] = pre_frame * (1 - weight) + array[region[0] + i] * weight
def _blend_region_end(
array: np.ndarray,
region: np.ndarray,
blend_frames: int
) -> None:
"""Applies linear blend between silent region end and next active frame."""
blend_length = min(blend_frames, array.shape[0] - region[-1] - 1)
if blend_length <= 0:
return
post_frame = array[region[-1] + 1]
for i in range(blend_length):
weight = (i + 1) / (blend_length + 1)
array[region[-1] - i] = post_frame * (1 - weight) + array[region[-1] - i] * weight
def find_low_value_regions(
signal: np.ndarray,
threshold: float,
min_region_length: int = 5
) -> list:
"""Identifies contiguous regions in a signal where values fall below a threshold.
Args:
signal: Input 1D array of numerical values
threshold: Value threshold for identifying low regions
min_region_length: Minimum consecutive samples required to qualify as a region
Returns:
List of numpy arrays, each containing indices for a qualifying low-value region
"""
low_value_indices = np.where(signal < threshold)[0]
contiguous_regions = []
current_region_length = 0
region_start_idx = 0
for i in range(1, len(low_value_indices)):
# Check if current index continues a consecutive sequence
if low_value_indices[i] != low_value_indices[i - 1] + 1:
# Finalize previous region if it meets length requirement
if current_region_length >= min_region_length:
contiguous_regions.append(low_value_indices[region_start_idx:i])
# Reset tracking for new potential region
region_start_idx = i
current_region_length = 0
current_region_length += 1
# Add the final region if it qualifies
if current_region_length >= min_region_length:
contiguous_regions.append(low_value_indices[region_start_idx:])
return contiguous_regions
def smooth_mouth_movements(
blend_shapes: np.ndarray,
processed_frames: int,
volume: np.ndarray = None,
silence_threshold: float = 0.001,
min_silence_duration: int = 7,
blend_window: int = 3
) -> np.ndarray:
"""Reduces jaw movement artifacts during silent periods in audio-driven animation.
Args:
blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes]
processed_frames: Number of already processed frames that shouldn't be modified
volume: Audio volume array used to detect silent periods
silence_threshold: Volume threshold for considering a frame silent
min_silence_duration: Minimum consecutive silent frames to qualify for processing
blend_window: Number of frames to smooth at region boundaries
Returns:
Modified blend shape array with reduced mouth movements during silence
"""
if volume is None:
return blend_shapes
# Detect silence periods using volume data
silent_regions = find_low_value_regions(
volume,
threshold=silence_threshold,
min_region_length=min_silence_duration
)
for region_indices in silent_regions:
# Reduce mouth blend shapes in silent region
mouth_blend_indices = [ARKitBlendShape.index(name) for name in MOUTH_BLENDSHAPES]
for region_indice in region_indices.tolist():
blend_shapes[region_indice, mouth_blend_indices] *= 0.1
try:
# Smooth transition into silent region
_blend_region_start(
blend_shapes,
region_indices,
processed_frames,
blend_window
)
# Smooth transition out of silent region
_blend_region_end(
blend_shapes,
region_indices,
blend_window
)
except IndexError as e:
warnings.warn(f"Edge blending skipped at region {region_indices}: {str(e)}")
return blend_shapes
def apply_frame_blending(
blend_shapes: np.ndarray,
processed_frames: int,
initial_blend_window: int = 3,
subsequent_blend_window: int = 5
) -> np.ndarray:
"""Smooths transitions between processed and unprocessed animation frames using linear blending.
Args:
blend_shapes: Array of facial blend shape weights [num_frames, num_blendshapes]
processed_frames: Number of already processed frames (0 means no previous processing)
initial_blend_window: Max frames to blend at sequence start
subsequent_blend_window: Max frames to blend between processed and new frames
Returns:
Modified blend shape array with smoothed transitions
"""
if processed_frames > 0:
# Blend transition between existing and new animation
_blend_animation_segment(
blend_shapes,
transition_start=processed_frames,
blend_window=subsequent_blend_window,
reference_frame=blend_shapes[processed_frames - 1]
)
else:
# Smooth initial frames from neutral expression (zeros)
_blend_animation_segment(
blend_shapes,
transition_start=0,
blend_window=initial_blend_window,
reference_frame=np.zeros_like(blend_shapes[0])
)
return blend_shapes
def _blend_animation_segment(
array: np.ndarray,
transition_start: int,
blend_window: int,
reference_frame: np.ndarray
) -> None:
"""Applies linear interpolation between reference frame and target frames.
Args:
array: Blend shape array to modify
transition_start: Starting index for blending
blend_window: Maximum number of frames to blend
reference_frame: The reference frame to blend from
"""
actual_blend_length = min(blend_window, array.shape[0] - transition_start)
for frame_offset in range(actual_blend_length):
current_idx = transition_start + frame_offset
blend_weight = (frame_offset + 1) / (actual_blend_length + 1)
# Linear interpolation: ref_frame * (1 - weight) + current_frame * weight
array[current_idx] = (reference_frame * (1 - blend_weight)
+ array[current_idx] * blend_weight)
BROW1 = np.array([[0.05597309, 0.05727929, 0.07995935, 0. , 0. ],
[0.00757574, 0.00936678, 0.12242376, 0. , 0. ],
[0. , 0. , 0.14943372, 0.04535687, 0.04264118],
[0. , 0. , 0.18015374, 0.09019445, 0.08736137],
[0. , 0. , 0.20549579, 0.12802747, 0.12450772],
[0. , 0. , 0.21098022, 0.1369939 , 0.13343132],
[0. , 0. , 0.20904602, 0.13903855, 0.13562402],
[0. , 0. , 0.20365039, 0.13977394, 0.13653506],
[0. , 0. , 0.19714841, 0.14096624, 0.13805152],
[0. , 0. , 0.20325482, 0.17303431, 0.17028868],
[0. , 0. , 0.21990852, 0.20164253, 0.19818163],
[0. , 0. , 0.23858181, 0.21908803, 0.21540019],
[0. , 0. , 0.2567876 , 0.23762083, 0.23396946],
[0. , 0. , 0.34093422, 0.27898848, 0.27651772],
[0. , 0. , 0.45288125, 0.35008961, 0.34887788],
[0. , 0. , 0.48076251, 0.36878952, 0.36778417],
[0. , 0. , 0.47798249, 0.36362219, 0.36145973],
[0. , 0. , 0.46186113, 0.33865979, 0.33597934],
[0. , 0. , 0.45264384, 0.33152157, 0.32891783],
[0. , 0. , 0.40986338, 0.29646468, 0.2945672 ],
[0. , 0. , 0.35628179, 0.23356403, 0.23155804],
[0. , 0. , 0.30870566, 0.1780673 , 0.17637439],
[0. , 0. , 0.25293985, 0.10710219, 0.10622486],
[0. , 0. , 0.18743332, 0.03252602, 0.03244236],
[0.02340254, 0.02364671, 0.15736724, 0. , 0. ]])
BROW2 = np.array([
[0. , 0. , 0.09799323, 0.05944436, 0.05002545],
[0. , 0. , 0.09780276, 0.07674237, 0.01636653],
[0. , 0. , 0.11136199, 0.1027964 , 0.04249811],
[0. , 0. , 0.26883412, 0.15861984, 0.15832305],
[0. , 0. , 0.42191629, 0.27038204, 0.27007768],
[0. , 0. , 0.3404977 , 0.21633868, 0.21597538],
[0. , 0. , 0.27301185, 0.17176409, 0.17134669],
[0. , 0. , 0.25960442, 0.15670464, 0.15622253],
[0. , 0. , 0.22877269, 0.11805892, 0.11754539],
[0. , 0. , 0.1451605 , 0.06389034, 0.0636282 ]])
BROW3 = np.array([
[0. , 0. , 0.124 , 0.0295, 0.0295],
[0. , 0. , 0.267 , 0.184 , 0.184 ],
[0. , 0. , 0.359 , 0.2765, 0.2765],
[0. , 0. , 0.3945, 0.3125, 0.3125],
[0. , 0. , 0.4125, 0.331 , 0.331 ],
[0. , 0. , 0.4235, 0.3445, 0.3445],
[0. , 0. , 0.4085, 0.3305, 0.3305],
[0. , 0. , 0.3695, 0.294 , 0.294 ],
[0. , 0. , 0.2835, 0.213 , 0.213 ],
[0. , 0. , 0.1795, 0.1005, 0.1005],
[0. , 0. , 0.108 , 0.014 , 0.014 ]])
import numpy as np
from scipy.ndimage import label
def apply_random_brow_movement(input_exp, volume):
FRAME_SEGMENT = 150
HOLD_THRESHOLD = 10
VOLUME_THRESHOLD = 0.08
MIN_REGION_LENGTH = 6
STRENGTH_RANGE = (0.7, 1.3)
BROW_PEAKS = {
0: np.argmax(BROW1[:, 2]),
1: np.argmax(BROW2[:, 2])
}
for seg_start in range(0, len(volume), FRAME_SEGMENT):
seg_end = min(seg_start + FRAME_SEGMENT, len(volume))
seg_volume = volume[seg_start:seg_end]
candidate_regions = []
high_vol_mask = seg_volume > VOLUME_THRESHOLD
labeled_array, num_features = label(high_vol_mask)
for i in range(1, num_features + 1):
region = (labeled_array == i)
region_indices = np.where(region)[0]
if len(region_indices) >= MIN_REGION_LENGTH:
candidate_regions.append(region_indices)
if candidate_regions:
selected_region = candidate_regions[np.random.choice(len(candidate_regions))]
region_start = selected_region[0]
region_end = selected_region[-1]
region_length = region_end - region_start + 1
brow_idx = np.random.randint(0, 2)
base_brow = BROW1 if brow_idx == 0 else BROW2
peak_idx = BROW_PEAKS[brow_idx]
if region_length > HOLD_THRESHOLD:
local_max_pos = seg_volume[selected_region].argmax()
global_peak_frame = seg_start + selected_region[local_max_pos]
rise_anim = base_brow[:peak_idx + 1]
hold_frame = base_brow[peak_idx:peak_idx + 1]
insert_start = max(global_peak_frame - peak_idx, seg_start)
insert_end = min(global_peak_frame + (region_length - local_max_pos), seg_end)
strength = np.random.uniform(*STRENGTH_RANGE)
if insert_start + len(rise_anim) <= seg_end:
input_exp[insert_start:insert_start + len(rise_anim), :5] += rise_anim * strength
hold_duration = insert_end - (insert_start + len(rise_anim))
if hold_duration > 0:
input_exp[insert_start + len(rise_anim):insert_end, :5] += np.tile(hold_frame * strength,
(hold_duration, 1))
else:
anim_length = base_brow.shape[0]
insert_pos = seg_start + region_start + (region_length - anim_length) // 2
insert_pos = max(seg_start, min(insert_pos, seg_end - anim_length))
if insert_pos + anim_length <= seg_end:
strength = np.random.uniform(*STRENGTH_RANGE)
input_exp[insert_pos:insert_pos + anim_length, :5] += base_brow * strength
return np.clip(input_exp, 0, 1)