123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- import torch.nn as nn
- from mmdet.registry import MODELS
- from .utils import weight_reduce_loss
- def dice_loss(pred,
- target,
- weight=None,
- eps=1e-3,
- reduction='mean',
- naive_dice=False,
- avg_factor=None):
- """Calculate dice loss, there are two forms of dice loss is supported:
- - the one proposed in `V-Net: Fully Convolutional Neural
- Networks for Volumetric Medical Image Segmentation
- <https://arxiv.org/abs/1606.04797>`_.
- - the dice loss in which the power of the number in the
- denominator is the first power instead of the second
- power.
- Args:
- pred (torch.Tensor): The prediction, has a shape (n, *)
- target (torch.Tensor): The learning label of the prediction,
- shape (n, *), same shape of pred.
- weight (torch.Tensor, optional): The weight of loss for each
- prediction, has a shape (n,). Defaults to None.
- eps (float): Avoid dividing by zero. Default: 1e-3.
- reduction (str, optional): The method used to reduce the loss into
- a scalar. Defaults to 'mean'.
- Options are "none", "mean" and "sum".
- naive_dice (bool, optional): If false, use the dice
- loss defined in the V-Net paper, otherwise, use the
- naive dice loss in which the power of the number in the
- denominator is the first power instead of the second
- power.Defaults to False.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- """
- input = pred.flatten(1)
- target = target.flatten(1).float()
- a = torch.sum(input * target, 1)
- if naive_dice:
- b = torch.sum(input, 1)
- c = torch.sum(target, 1)
- d = (2 * a + eps) / (b + c + eps)
- else:
- b = torch.sum(input * input, 1) + eps
- c = torch.sum(target * target, 1) + eps
- d = (2 * a) / (b + c)
- loss = 1 - d
- if weight is not None:
- assert weight.ndim == loss.ndim
- assert len(weight) == len(pred)
- loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
- return loss
- @MODELS.register_module()
- class DiceLoss(nn.Module):
- def __init__(self,
- use_sigmoid=True,
- activate=True,
- reduction='mean',
- naive_dice=False,
- loss_weight=1.0,
- eps=1e-3):
- """Compute dice loss.
- Args:
- use_sigmoid (bool, optional): Whether to the prediction is
- used for sigmoid or softmax. Defaults to True.
- activate (bool): Whether to activate the predictions inside,
- this will disable the inside sigmoid operation.
- Defaults to True.
- reduction (str, optional): The method used
- to reduce the loss. Options are "none",
- "mean" and "sum". Defaults to 'mean'.
- naive_dice (bool, optional): If false, use the dice
- loss defined in the V-Net paper, otherwise, use the
- naive dice loss in which the power of the number in the
- denominator is the first power instead of the second
- power. Defaults to False.
- loss_weight (float, optional): Weight of loss. Defaults to 1.0.
- eps (float): Avoid dividing by zero. Defaults to 1e-3.
- """
- super(DiceLoss, self).__init__()
- self.use_sigmoid = use_sigmoid
- self.reduction = reduction
- self.naive_dice = naive_dice
- self.loss_weight = loss_weight
- self.eps = eps
- self.activate = activate
- def forward(self,
- pred,
- target,
- weight=None,
- reduction_override=None,
- avg_factor=None):
- """Forward function.
- Args:
- pred (torch.Tensor): The prediction, has a shape (n, *).
- target (torch.Tensor): The label of the prediction,
- shape (n, *), same shape of pred.
- weight (torch.Tensor, optional): The weight of loss for each
- prediction, has a shape (n,). Defaults to None.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- reduction_override (str, optional): The reduction method used to
- override the original reduction method of the loss.
- Options are "none", "mean" and "sum".
- Returns:
- torch.Tensor: The calculated loss
- """
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- if self.activate:
- if self.use_sigmoid:
- pred = pred.sigmoid()
- else:
- raise NotImplementedError
- loss = self.loss_weight * dice_loss(
- pred,
- target,
- weight,
- eps=self.eps,
- reduction=reduction,
- naive_dice=self.naive_dice,
- avg_factor=avg_factor)
- return loss
|