ghm_loss.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmdet.registry import MODELS
  6. from .utils import weight_reduce_loss
  7. def _expand_onehot_labels(labels, label_weights, label_channels):
  8. bin_labels = labels.new_full((labels.size(0), label_channels), 0)
  9. inds = torch.nonzero(
  10. (labels >= 0) & (labels < label_channels), as_tuple=False).squeeze()
  11. if inds.numel() > 0:
  12. bin_labels[inds, labels[inds]] = 1
  13. bin_label_weights = label_weights.view(-1, 1).expand(
  14. label_weights.size(0), label_channels)
  15. return bin_labels, bin_label_weights
  16. # TODO: code refactoring to make it consistent with other losses
  17. @MODELS.register_module()
  18. class GHMC(nn.Module):
  19. """GHM Classification Loss.
  20. Details of the theorem can be viewed in the paper
  21. `Gradient Harmonized Single-stage Detector
  22. <https://arxiv.org/abs/1811.05181>`_.
  23. Args:
  24. bins (int): Number of the unit regions for distribution calculation.
  25. momentum (float): The parameter for moving average.
  26. use_sigmoid (bool): Can only be true for BCE based loss now.
  27. loss_weight (float): The weight of the total GHM-C loss.
  28. reduction (str): Options are "none", "mean" and "sum".
  29. Defaults to "mean"
  30. """
  31. def __init__(self,
  32. bins=10,
  33. momentum=0,
  34. use_sigmoid=True,
  35. loss_weight=1.0,
  36. reduction='mean'):
  37. super(GHMC, self).__init__()
  38. self.bins = bins
  39. self.momentum = momentum
  40. edges = torch.arange(bins + 1).float() / bins
  41. self.register_buffer('edges', edges)
  42. self.edges[-1] += 1e-6
  43. if momentum > 0:
  44. acc_sum = torch.zeros(bins)
  45. self.register_buffer('acc_sum', acc_sum)
  46. self.use_sigmoid = use_sigmoid
  47. if not self.use_sigmoid:
  48. raise NotImplementedError
  49. self.loss_weight = loss_weight
  50. self.reduction = reduction
  51. def forward(self,
  52. pred,
  53. target,
  54. label_weight,
  55. reduction_override=None,
  56. **kwargs):
  57. """Calculate the GHM-C loss.
  58. Args:
  59. pred (float tensor of size [batch_num, class_num]):
  60. The direct prediction of classification fc layer.
  61. target (float tensor of size [batch_num, class_num]):
  62. Binary class target for each sample.
  63. label_weight (float tensor of size [batch_num, class_num]):
  64. the value is 1 if the sample is valid and 0 if ignored.
  65. reduction_override (str, optional): The reduction method used to
  66. override the original reduction method of the loss.
  67. Defaults to None.
  68. Returns:
  69. The gradient harmonized loss.
  70. """
  71. assert reduction_override in (None, 'none', 'mean', 'sum')
  72. reduction = (
  73. reduction_override if reduction_override else self.reduction)
  74. # the target should be binary class label
  75. if pred.dim() != target.dim():
  76. target, label_weight = _expand_onehot_labels(
  77. target, label_weight, pred.size(-1))
  78. target, label_weight = target.float(), label_weight.float()
  79. edges = self.edges
  80. mmt = self.momentum
  81. weights = torch.zeros_like(pred)
  82. # gradient length
  83. g = torch.abs(pred.sigmoid().detach() - target)
  84. valid = label_weight > 0
  85. tot = max(valid.float().sum().item(), 1.0)
  86. n = 0 # n valid bins
  87. for i in range(self.bins):
  88. inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
  89. num_in_bin = inds.sum().item()
  90. if num_in_bin > 0:
  91. if mmt > 0:
  92. self.acc_sum[i] = mmt * self.acc_sum[i] \
  93. + (1 - mmt) * num_in_bin
  94. weights[inds] = tot / self.acc_sum[i]
  95. else:
  96. weights[inds] = tot / num_in_bin
  97. n += 1
  98. if n > 0:
  99. weights = weights / n
  100. loss = F.binary_cross_entropy_with_logits(
  101. pred, target, reduction='none')
  102. loss = weight_reduce_loss(
  103. loss, weights, reduction=reduction, avg_factor=tot)
  104. return loss * self.loss_weight
  105. # TODO: code refactoring to make it consistent with other losses
  106. @MODELS.register_module()
  107. class GHMR(nn.Module):
  108. """GHM Regression Loss.
  109. Details of the theorem can be viewed in the paper
  110. `Gradient Harmonized Single-stage Detector
  111. <https://arxiv.org/abs/1811.05181>`_.
  112. Args:
  113. mu (float): The parameter for the Authentic Smooth L1 loss.
  114. bins (int): Number of the unit regions for distribution calculation.
  115. momentum (float): The parameter for moving average.
  116. loss_weight (float): The weight of the total GHM-R loss.
  117. reduction (str): Options are "none", "mean" and "sum".
  118. Defaults to "mean"
  119. """
  120. def __init__(self,
  121. mu=0.02,
  122. bins=10,
  123. momentum=0,
  124. loss_weight=1.0,
  125. reduction='mean'):
  126. super(GHMR, self).__init__()
  127. self.mu = mu
  128. self.bins = bins
  129. edges = torch.arange(bins + 1).float() / bins
  130. self.register_buffer('edges', edges)
  131. self.edges[-1] = 1e3
  132. self.momentum = momentum
  133. if momentum > 0:
  134. acc_sum = torch.zeros(bins)
  135. self.register_buffer('acc_sum', acc_sum)
  136. self.loss_weight = loss_weight
  137. self.reduction = reduction
  138. # TODO: support reduction parameter
  139. def forward(self,
  140. pred,
  141. target,
  142. label_weight,
  143. avg_factor=None,
  144. reduction_override=None):
  145. """Calculate the GHM-R loss.
  146. Args:
  147. pred (float tensor of size [batch_num, 4 (* class_num)]):
  148. The prediction of box regression layer. Channel number can be 4
  149. or 4 * class_num depending on whether it is class-agnostic.
  150. target (float tensor of size [batch_num, 4 (* class_num)]):
  151. The target regression values with the same size of pred.
  152. label_weight (float tensor of size [batch_num, 4 (* class_num)]):
  153. The weight of each sample, 0 if ignored.
  154. reduction_override (str, optional): The reduction method used to
  155. override the original reduction method of the loss.
  156. Defaults to None.
  157. Returns:
  158. The gradient harmonized loss.
  159. """
  160. assert reduction_override in (None, 'none', 'mean', 'sum')
  161. reduction = (
  162. reduction_override if reduction_override else self.reduction)
  163. mu = self.mu
  164. edges = self.edges
  165. mmt = self.momentum
  166. # ASL1 loss
  167. diff = pred - target
  168. loss = torch.sqrt(diff * diff + mu * mu) - mu
  169. # gradient length
  170. g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach()
  171. weights = torch.zeros_like(g)
  172. valid = label_weight > 0
  173. tot = max(label_weight.float().sum().item(), 1.0)
  174. n = 0 # n: valid bins
  175. for i in range(self.bins):
  176. inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
  177. num_in_bin = inds.sum().item()
  178. if num_in_bin > 0:
  179. n += 1
  180. if mmt > 0:
  181. self.acc_sum[i] = mmt * self.acc_sum[i] \
  182. + (1 - mmt) * num_in_bin
  183. weights[inds] = tot / self.acc_sum[i]
  184. else:
  185. weights[inds] = tot / num_in_bin
  186. if n > 0:
  187. weights /= n
  188. loss = weight_reduce_loss(
  189. loss, weights, reduction=reduction, avg_factor=tot)
  190. return loss * self.loss_weight