mirror of
https://github.com/aigc3d/LAM_Audio2Expression.git
synced 2026-02-04 17:39:24 +08:00
feat: Initial commit
This commit is contained in:
4
models/losses/__init__.py
Normal file
4
models/losses/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .builder import build_criteria
|
||||
|
||||
from .misc import CrossEntropyLoss, SmoothCELoss, DiceLoss, FocalLoss, BinaryFocalLoss, L1Loss
|
||||
from .lovasz import LovaszLoss
|
||||
28
models/losses/builder.py
Normal file
28
models/losses/builder.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
from utils.registry import Registry
|
||||
|
||||
LOSSES = Registry("losses")
|
||||
|
||||
|
||||
class Criteria(object):
|
||||
def __init__(self, cfg=None):
|
||||
self.cfg = cfg if cfg is not None else []
|
||||
self.criteria = []
|
||||
for loss_cfg in self.cfg:
|
||||
self.criteria.append(LOSSES.build(cfg=loss_cfg))
|
||||
|
||||
def __call__(self, pred, target):
|
||||
if len(self.criteria) == 0:
|
||||
# loss computation occur in model
|
||||
return pred
|
||||
loss = 0
|
||||
for c in self.criteria:
|
||||
loss += c(pred, target)
|
||||
return loss
|
||||
|
||||
|
||||
def build_criteria(cfg):
|
||||
return Criteria(cfg)
|
||||
253
models/losses/lovasz.py
Normal file
253
models/losses/lovasz.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from itertools import filterfalse
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from .builder import LOSSES
|
||||
|
||||
BINARY_MODE: str = "binary"
|
||||
MULTICLASS_MODE: str = "multiclass"
|
||||
MULTILABEL_MODE: str = "multilabel"
|
||||
|
||||
|
||||
def _lovasz_grad(gt_sorted):
|
||||
"""Compute gradient of the Lovasz extension w.r.t sorted errors
|
||||
See Alg. 1 in paper
|
||||
"""
|
||||
p = len(gt_sorted)
|
||||
gts = gt_sorted.sum()
|
||||
intersection = gts - gt_sorted.float().cumsum(0)
|
||||
union = gts + (1 - gt_sorted).float().cumsum(0)
|
||||
jaccard = 1.0 - intersection / union
|
||||
if p > 1: # cover 1-pixel case
|
||||
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
||||
return jaccard
|
||||
|
||||
|
||||
def _lovasz_hinge(logits, labels, per_image=True, ignore=None):
|
||||
"""
|
||||
Binary Lovasz hinge loss
|
||||
logits: [B, H, W] Logits at each pixel (between -infinity and +infinity)
|
||||
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
|
||||
per_image: compute the loss per image instead of per batch
|
||||
ignore: void class id
|
||||
"""
|
||||
if per_image:
|
||||
loss = mean(
|
||||
_lovasz_hinge_flat(
|
||||
*_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)
|
||||
)
|
||||
for log, lab in zip(logits, labels)
|
||||
)
|
||||
else:
|
||||
loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore))
|
||||
return loss
|
||||
|
||||
|
||||
def _lovasz_hinge_flat(logits, labels):
|
||||
"""Binary Lovasz hinge loss
|
||||
Args:
|
||||
logits: [P] Logits at each prediction (between -infinity and +infinity)
|
||||
labels: [P] Tensor, binary ground truth labels (0 or 1)
|
||||
"""
|
||||
if len(labels) == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return logits.sum() * 0.0
|
||||
signs = 2.0 * labels.float() - 1.0
|
||||
errors = 1.0 - logits * signs
|
||||
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
|
||||
perm = perm.data
|
||||
gt_sorted = labels[perm]
|
||||
grad = _lovasz_grad(gt_sorted)
|
||||
loss = torch.dot(F.relu(errors_sorted), grad)
|
||||
return loss
|
||||
|
||||
|
||||
def _flatten_binary_scores(scores, labels, ignore=None):
|
||||
"""Flattens predictions in the batch (binary case)
|
||||
Remove labels equal to 'ignore'
|
||||
"""
|
||||
scores = scores.view(-1)
|
||||
labels = labels.view(-1)
|
||||
if ignore is None:
|
||||
return scores, labels
|
||||
valid = labels != ignore
|
||||
vscores = scores[valid]
|
||||
vlabels = labels[valid]
|
||||
return vscores, vlabels
|
||||
|
||||
|
||||
def _lovasz_softmax(
|
||||
probas, labels, classes="present", class_seen=None, per_image=False, ignore=None
|
||||
):
|
||||
"""Multi-class Lovasz-Softmax loss
|
||||
Args:
|
||||
@param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1).
|
||||
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
|
||||
@param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
|
||||
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
|
||||
@param per_image: compute the loss per image instead of per batch
|
||||
@param ignore: void class labels
|
||||
"""
|
||||
if per_image:
|
||||
loss = mean(
|
||||
_lovasz_softmax_flat(
|
||||
*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore),
|
||||
classes=classes
|
||||
)
|
||||
for prob, lab in zip(probas, labels)
|
||||
)
|
||||
else:
|
||||
loss = _lovasz_softmax_flat(
|
||||
*_flatten_probas(probas, labels, ignore),
|
||||
classes=classes,
|
||||
class_seen=class_seen
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
def _lovasz_softmax_flat(probas, labels, classes="present", class_seen=None):
|
||||
"""Multi-class Lovasz-Softmax loss
|
||||
Args:
|
||||
@param probas: [P, C] Class probabilities at each prediction (between 0 and 1)
|
||||
@param labels: [P] Tensor, ground truth labels (between 0 and C - 1)
|
||||
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
|
||||
"""
|
||||
if probas.numel() == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return probas * 0.0
|
||||
C = probas.size(1)
|
||||
losses = []
|
||||
class_to_sum = list(range(C)) if classes in ["all", "present"] else classes
|
||||
# for c in class_to_sum:
|
||||
for c in labels.unique():
|
||||
if class_seen is None:
|
||||
fg = (labels == c).type_as(probas) # foreground for class c
|
||||
if classes == "present" and fg.sum() == 0:
|
||||
continue
|
||||
if C == 1:
|
||||
if len(classes) > 1:
|
||||
raise ValueError("Sigmoid output possible only with 1 class")
|
||||
class_pred = probas[:, 0]
|
||||
else:
|
||||
class_pred = probas[:, c]
|
||||
errors = (fg - class_pred).abs()
|
||||
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
||||
perm = perm.data
|
||||
fg_sorted = fg[perm]
|
||||
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
|
||||
else:
|
||||
if c in class_seen:
|
||||
fg = (labels == c).type_as(probas) # foreground for class c
|
||||
if classes == "present" and fg.sum() == 0:
|
||||
continue
|
||||
if C == 1:
|
||||
if len(classes) > 1:
|
||||
raise ValueError("Sigmoid output possible only with 1 class")
|
||||
class_pred = probas[:, 0]
|
||||
else:
|
||||
class_pred = probas[:, c]
|
||||
errors = (fg - class_pred).abs()
|
||||
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
||||
perm = perm.data
|
||||
fg_sorted = fg[perm]
|
||||
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
|
||||
return mean(losses)
|
||||
|
||||
|
||||
def _flatten_probas(probas, labels, ignore=None):
|
||||
"""Flattens predictions in the batch"""
|
||||
if probas.dim() == 3:
|
||||
# assumes output of a sigmoid layer
|
||||
B, H, W = probas.size()
|
||||
probas = probas.view(B, 1, H, W)
|
||||
|
||||
C = probas.size(1)
|
||||
probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C]
|
||||
probas = probas.contiguous().view(-1, C) # [P, C]
|
||||
|
||||
labels = labels.view(-1)
|
||||
if ignore is None:
|
||||
return probas, labels
|
||||
valid = labels != ignore
|
||||
vprobas = probas[valid]
|
||||
vlabels = labels[valid]
|
||||
return vprobas, vlabels
|
||||
|
||||
|
||||
def isnan(x):
|
||||
return x != x
|
||||
|
||||
|
||||
def mean(values, ignore_nan=False, empty=0):
|
||||
"""Nan-mean compatible with generators."""
|
||||
values = iter(values)
|
||||
if ignore_nan:
|
||||
values = filterfalse(isnan, values)
|
||||
try:
|
||||
n = 1
|
||||
acc = next(values)
|
||||
except StopIteration:
|
||||
if empty == "raise":
|
||||
raise ValueError("Empty mean")
|
||||
return empty
|
||||
for n, v in enumerate(values, 2):
|
||||
acc += v
|
||||
if n == 1:
|
||||
return acc
|
||||
return acc / n
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class LovaszLoss(_Loss):
|
||||
def __init__(
|
||||
self,
|
||||
mode: str,
|
||||
class_seen: Optional[int] = None,
|
||||
per_image: bool = False,
|
||||
ignore_index: Optional[int] = None,
|
||||
loss_weight: float = 1.0,
|
||||
):
|
||||
"""Lovasz loss for segmentation task.
|
||||
It supports binary, multiclass and multilabel cases
|
||||
Args:
|
||||
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
|
||||
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
|
||||
per_image: If True loss computed per each image and then averaged, else computed per whole batch
|
||||
Shape
|
||||
- **y_pred** - torch.Tensor of shape (N, C, H, W)
|
||||
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
|
||||
Reference
|
||||
https://github.com/BloodAxe/pytorch-toolbelt
|
||||
"""
|
||||
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
self.ignore_index = ignore_index
|
||||
self.per_image = per_image
|
||||
self.class_seen = class_seen
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
|
||||
loss = _lovasz_hinge(
|
||||
y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index
|
||||
)
|
||||
elif self.mode == MULTICLASS_MODE:
|
||||
y_pred = y_pred.softmax(dim=1)
|
||||
loss = _lovasz_softmax(
|
||||
y_pred,
|
||||
y_true,
|
||||
class_seen=self.class_seen,
|
||||
per_image=self.per_image,
|
||||
ignore=self.ignore_index,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Wrong mode {}.".format(self.mode))
|
||||
return loss * self.loss_weight
|
||||
241
models/losses/misc.py
Normal file
241
models/losses/misc.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
The code is base on https://github.com/Pointcept/Pointcept
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight=None,
|
||||
size_average=None,
|
||||
reduce=None,
|
||||
reduction="mean",
|
||||
label_smoothing=0.0,
|
||||
loss_weight=1.0,
|
||||
ignore_index=-1,
|
||||
):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
weight = torch.tensor(weight).cuda() if weight is not None else None
|
||||
self.loss_weight = loss_weight
|
||||
self.loss = nn.CrossEntropyLoss(
|
||||
weight=weight,
|
||||
size_average=size_average,
|
||||
ignore_index=ignore_index,
|
||||
reduce=reduce,
|
||||
reduction=reduction,
|
||||
label_smoothing=label_smoothing,
|
||||
)
|
||||
|
||||
def forward(self, pred, target):
|
||||
return self.loss(pred, target) * self.loss_weight
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class L1Loss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight=None,
|
||||
size_average=None,
|
||||
reduce=None,
|
||||
reduction="mean",
|
||||
label_smoothing=0.0,
|
||||
loss_weight=1.0,
|
||||
ignore_index=-1,
|
||||
):
|
||||
super(L1Loss, self).__init__()
|
||||
weight = torch.tensor(weight).cuda() if weight is not None else None
|
||||
self.loss_weight = loss_weight
|
||||
self.loss = nn.L1Loss(reduction='mean')
|
||||
|
||||
def forward(self, pred, target):
|
||||
return self.loss(pred, target[:,None]) * self.loss_weight
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class SmoothCELoss(nn.Module):
|
||||
def __init__(self, smoothing_ratio=0.1):
|
||||
super(SmoothCELoss, self).__init__()
|
||||
self.smoothing_ratio = smoothing_ratio
|
||||
|
||||
def forward(self, pred, target):
|
||||
eps = self.smoothing_ratio
|
||||
n_class = pred.size(1)
|
||||
one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
|
||||
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
||||
log_prb = F.log_softmax(pred, dim=1)
|
||||
loss = -(one_hot * log_prb).total(dim=1)
|
||||
loss = loss[torch.isfinite(loss)].mean()
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class BinaryFocalLoss(nn.Module):
|
||||
def __init__(self, gamma=2.0, alpha=0.5, logits=True, reduce=True, loss_weight=1.0):
|
||||
"""Binary Focal Loss
|
||||
<https://arxiv.org/abs/1708.02002>`
|
||||
"""
|
||||
super(BinaryFocalLoss, self).__init__()
|
||||
assert 0 < alpha < 1
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.logits = logits
|
||||
self.reduce = reduce
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self, pred, target, **kwargs):
|
||||
"""Forward function.
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N)
|
||||
target (torch.Tensor): The ground truth. If containing class
|
||||
indices, shape (N) where each value is 0≤targets[i]≤1, If containing class probabilities,
|
||||
same shape as the input.
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
if self.logits:
|
||||
bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
|
||||
else:
|
||||
bce = F.binary_cross_entropy(pred, target, reduction="none")
|
||||
pt = torch.exp(-bce)
|
||||
alpha = self.alpha * target + (1 - self.alpha) * (1 - target)
|
||||
focal_loss = alpha * (1 - pt) ** self.gamma * bce
|
||||
|
||||
if self.reduce:
|
||||
focal_loss = torch.mean(focal_loss)
|
||||
return focal_loss * self.loss_weight
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class FocalLoss(nn.Module):
|
||||
def __init__(
|
||||
self, gamma=2.0, alpha=0.5, reduction="mean", loss_weight=1.0, ignore_index=-1
|
||||
):
|
||||
"""Focal Loss
|
||||
<https://arxiv.org/abs/1708.02002>`
|
||||
"""
|
||||
super(FocalLoss, self).__init__()
|
||||
assert reduction in (
|
||||
"mean",
|
||||
"sum",
|
||||
), "AssertionError: reduction should be 'mean' or 'sum'"
|
||||
assert isinstance(
|
||||
alpha, (float, list)
|
||||
), "AssertionError: alpha should be of type float"
|
||||
assert isinstance(gamma, float), "AssertionError: gamma should be of type float"
|
||||
assert isinstance(
|
||||
loss_weight, float
|
||||
), "AssertionError: loss_weight should be of type float"
|
||||
assert isinstance(ignore_index, int), "ignore_index must be of type int"
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self, pred, target, **kwargs):
|
||||
"""Forward function.
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C) where C = number of classes.
|
||||
target (torch.Tensor): The ground truth. If containing class
|
||||
indices, shape (N) where each value is 0≤targets[i]≤C−1, If containing class probabilities,
|
||||
same shape as the input.
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
|
||||
pred = pred.transpose(0, 1)
|
||||
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||
pred = pred.reshape(pred.size(0), -1)
|
||||
# [C, N] -> [N, C]
|
||||
pred = pred.transpose(0, 1).contiguous()
|
||||
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
|
||||
target = target.view(-1).contiguous()
|
||||
assert pred.size(0) == target.size(
|
||||
0
|
||||
), "The shape of pred doesn't match the shape of target"
|
||||
valid_mask = target != self.ignore_index
|
||||
target = target[valid_mask]
|
||||
pred = pred[valid_mask]
|
||||
|
||||
if len(target) == 0:
|
||||
return 0.0
|
||||
|
||||
num_classes = pred.size(1)
|
||||
target = F.one_hot(target, num_classes=num_classes)
|
||||
|
||||
alpha = self.alpha
|
||||
if isinstance(alpha, list):
|
||||
alpha = pred.new_tensor(alpha)
|
||||
pred_sigmoid = pred.sigmoid()
|
||||
target = target.type_as(pred)
|
||||
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
|
||||
focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow(
|
||||
self.gamma
|
||||
)
|
||||
|
||||
loss = (
|
||||
F.binary_cross_entropy_with_logits(pred, target, reduction="none")
|
||||
* focal_weight
|
||||
)
|
||||
if self.reduction == "mean":
|
||||
loss = loss.mean()
|
||||
elif self.reduction == "sum":
|
||||
loss = loss.total()
|
||||
return self.loss_weight * loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class DiceLoss(nn.Module):
|
||||
def __init__(self, smooth=1, exponent=2, loss_weight=1.0, ignore_index=-1):
|
||||
"""DiceLoss.
|
||||
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
|
||||
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
|
||||
"""
|
||||
super(DiceLoss, self).__init__()
|
||||
self.smooth = smooth
|
||||
self.exponent = exponent
|
||||
self.loss_weight = loss_weight
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self, pred, target, **kwargs):
|
||||
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
|
||||
pred = pred.transpose(0, 1)
|
||||
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
|
||||
pred = pred.reshape(pred.size(0), -1)
|
||||
# [C, N] -> [N, C]
|
||||
pred = pred.transpose(0, 1).contiguous()
|
||||
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
|
||||
target = target.view(-1).contiguous()
|
||||
assert pred.size(0) == target.size(
|
||||
0
|
||||
), "The shape of pred doesn't match the shape of target"
|
||||
valid_mask = target != self.ignore_index
|
||||
target = target[valid_mask]
|
||||
pred = pred[valid_mask]
|
||||
|
||||
pred = F.softmax(pred, dim=1)
|
||||
num_classes = pred.shape[1]
|
||||
target = F.one_hot(
|
||||
torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes
|
||||
)
|
||||
|
||||
total_loss = 0
|
||||
for i in range(num_classes):
|
||||
if i != self.ignore_index:
|
||||
num = torch.sum(torch.mul(pred[:, i], target[:, i])) * 2 + self.smooth
|
||||
den = (
|
||||
torch.sum(
|
||||
pred[:, i].pow(self.exponent) + target[:, i].pow(self.exponent)
|
||||
)
|
||||
+ self.smooth
|
||||
)
|
||||
dice_loss = 1 - num / den
|
||||
total_loss += dice_loss
|
||||
loss = total_loss / num_classes
|
||||
return self.loss_weight * loss
|
||||
Reference in New Issue
Block a user