gfocal_loss.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from functools import partial
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmdet.models.losses.utils import weighted_loss
  7. from mmdet.registry import MODELS
  8. @weighted_loss
  9. def quality_focal_loss(pred, target, beta=2.0):
  10. r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
  11. Qualified and Distributed Bounding Boxes for Dense Object Detection
  12. <https://arxiv.org/abs/2006.04388>`_.
  13. Args:
  14. pred (torch.Tensor): Predicted joint representation of classification
  15. and quality (IoU) estimation with shape (N, C), C is the number of
  16. classes.
  17. target (tuple([torch.Tensor])): Target category label with shape (N,)
  18. and target quality label with shape (N,).
  19. beta (float): The beta parameter for calculating the modulating factor.
  20. Defaults to 2.0.
  21. Returns:
  22. torch.Tensor: Loss tensor with shape (N,).
  23. """
  24. assert len(target) == 2, """target for QFL must be a tuple of two elements,
  25. including category label and quality label, respectively"""
  26. # label denotes the category id, score denotes the quality score
  27. label, score = target
  28. # negatives are supervised by 0 quality score
  29. pred_sigmoid = pred.sigmoid()
  30. scale_factor = pred_sigmoid
  31. zerolabel = scale_factor.new_zeros(pred.shape)
  32. loss = F.binary_cross_entropy_with_logits(
  33. pred, zerolabel, reduction='none') * scale_factor.pow(beta)
  34. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  35. bg_class_ind = pred.size(1)
  36. pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
  37. pos_label = label[pos].long()
  38. # positives are supervised by bbox quality (IoU) score
  39. scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
  40. loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
  41. pred[pos, pos_label], score[pos],
  42. reduction='none') * scale_factor.abs().pow(beta)
  43. loss = loss.sum(dim=1, keepdim=False)
  44. return loss
  45. @weighted_loss
  46. def quality_focal_loss_tensor_target(pred, target, beta=2.0, activated=False):
  47. """`QualityFocal Loss <https://arxiv.org/abs/2008.13367>`_
  48. Args:
  49. pred (torch.Tensor): The prediction with shape (N, C), C is the
  50. number of classes
  51. target (torch.Tensor): The learning target of the iou-aware
  52. classification score with shape (N, C), C is the number of classes.
  53. beta (float): The beta parameter for calculating the modulating factor.
  54. Defaults to 2.0.
  55. activated (bool): Whether the input is activated.
  56. If True, it means the input has been activated and can be
  57. treated as probabilities. Else, it should be treated as logits.
  58. Defaults to False.
  59. """
  60. # pred and target should be of the same size
  61. assert pred.size() == target.size()
  62. if activated:
  63. pred_sigmoid = pred
  64. loss_function = F.binary_cross_entropy
  65. else:
  66. pred_sigmoid = pred.sigmoid()
  67. loss_function = F.binary_cross_entropy_with_logits
  68. scale_factor = pred_sigmoid
  69. target = target.type_as(pred)
  70. zerolabel = scale_factor.new_zeros(pred.shape)
  71. loss = loss_function(
  72. pred, zerolabel, reduction='none') * scale_factor.pow(beta)
  73. pos = (target != 0)
  74. scale_factor = target[pos] - pred_sigmoid[pos]
  75. loss[pos] = loss_function(
  76. pred[pos], target[pos],
  77. reduction='none') * scale_factor.abs().pow(beta)
  78. loss = loss.sum(dim=1, keepdim=False)
  79. return loss
  80. @weighted_loss
  81. def quality_focal_loss_with_prob(pred, target, beta=2.0):
  82. r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
  83. Qualified and Distributed Bounding Boxes for Dense Object Detection
  84. <https://arxiv.org/abs/2006.04388>`_.
  85. Different from `quality_focal_loss`, this function accepts probability
  86. as input.
  87. Args:
  88. pred (torch.Tensor): Predicted joint representation of classification
  89. and quality (IoU) estimation with shape (N, C), C is the number of
  90. classes.
  91. target (tuple([torch.Tensor])): Target category label with shape (N,)
  92. and target quality label with shape (N,).
  93. beta (float): The beta parameter for calculating the modulating factor.
  94. Defaults to 2.0.
  95. Returns:
  96. torch.Tensor: Loss tensor with shape (N,).
  97. """
  98. assert len(target) == 2, """target for QFL must be a tuple of two elements,
  99. including category label and quality label, respectively"""
  100. # label denotes the category id, score denotes the quality score
  101. label, score = target
  102. # negatives are supervised by 0 quality score
  103. pred_sigmoid = pred
  104. scale_factor = pred_sigmoid
  105. zerolabel = scale_factor.new_zeros(pred.shape)
  106. loss = F.binary_cross_entropy(
  107. pred, zerolabel, reduction='none') * scale_factor.pow(beta)
  108. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  109. bg_class_ind = pred.size(1)
  110. pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
  111. pos_label = label[pos].long()
  112. # positives are supervised by bbox quality (IoU) score
  113. scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
  114. loss[pos, pos_label] = F.binary_cross_entropy(
  115. pred[pos, pos_label], score[pos],
  116. reduction='none') * scale_factor.abs().pow(beta)
  117. loss = loss.sum(dim=1, keepdim=False)
  118. return loss
  119. @weighted_loss
  120. def distribution_focal_loss(pred, label):
  121. r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
  122. Qualified and Distributed Bounding Boxes for Dense Object Detection
  123. <https://arxiv.org/abs/2006.04388>`_.
  124. Args:
  125. pred (torch.Tensor): Predicted general distribution of bounding boxes
  126. (before softmax) with shape (N, n+1), n is the max value of the
  127. integral set `{0, ..., n}` in paper.
  128. label (torch.Tensor): Target distance label for bounding boxes with
  129. shape (N,).
  130. Returns:
  131. torch.Tensor: Loss tensor with shape (N,).
  132. """
  133. dis_left = label.long()
  134. dis_right = dis_left + 1
  135. weight_left = dis_right.float() - label
  136. weight_right = label - dis_left.float()
  137. loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \
  138. + F.cross_entropy(pred, dis_right, reduction='none') * weight_right
  139. return loss
  140. @MODELS.register_module()
  141. class QualityFocalLoss(nn.Module):
  142. r"""Quality Focal Loss (QFL) is a variant of `Generalized Focal Loss:
  143. Learning Qualified and Distributed Bounding Boxes for Dense Object
  144. Detection <https://arxiv.org/abs/2006.04388>`_.
  145. Args:
  146. use_sigmoid (bool): Whether sigmoid operation is conducted in QFL.
  147. Defaults to True.
  148. beta (float): The beta parameter for calculating the modulating factor.
  149. Defaults to 2.0.
  150. reduction (str): Options are "none", "mean" and "sum".
  151. loss_weight (float): Loss weight of current loss.
  152. activated (bool, optional): Whether the input is activated.
  153. If True, it means the input has been activated and can be
  154. treated as probabilities. Else, it should be treated as logits.
  155. Defaults to False.
  156. """
  157. def __init__(self,
  158. use_sigmoid=True,
  159. beta=2.0,
  160. reduction='mean',
  161. loss_weight=1.0,
  162. activated=False):
  163. super(QualityFocalLoss, self).__init__()
  164. assert use_sigmoid is True, 'Only sigmoid in QFL supported now.'
  165. self.use_sigmoid = use_sigmoid
  166. self.beta = beta
  167. self.reduction = reduction
  168. self.loss_weight = loss_weight
  169. self.activated = activated
  170. def forward(self,
  171. pred,
  172. target,
  173. weight=None,
  174. avg_factor=None,
  175. reduction_override=None):
  176. """Forward function.
  177. Args:
  178. pred (torch.Tensor): Predicted joint representation of
  179. classification and quality (IoU) estimation with shape (N, C),
  180. C is the number of classes.
  181. target (Union(tuple([torch.Tensor]),Torch.Tensor)): The type is
  182. tuple, it should be included Target category label with
  183. shape (N,) and target quality label with shape (N,).The type
  184. is torch.Tensor, the target should be one-hot form with
  185. soft weights.
  186. weight (torch.Tensor, optional): The weight of loss for each
  187. prediction. Defaults to None.
  188. avg_factor (int, optional): Average factor that is used to average
  189. the loss. Defaults to None.
  190. reduction_override (str, optional): The reduction method used to
  191. override the original reduction method of the loss.
  192. Defaults to None.
  193. """
  194. assert reduction_override in (None, 'none', 'mean', 'sum')
  195. reduction = (
  196. reduction_override if reduction_override else self.reduction)
  197. if self.use_sigmoid:
  198. if self.activated:
  199. calculate_loss_func = quality_focal_loss_with_prob
  200. else:
  201. calculate_loss_func = quality_focal_loss
  202. if isinstance(target, torch.Tensor):
  203. # the target shape with (N,C) or (N,C,...), which means
  204. # the target is one-hot form with soft weights.
  205. calculate_loss_func = partial(
  206. quality_focal_loss_tensor_target, activated=self.activated)
  207. loss_cls = self.loss_weight * calculate_loss_func(
  208. pred,
  209. target,
  210. weight,
  211. beta=self.beta,
  212. reduction=reduction,
  213. avg_factor=avg_factor)
  214. else:
  215. raise NotImplementedError
  216. return loss_cls
  217. @MODELS.register_module()
  218. class DistributionFocalLoss(nn.Module):
  219. r"""Distribution Focal Loss (DFL) is a variant of `Generalized Focal Loss:
  220. Learning Qualified and Distributed Bounding Boxes for Dense Object
  221. Detection <https://arxiv.org/abs/2006.04388>`_.
  222. Args:
  223. reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
  224. loss_weight (float): Loss weight of current loss.
  225. """
  226. def __init__(self, reduction='mean', loss_weight=1.0):
  227. super(DistributionFocalLoss, self).__init__()
  228. self.reduction = reduction
  229. self.loss_weight = loss_weight
  230. def forward(self,
  231. pred,
  232. target,
  233. weight=None,
  234. avg_factor=None,
  235. reduction_override=None):
  236. """Forward function.
  237. Args:
  238. pred (torch.Tensor): Predicted general distribution of bounding
  239. boxes (before softmax) with shape (N, n+1), n is the max value
  240. of the integral set `{0, ..., n}` in paper.
  241. target (torch.Tensor): Target distance label for bounding boxes
  242. with shape (N,).
  243. weight (torch.Tensor, optional): The weight of loss for each
  244. prediction. Defaults to None.
  245. avg_factor (int, optional): Average factor that is used to average
  246. the loss. Defaults to None.
  247. reduction_override (str, optional): The reduction method used to
  248. override the original reduction method of the loss.
  249. Defaults to None.
  250. """
  251. assert reduction_override in (None, 'none', 'mean', 'sum')
  252. reduction = (
  253. reduction_override if reduction_override else self.reduction)
  254. loss_cls = self.loss_weight * distribution_focal_loss(
  255. pred, target, weight, reduction=reduction, avg_factor=avg_factor)
  256. return loss_cls