Files
2025-04-17 23:14:24 +08:00

242 lines
8.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
The code is base on https://github.com/Pointcept/Pointcept
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .builder import LOSSES
@LOSSES.register_module()
class CrossEntropyLoss(nn.Module):
def __init__(
self,
weight=None,
size_average=None,
reduce=None,
reduction="mean",
label_smoothing=0.0,
loss_weight=1.0,
ignore_index=-1,
):
super(CrossEntropyLoss, self).__init__()
weight = torch.tensor(weight).cuda() if weight is not None else None
self.loss_weight = loss_weight
self.loss = nn.CrossEntropyLoss(
weight=weight,
size_average=size_average,
ignore_index=ignore_index,
reduce=reduce,
reduction=reduction,
label_smoothing=label_smoothing,
)
def forward(self, pred, target):
return self.loss(pred, target) * self.loss_weight
@LOSSES.register_module()
class L1Loss(nn.Module):
def __init__(
self,
weight=None,
size_average=None,
reduce=None,
reduction="mean",
label_smoothing=0.0,
loss_weight=1.0,
ignore_index=-1,
):
super(L1Loss, self).__init__()
weight = torch.tensor(weight).cuda() if weight is not None else None
self.loss_weight = loss_weight
self.loss = nn.L1Loss(reduction='mean')
def forward(self, pred, target):
return self.loss(pred, target[:,None]) * self.loss_weight
@LOSSES.register_module()
class SmoothCELoss(nn.Module):
def __init__(self, smoothing_ratio=0.1):
super(SmoothCELoss, self).__init__()
self.smoothing_ratio = smoothing_ratio
def forward(self, pred, target):
eps = self.smoothing_ratio
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).total(dim=1)
loss = loss[torch.isfinite(loss)].mean()
return loss
@LOSSES.register_module()
class BinaryFocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=0.5, logits=True, reduce=True, loss_weight=1.0):
"""Binary Focal Loss
<https://arxiv.org/abs/1708.02002>`
"""
super(BinaryFocalLoss, self).__init__()
assert 0 < alpha < 1
self.gamma = gamma
self.alpha = alpha
self.logits = logits
self.reduce = reduce
self.loss_weight = loss_weight
def forward(self, pred, target, **kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction with shape (N)
target (torch.Tensor): The ground truth. If containing class
indices, shape (N) where each value is 0≤targets[i]≤1, If containing class probabilities,
same shape as the input.
Returns:
torch.Tensor: The calculated loss
"""
if self.logits:
bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
else:
bce = F.binary_cross_entropy(pred, target, reduction="none")
pt = torch.exp(-bce)
alpha = self.alpha * target + (1 - self.alpha) * (1 - target)
focal_loss = alpha * (1 - pt) ** self.gamma * bce
if self.reduce:
focal_loss = torch.mean(focal_loss)
return focal_loss * self.loss_weight
@LOSSES.register_module()
class FocalLoss(nn.Module):
def __init__(
self, gamma=2.0, alpha=0.5, reduction="mean", loss_weight=1.0, ignore_index=-1
):
"""Focal Loss
<https://arxiv.org/abs/1708.02002>`
"""
super(FocalLoss, self).__init__()
assert reduction in (
"mean",
"sum",
), "AssertionError: reduction should be 'mean' or 'sum'"
assert isinstance(
alpha, (float, list)
), "AssertionError: alpha should be of type float"
assert isinstance(gamma, float), "AssertionError: gamma should be of type float"
assert isinstance(
loss_weight, float
), "AssertionError: loss_weight should be of type float"
assert isinstance(ignore_index, int), "ignore_index must be of type int"
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
self.ignore_index = ignore_index
def forward(self, pred, target, **kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction with shape (N, C) where C = number of classes.
target (torch.Tensor): The ground truth. If containing class
indices, shape (N) where each value is 0≤targets[i]≤C1, If containing class probabilities,
same shape as the input.
Returns:
torch.Tensor: The calculated loss
"""
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
pred = pred.transpose(0, 1)
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
pred = pred.reshape(pred.size(0), -1)
# [C, N] -> [N, C]
pred = pred.transpose(0, 1).contiguous()
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
target = target.view(-1).contiguous()
assert pred.size(0) == target.size(
0
), "The shape of pred doesn't match the shape of target"
valid_mask = target != self.ignore_index
target = target[valid_mask]
pred = pred[valid_mask]
if len(target) == 0:
return 0.0
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes)
alpha = self.alpha
if isinstance(alpha, list):
alpha = pred.new_tensor(alpha)
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * one_minus_pt.pow(
self.gamma
)
loss = (
F.binary_cross_entropy_with_logits(pred, target, reduction="none")
* focal_weight
)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.total()
return self.loss_weight * loss
@LOSSES.register_module()
class DiceLoss(nn.Module):
def __init__(self, smooth=1, exponent=2, loss_weight=1.0, ignore_index=-1):
"""DiceLoss.
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
"""
super(DiceLoss, self).__init__()
self.smooth = smooth
self.exponent = exponent
self.loss_weight = loss_weight
self.ignore_index = ignore_index
def forward(self, pred, target, **kwargs):
# [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k]
pred = pred.transpose(0, 1)
# [C, B, d_1, d_2, ..., d_k] -> [C, N]
pred = pred.reshape(pred.size(0), -1)
# [C, N] -> [N, C]
pred = pred.transpose(0, 1).contiguous()
# (B, d_1, d_2, ..., d_k) --> (B * d_1 * d_2 * ... * d_k,)
target = target.view(-1).contiguous()
assert pred.size(0) == target.size(
0
), "The shape of pred doesn't match the shape of target"
valid_mask = target != self.ignore_index
target = target[valid_mask]
pred = pred[valid_mask]
pred = F.softmax(pred, dim=1)
num_classes = pred.shape[1]
target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes
)
total_loss = 0
for i in range(num_classes):
if i != self.ignore_index:
num = torch.sum(torch.mul(pred[:, i], target[:, i])) * 2 + self.smooth
den = (
torch.sum(
pred[:, i].pow(self.exponent) + target[:, i].pow(self.exponent)
)
+ self.smooth
)
dice_loss = 1 - num / den
total_loss += dice_loss
loss = total_loss / num_classes
return self.loss_weight * loss