dice_loss.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmdet.registry import MODELS
  5. from .utils import weight_reduce_loss
  6. def dice_loss(pred,
  7. target,
  8. weight=None,
  9. eps=1e-3,
  10. reduction='mean',
  11. naive_dice=False,
  12. avg_factor=None):
  13. """Calculate dice loss, there are two forms of dice loss is supported:
  14. - the one proposed in `V-Net: Fully Convolutional Neural
  15. Networks for Volumetric Medical Image Segmentation
  16. <https://arxiv.org/abs/1606.04797>`_.
  17. - the dice loss in which the power of the number in the
  18. denominator is the first power instead of the second
  19. power.
  20. Args:
  21. pred (torch.Tensor): The prediction, has a shape (n, *)
  22. target (torch.Tensor): The learning label of the prediction,
  23. shape (n, *), same shape of pred.
  24. weight (torch.Tensor, optional): The weight of loss for each
  25. prediction, has a shape (n,). Defaults to None.
  26. eps (float): Avoid dividing by zero. Default: 1e-3.
  27. reduction (str, optional): The method used to reduce the loss into
  28. a scalar. Defaults to 'mean'.
  29. Options are "none", "mean" and "sum".
  30. naive_dice (bool, optional): If false, use the dice
  31. loss defined in the V-Net paper, otherwise, use the
  32. naive dice loss in which the power of the number in the
  33. denominator is the first power instead of the second
  34. power.Defaults to False.
  35. avg_factor (int, optional): Average factor that is used to average
  36. the loss. Defaults to None.
  37. """
  38. input = pred.flatten(1)
  39. target = target.flatten(1).float()
  40. a = torch.sum(input * target, 1)
  41. if naive_dice:
  42. b = torch.sum(input, 1)
  43. c = torch.sum(target, 1)
  44. d = (2 * a + eps) / (b + c + eps)
  45. else:
  46. b = torch.sum(input * input, 1) + eps
  47. c = torch.sum(target * target, 1) + eps
  48. d = (2 * a) / (b + c)
  49. loss = 1 - d
  50. if weight is not None:
  51. assert weight.ndim == loss.ndim
  52. assert len(weight) == len(pred)
  53. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  54. return loss
  55. @MODELS.register_module()
  56. class DiceLoss(nn.Module):
  57. def __init__(self,
  58. use_sigmoid=True,
  59. activate=True,
  60. reduction='mean',
  61. naive_dice=False,
  62. loss_weight=1.0,
  63. eps=1e-3):
  64. """Compute dice loss.
  65. Args:
  66. use_sigmoid (bool, optional): Whether to the prediction is
  67. used for sigmoid or softmax. Defaults to True.
  68. activate (bool): Whether to activate the predictions inside,
  69. this will disable the inside sigmoid operation.
  70. Defaults to True.
  71. reduction (str, optional): The method used
  72. to reduce the loss. Options are "none",
  73. "mean" and "sum". Defaults to 'mean'.
  74. naive_dice (bool, optional): If false, use the dice
  75. loss defined in the V-Net paper, otherwise, use the
  76. naive dice loss in which the power of the number in the
  77. denominator is the first power instead of the second
  78. power. Defaults to False.
  79. loss_weight (float, optional): Weight of loss. Defaults to 1.0.
  80. eps (float): Avoid dividing by zero. Defaults to 1e-3.
  81. """
  82. super(DiceLoss, self).__init__()
  83. self.use_sigmoid = use_sigmoid
  84. self.reduction = reduction
  85. self.naive_dice = naive_dice
  86. self.loss_weight = loss_weight
  87. self.eps = eps
  88. self.activate = activate
  89. def forward(self,
  90. pred,
  91. target,
  92. weight=None,
  93. reduction_override=None,
  94. avg_factor=None):
  95. """Forward function.
  96. Args:
  97. pred (torch.Tensor): The prediction, has a shape (n, *).
  98. target (torch.Tensor): The label of the prediction,
  99. shape (n, *), same shape of pred.
  100. weight (torch.Tensor, optional): The weight of loss for each
  101. prediction, has a shape (n,). Defaults to None.
  102. avg_factor (int, optional): Average factor that is used to average
  103. the loss. Defaults to None.
  104. reduction_override (str, optional): The reduction method used to
  105. override the original reduction method of the loss.
  106. Options are "none", "mean" and "sum".
  107. Returns:
  108. torch.Tensor: The calculated loss
  109. """
  110. assert reduction_override in (None, 'none', 'mean', 'sum')
  111. reduction = (
  112. reduction_override if reduction_override else self.reduction)
  113. if self.activate:
  114. if self.use_sigmoid:
  115. pred = pred.sigmoid()
  116. else:
  117. raise NotImplementedError
  118. loss = self.loss_weight * dice_loss(
  119. pred,
  120. target,
  121. weight,
  122. eps=self.eps,
  123. reduction=reduction,
  124. naive_dice=self.naive_dice,
  125. avg_factor=avg_factor)
  126. return loss