mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-04 09:29:24 +08:00
254 lines
8.5 KiB
Python
254 lines
8.5 KiB
Python
"""
|
|
The code is base on https://github.com/Pointcept/Pointcept
|
|
"""
|
|
|
|
from typing import Optional
|
|
from itertools import filterfalse
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.nn.modules.loss import _Loss
|
|
|
|
from .builder import LOSSES
|
|
|
|
BINARY_MODE: str = "binary"
|
|
MULTICLASS_MODE: str = "multiclass"
|
|
MULTILABEL_MODE: str = "multilabel"
|
|
|
|
|
|
def _lovasz_grad(gt_sorted):
|
|
"""Compute gradient of the Lovasz extension w.r.t sorted errors
|
|
See Alg. 1 in paper
|
|
"""
|
|
p = len(gt_sorted)
|
|
gts = gt_sorted.sum()
|
|
intersection = gts - gt_sorted.float().cumsum(0)
|
|
union = gts + (1 - gt_sorted).float().cumsum(0)
|
|
jaccard = 1.0 - intersection / union
|
|
if p > 1: # cover 1-pixel case
|
|
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
|
return jaccard
|
|
|
|
|
|
def _lovasz_hinge(logits, labels, per_image=True, ignore=None):
|
|
"""
|
|
Binary Lovasz hinge loss
|
|
logits: [B, H, W] Logits at each pixel (between -infinity and +infinity)
|
|
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
|
|
per_image: compute the loss per image instead of per batch
|
|
ignore: void class id
|
|
"""
|
|
if per_image:
|
|
loss = mean(
|
|
_lovasz_hinge_flat(
|
|
*_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)
|
|
)
|
|
for log, lab in zip(logits, labels)
|
|
)
|
|
else:
|
|
loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore))
|
|
return loss
|
|
|
|
|
|
def _lovasz_hinge_flat(logits, labels):
|
|
"""Binary Lovasz hinge loss
|
|
Args:
|
|
logits: [P] Logits at each prediction (between -infinity and +infinity)
|
|
labels: [P] Tensor, binary ground truth labels (0 or 1)
|
|
"""
|
|
if len(labels) == 0:
|
|
# only void pixels, the gradients should be 0
|
|
return logits.sum() * 0.0
|
|
signs = 2.0 * labels.float() - 1.0
|
|
errors = 1.0 - logits * signs
|
|
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
|
|
perm = perm.data
|
|
gt_sorted = labels[perm]
|
|
grad = _lovasz_grad(gt_sorted)
|
|
loss = torch.dot(F.relu(errors_sorted), grad)
|
|
return loss
|
|
|
|
|
|
def _flatten_binary_scores(scores, labels, ignore=None):
|
|
"""Flattens predictions in the batch (binary case)
|
|
Remove labels equal to 'ignore'
|
|
"""
|
|
scores = scores.view(-1)
|
|
labels = labels.view(-1)
|
|
if ignore is None:
|
|
return scores, labels
|
|
valid = labels != ignore
|
|
vscores = scores[valid]
|
|
vlabels = labels[valid]
|
|
return vscores, vlabels
|
|
|
|
|
|
def _lovasz_softmax(
|
|
probas, labels, classes="present", class_seen=None, per_image=False, ignore=None
|
|
):
|
|
"""Multi-class Lovasz-Softmax loss
|
|
Args:
|
|
@param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1).
|
|
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
|
|
@param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
|
|
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
|
|
@param per_image: compute the loss per image instead of per batch
|
|
@param ignore: void class labels
|
|
"""
|
|
if per_image:
|
|
loss = mean(
|
|
_lovasz_softmax_flat(
|
|
*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore),
|
|
classes=classes
|
|
)
|
|
for prob, lab in zip(probas, labels)
|
|
)
|
|
else:
|
|
loss = _lovasz_softmax_flat(
|
|
*_flatten_probas(probas, labels, ignore),
|
|
classes=classes,
|
|
class_seen=class_seen
|
|
)
|
|
return loss
|
|
|
|
|
|
def _lovasz_softmax_flat(probas, labels, classes="present", class_seen=None):
|
|
"""Multi-class Lovasz-Softmax loss
|
|
Args:
|
|
@param probas: [P, C] Class probabilities at each prediction (between 0 and 1)
|
|
@param labels: [P] Tensor, ground truth labels (between 0 and C - 1)
|
|
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
|
|
"""
|
|
if probas.numel() == 0:
|
|
# only void pixels, the gradients should be 0
|
|
return probas * 0.0
|
|
C = probas.size(1)
|
|
losses = []
|
|
class_to_sum = list(range(C)) if classes in ["all", "present"] else classes
|
|
# for c in class_to_sum:
|
|
for c in labels.unique():
|
|
if class_seen is None:
|
|
fg = (labels == c).type_as(probas) # foreground for class c
|
|
if classes == "present" and fg.sum() == 0:
|
|
continue
|
|
if C == 1:
|
|
if len(classes) > 1:
|
|
raise ValueError("Sigmoid output possible only with 1 class")
|
|
class_pred = probas[:, 0]
|
|
else:
|
|
class_pred = probas[:, c]
|
|
errors = (fg - class_pred).abs()
|
|
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
|
perm = perm.data
|
|
fg_sorted = fg[perm]
|
|
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
|
|
else:
|
|
if c in class_seen:
|
|
fg = (labels == c).type_as(probas) # foreground for class c
|
|
if classes == "present" and fg.sum() == 0:
|
|
continue
|
|
if C == 1:
|
|
if len(classes) > 1:
|
|
raise ValueError("Sigmoid output possible only with 1 class")
|
|
class_pred = probas[:, 0]
|
|
else:
|
|
class_pred = probas[:, c]
|
|
errors = (fg - class_pred).abs()
|
|
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
|
perm = perm.data
|
|
fg_sorted = fg[perm]
|
|
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
|
|
return mean(losses)
|
|
|
|
|
|
def _flatten_probas(probas, labels, ignore=None):
|
|
"""Flattens predictions in the batch"""
|
|
if probas.dim() == 3:
|
|
# assumes output of a sigmoid layer
|
|
B, H, W = probas.size()
|
|
probas = probas.view(B, 1, H, W)
|
|
|
|
C = probas.size(1)
|
|
probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C]
|
|
probas = probas.contiguous().view(-1, C) # [P, C]
|
|
|
|
labels = labels.view(-1)
|
|
if ignore is None:
|
|
return probas, labels
|
|
valid = labels != ignore
|
|
vprobas = probas[valid]
|
|
vlabels = labels[valid]
|
|
return vprobas, vlabels
|
|
|
|
|
|
def isnan(x):
|
|
return x != x
|
|
|
|
|
|
def mean(values, ignore_nan=False, empty=0):
|
|
"""Nan-mean compatible with generators."""
|
|
values = iter(values)
|
|
if ignore_nan:
|
|
values = filterfalse(isnan, values)
|
|
try:
|
|
n = 1
|
|
acc = next(values)
|
|
except StopIteration:
|
|
if empty == "raise":
|
|
raise ValueError("Empty mean")
|
|
return empty
|
|
for n, v in enumerate(values, 2):
|
|
acc += v
|
|
if n == 1:
|
|
return acc
|
|
return acc / n
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class LovaszLoss(_Loss):
|
|
def __init__(
|
|
self,
|
|
mode: str,
|
|
class_seen: Optional[int] = None,
|
|
per_image: bool = False,
|
|
ignore_index: Optional[int] = None,
|
|
loss_weight: float = 1.0,
|
|
):
|
|
"""Lovasz loss for segmentation task.
|
|
It supports binary, multiclass and multilabel cases
|
|
Args:
|
|
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
|
|
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
|
|
per_image: If True loss computed per each image and then averaged, else computed per whole batch
|
|
Shape
|
|
- **y_pred** - torch.Tensor of shape (N, C, H, W)
|
|
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
|
|
Reference
|
|
https://github.com/BloodAxe/pytorch-toolbelt
|
|
"""
|
|
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
|
|
super().__init__()
|
|
|
|
self.mode = mode
|
|
self.ignore_index = ignore_index
|
|
self.per_image = per_image
|
|
self.class_seen = class_seen
|
|
self.loss_weight = loss_weight
|
|
|
|
def forward(self, y_pred, y_true):
|
|
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
|
|
loss = _lovasz_hinge(
|
|
y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index
|
|
)
|
|
elif self.mode == MULTICLASS_MODE:
|
|
y_pred = y_pred.softmax(dim=1)
|
|
loss = _lovasz_softmax(
|
|
y_pred,
|
|
y_true,
|
|
class_seen=self.class_seen,
|
|
per_image=self.per_image,
|
|
ignore=self.ignore_index,
|
|
)
|
|
else:
|
|
raise ValueError("Wrong mode {}.".format(self.mode))
|
|
return loss * self.loss_weight
|