smooth_l1_loss.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from torch import Tensor
  6. from mmdet.registry import MODELS
  7. from .utils import weighted_loss
  8. @weighted_loss
  9. def smooth_l1_loss(pred: Tensor, target: Tensor, beta: float = 1.0) -> Tensor:
  10. """Smooth L1 loss.
  11. Args:
  12. pred (Tensor): The prediction.
  13. target (Tensor): The learning target of the prediction.
  14. beta (float, optional): The threshold in the piecewise function.
  15. Defaults to 1.0.
  16. Returns:
  17. Tensor: Calculated loss
  18. """
  19. assert beta > 0
  20. if target.numel() == 0:
  21. return pred.sum() * 0
  22. assert pred.size() == target.size()
  23. diff = torch.abs(pred - target)
  24. loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
  25. diff - 0.5 * beta)
  26. return loss
  27. @weighted_loss
  28. def l1_loss(pred: Tensor, target: Tensor) -> Tensor:
  29. """L1 loss.
  30. Args:
  31. pred (Tensor): The prediction.
  32. target (Tensor): The learning target of the prediction.
  33. Returns:
  34. Tensor: Calculated loss
  35. """
  36. if target.numel() == 0:
  37. return pred.sum() * 0
  38. assert pred.size() == target.size()
  39. loss = torch.abs(pred - target)
  40. return loss
  41. @MODELS.register_module()
  42. class SmoothL1Loss(nn.Module):
  43. """Smooth L1 loss.
  44. Args:
  45. beta (float, optional): The threshold in the piecewise function.
  46. Defaults to 1.0.
  47. reduction (str, optional): The method to reduce the loss.
  48. Options are "none", "mean" and "sum". Defaults to "mean".
  49. loss_weight (float, optional): The weight of loss.
  50. """
  51. def __init__(self,
  52. beta: float = 1.0,
  53. reduction: str = 'mean',
  54. loss_weight: float = 1.0) -> None:
  55. super().__init__()
  56. self.beta = beta
  57. self.reduction = reduction
  58. self.loss_weight = loss_weight
  59. def forward(self,
  60. pred: Tensor,
  61. target: Tensor,
  62. weight: Optional[Tensor] = None,
  63. avg_factor: Optional[int] = None,
  64. reduction_override: Optional[str] = None,
  65. **kwargs) -> Tensor:
  66. """Forward function.
  67. Args:
  68. pred (Tensor): The prediction.
  69. target (Tensor): The learning target of the prediction.
  70. weight (Tensor, optional): The weight of loss for each
  71. prediction. Defaults to None.
  72. avg_factor (int, optional): Average factor that is used to average
  73. the loss. Defaults to None.
  74. reduction_override (str, optional): The reduction method used to
  75. override the original reduction method of the loss.
  76. Defaults to None.
  77. Returns:
  78. Tensor: Calculated loss
  79. """
  80. assert reduction_override in (None, 'none', 'mean', 'sum')
  81. reduction = (
  82. reduction_override if reduction_override else self.reduction)
  83. loss_bbox = self.loss_weight * smooth_l1_loss(
  84. pred,
  85. target,
  86. weight,
  87. beta=self.beta,
  88. reduction=reduction,
  89. avg_factor=avg_factor,
  90. **kwargs)
  91. return loss_bbox
  92. @MODELS.register_module()
  93. class L1Loss(nn.Module):
  94. """L1 loss.
  95. Args:
  96. reduction (str, optional): The method to reduce the loss.
  97. Options are "none", "mean" and "sum".
  98. loss_weight (float, optional): The weight of loss.
  99. """
  100. def __init__(self,
  101. reduction: str = 'mean',
  102. loss_weight: float = 1.0) -> None:
  103. super().__init__()
  104. self.reduction = reduction
  105. self.loss_weight = loss_weight
  106. def forward(self,
  107. pred: Tensor,
  108. target: Tensor,
  109. weight: Optional[Tensor] = None,
  110. avg_factor: Optional[int] = None,
  111. reduction_override: Optional[str] = None) -> Tensor:
  112. """Forward function.
  113. Args:
  114. pred (Tensor): The prediction.
  115. target (Tensor): The learning target of the prediction.
  116. weight (Tensor, optional): The weight of loss for each
  117. prediction. Defaults to None.
  118. avg_factor (int, optional): Average factor that is used to average
  119. the loss. Defaults to None.
  120. reduction_override (str, optional): The reduction method used to
  121. override the original reduction method of the loss.
  122. Defaults to None.
  123. Returns:
  124. Tensor: Calculated loss
  125. """
  126. assert reduction_override in (None, 'none', 'mean', 'sum')
  127. reduction = (
  128. reduction_override if reduction_override else self.reduction)
  129. loss_bbox = self.loss_weight * l1_loss(
  130. pred, target, weight, reduction=reduction, avg_factor=avg_factor)
  131. return loss_bbox