balanced_l1_loss.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. from mmdet.registry import MODELS
  6. from .utils import weighted_loss
  7. @weighted_loss
  8. def balanced_l1_loss(pred,
  9. target,
  10. beta=1.0,
  11. alpha=0.5,
  12. gamma=1.5,
  13. reduction='mean'):
  14. """Calculate balanced L1 loss.
  15. Please see the `Libra R-CNN <https://arxiv.org/pdf/1904.02701.pdf>`_
  16. Args:
  17. pred (torch.Tensor): The prediction with shape (N, 4).
  18. target (torch.Tensor): The learning target of the prediction with
  19. shape (N, 4).
  20. beta (float): The loss is a piecewise function of prediction and target
  21. and ``beta`` serves as a threshold for the difference between the
  22. prediction and target. Defaults to 1.0.
  23. alpha (float): The denominator ``alpha`` in the balanced L1 loss.
  24. Defaults to 0.5.
  25. gamma (float): The ``gamma`` in the balanced L1 loss.
  26. Defaults to 1.5.
  27. reduction (str, optional): The method that reduces the loss to a
  28. scalar. Options are "none", "mean" and "sum".
  29. Returns:
  30. torch.Tensor: The calculated loss
  31. """
  32. assert beta > 0
  33. if target.numel() == 0:
  34. return pred.sum() * 0
  35. assert pred.size() == target.size()
  36. diff = torch.abs(pred - target)
  37. b = np.e**(gamma / alpha) - 1
  38. loss = torch.where(
  39. diff < beta, alpha / b *
  40. (b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff,
  41. gamma * diff + gamma / b - alpha * beta)
  42. return loss
  43. @MODELS.register_module()
  44. class BalancedL1Loss(nn.Module):
  45. """Balanced L1 Loss.
  46. arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
  47. Args:
  48. alpha (float): The denominator ``alpha`` in the balanced L1 loss.
  49. Defaults to 0.5.
  50. gamma (float): The ``gamma`` in the balanced L1 loss. Defaults to 1.5.
  51. beta (float, optional): The loss is a piecewise function of prediction
  52. and target. ``beta`` serves as a threshold for the difference
  53. between the prediction and target. Defaults to 1.0.
  54. reduction (str, optional): The method that reduces the loss to a
  55. scalar. Options are "none", "mean" and "sum".
  56. loss_weight (float, optional): The weight of the loss. Defaults to 1.0
  57. """
  58. def __init__(self,
  59. alpha=0.5,
  60. gamma=1.5,
  61. beta=1.0,
  62. reduction='mean',
  63. loss_weight=1.0):
  64. super(BalancedL1Loss, self).__init__()
  65. self.alpha = alpha
  66. self.gamma = gamma
  67. self.beta = beta
  68. self.reduction = reduction
  69. self.loss_weight = loss_weight
  70. def forward(self,
  71. pred,
  72. target,
  73. weight=None,
  74. avg_factor=None,
  75. reduction_override=None,
  76. **kwargs):
  77. """Forward function of loss.
  78. Args:
  79. pred (torch.Tensor): The prediction with shape (N, 4).
  80. target (torch.Tensor): The learning target of the prediction with
  81. shape (N, 4).
  82. weight (torch.Tensor, optional): Sample-wise loss weight with
  83. shape (N, ).
  84. avg_factor (int, optional): Average factor that is used to average
  85. the loss. Defaults to None.
  86. reduction_override (str, optional): The reduction method used to
  87. override the original reduction method of the loss.
  88. Options are "none", "mean" and "sum".
  89. Returns:
  90. torch.Tensor: The calculated loss
  91. """
  92. assert reduction_override in (None, 'none', 'mean', 'sum')
  93. reduction = (
  94. reduction_override if reduction_override else self.reduction)
  95. loss_bbox = self.loss_weight * balanced_l1_loss(
  96. pred,
  97. target,
  98. weight,
  99. alpha=self.alpha,
  100. gamma=self.gamma,
  101. beta=self.beta,
  102. reduction=reduction,
  103. avg_factor=avg_factor,
  104. **kwargs)
  105. return loss_bbox