free_anchor_retina_head.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List
  3. import torch
  4. import torch.nn.functional as F
  5. from mmengine.structures import InstanceData
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from mmdet.structures.bbox import bbox_overlaps
  9. from mmdet.utils import InstanceList, OptConfigType, OptInstanceList
  10. from ..utils import multi_apply
  11. from .retina_head import RetinaHead
  12. EPS = 1e-12
  13. @MODELS.register_module()
  14. class FreeAnchorRetinaHead(RetinaHead):
  15. """FreeAnchor RetinaHead used in https://arxiv.org/abs/1909.02466.
  16. Args:
  17. num_classes (int): Number of categories excluding the background
  18. category.
  19. in_channels (int): Number of channels in the input feature map.
  20. stacked_convs (int): Number of conv layers in cls and reg tower.
  21. Defaults to 4.
  22. conv_cfg (:obj:`ConfigDict` or dict, optional): dictionary to
  23. construct and config conv layer. Defaults to None.
  24. norm_cfg (:obj:`ConfigDict` or dict, optional): dictionary to
  25. construct and config norm layer. Defaults to
  26. norm_cfg=dict(type='GN', num_groups=32, requires_grad=True).
  27. pre_anchor_topk (int): Number of boxes that be token in each bag.
  28. Defaults to 50
  29. bbox_thr (float): The threshold of the saturated linear function.
  30. It is usually the same with the IoU threshold used in NMS.
  31. Defaults to 0.6.
  32. gamma (float): Gamma parameter in focal loss. Defaults to 2.0.
  33. alpha (float): Alpha parameter in focal loss. Defaults to 0.5.
  34. """
  35. def __init__(self,
  36. num_classes: int,
  37. in_channels: int,
  38. stacked_convs: int = 4,
  39. conv_cfg: OptConfigType = None,
  40. norm_cfg: OptConfigType = None,
  41. pre_anchor_topk: int = 50,
  42. bbox_thr: float = 0.6,
  43. gamma: float = 2.0,
  44. alpha: float = 0.5,
  45. **kwargs) -> None:
  46. super().__init__(
  47. num_classes=num_classes,
  48. in_channels=in_channels,
  49. stacked_convs=stacked_convs,
  50. conv_cfg=conv_cfg,
  51. norm_cfg=norm_cfg,
  52. **kwargs)
  53. self.pre_anchor_topk = pre_anchor_topk
  54. self.bbox_thr = bbox_thr
  55. self.gamma = gamma
  56. self.alpha = alpha
  57. def loss_by_feat(
  58. self,
  59. cls_scores: List[Tensor],
  60. bbox_preds: List[Tensor],
  61. batch_gt_instances: InstanceList,
  62. batch_img_metas: List[dict],
  63. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  64. """Calculate the loss based on the features extracted by the detection
  65. head.
  66. Args:
  67. cls_scores (list[Tensor]): Box scores for each scale level
  68. has shape (N, num_anchors * num_classes, H, W).
  69. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  70. level with shape (N, num_anchors * 4, H, W).
  71. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  72. gt_instance. It usually includes ``bboxes`` and ``labels``
  73. attributes.
  74. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  75. image size, scaling factor, etc.
  76. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  77. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  78. data that is ignored during training and testing.
  79. Defaults to None.
  80. Returns:
  81. dict: A dictionary of loss components.
  82. """
  83. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  84. assert len(featmap_sizes) == self.prior_generator.num_levels
  85. device = cls_scores[0].device
  86. anchor_list, _ = self.get_anchors(
  87. featmap_sizes=featmap_sizes,
  88. batch_img_metas=batch_img_metas,
  89. device=device)
  90. concat_anchor_list = [torch.cat(anchor) for anchor in anchor_list]
  91. # concatenate each level
  92. cls_scores = [
  93. cls.permute(0, 2, 3,
  94. 1).reshape(cls.size(0), -1, self.cls_out_channels)
  95. for cls in cls_scores
  96. ]
  97. bbox_preds = [
  98. bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.size(0), -1, 4)
  99. for bbox_pred in bbox_preds
  100. ]
  101. cls_scores = torch.cat(cls_scores, dim=1)
  102. cls_probs = torch.sigmoid(cls_scores)
  103. bbox_preds = torch.cat(bbox_preds, dim=1)
  104. box_probs, positive_losses, num_pos_list = multi_apply(
  105. self.positive_loss_single, cls_probs, bbox_preds,
  106. concat_anchor_list, batch_gt_instances)
  107. num_pos = sum(num_pos_list)
  108. positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos)
  109. # box_prob: P{a_{j} \in A_{+}}
  110. box_probs = torch.stack(box_probs, dim=0)
  111. # negative_loss:
  112. # \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B||
  113. negative_loss = self.negative_bag_loss(cls_probs, box_probs).sum() / \
  114. max(1, num_pos * self.pre_anchor_topk)
  115. # avoid the absence of gradients in regression subnet
  116. # when no ground-truth in a batch
  117. if num_pos == 0:
  118. positive_loss = bbox_preds.sum() * 0
  119. losses = {
  120. 'positive_bag_loss': positive_loss,
  121. 'negative_bag_loss': negative_loss
  122. }
  123. return losses
  124. def positive_loss_single(self, cls_prob: Tensor, bbox_pred: Tensor,
  125. flat_anchors: Tensor,
  126. gt_instances: InstanceData) -> tuple:
  127. """Compute positive loss.
  128. Args:
  129. cls_prob (Tensor): Classification probability of shape
  130. (num_anchors, num_classes).
  131. bbox_pred (Tensor): Box probability of shape (num_anchors, 4).
  132. flat_anchors (Tensor): Multi-level anchors of the image, which are
  133. concatenated into a single tensor of shape (num_anchors, 4)
  134. gt_instances (:obj:`InstanceData`): Ground truth of instance
  135. annotations. It should includes ``bboxes`` and ``labels``
  136. attributes.
  137. Returns:
  138. tuple:
  139. - box_prob (Tensor): Box probability of shape (num_anchors, 4).
  140. - positive_loss (Tensor): Positive loss of shape (num_pos, ).
  141. - num_pos (int): positive samples indexes.
  142. """
  143. gt_bboxes = gt_instances.bboxes
  144. gt_labels = gt_instances.labels
  145. with torch.no_grad():
  146. if len(gt_bboxes) == 0:
  147. image_box_prob = torch.zeros(
  148. flat_anchors.size(0),
  149. self.cls_out_channels).type_as(bbox_pred)
  150. else:
  151. # box_localization: a_{j}^{loc}, shape: [j, 4]
  152. pred_boxes = self.bbox_coder.decode(flat_anchors, bbox_pred)
  153. # object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
  154. object_box_iou = bbox_overlaps(gt_bboxes, pred_boxes)
  155. # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
  156. t1 = self.bbox_thr
  157. t2 = object_box_iou.max(
  158. dim=1, keepdim=True).values.clamp(min=t1 + 1e-12)
  159. object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp(
  160. min=0, max=1)
  161. # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
  162. num_obj = gt_labels.size(0)
  163. indices = torch.stack(
  164. [torch.arange(num_obj).type_as(gt_labels), gt_labels],
  165. dim=0)
  166. object_cls_box_prob = torch.sparse_coo_tensor(
  167. indices, object_box_prob)
  168. # image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j]
  169. """
  170. from "start" to "end" implement:
  171. image_box_iou = torch.sparse.max(object_cls_box_prob,
  172. dim=0).t()
  173. """
  174. # start
  175. box_cls_prob = torch.sparse.sum(
  176. object_cls_box_prob, dim=0).to_dense()
  177. indices = torch.nonzero(box_cls_prob, as_tuple=False).t_()
  178. if indices.numel() == 0:
  179. image_box_prob = torch.zeros(
  180. flat_anchors.size(0),
  181. self.cls_out_channels).type_as(object_box_prob)
  182. else:
  183. nonzero_box_prob = torch.where(
  184. (gt_labels.unsqueeze(dim=-1) == indices[0]),
  185. object_box_prob[:, indices[1]],
  186. torch.tensor(
  187. [0]).type_as(object_box_prob)).max(dim=0).values
  188. # upmap to shape [j, c]
  189. image_box_prob = torch.sparse_coo_tensor(
  190. indices.flip([0]),
  191. nonzero_box_prob,
  192. size=(flat_anchors.size(0),
  193. self.cls_out_channels)).to_dense()
  194. # end
  195. box_prob = image_box_prob
  196. # construct bags for objects
  197. match_quality_matrix = bbox_overlaps(gt_bboxes, flat_anchors)
  198. _, matched = torch.topk(
  199. match_quality_matrix, self.pre_anchor_topk, dim=1, sorted=False)
  200. del match_quality_matrix
  201. # matched_cls_prob: P_{ij}^{cls}
  202. matched_cls_prob = torch.gather(
  203. cls_prob[matched], 2,
  204. gt_labels.view(-1, 1, 1).repeat(1, self.pre_anchor_topk,
  205. 1)).squeeze(2)
  206. # matched_box_prob: P_{ij}^{loc}
  207. matched_anchors = flat_anchors[matched]
  208. matched_object_targets = self.bbox_coder.encode(
  209. matched_anchors,
  210. gt_bboxes.unsqueeze(dim=1).expand_as(matched_anchors))
  211. loss_bbox = self.loss_bbox(
  212. bbox_pred[matched],
  213. matched_object_targets,
  214. reduction_override='none').sum(-1)
  215. matched_box_prob = torch.exp(-loss_bbox)
  216. # positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
  217. num_pos = len(gt_bboxes)
  218. positive_loss = self.positive_bag_loss(matched_cls_prob,
  219. matched_box_prob)
  220. return box_prob, positive_loss, num_pos
  221. def positive_bag_loss(self, matched_cls_prob: Tensor,
  222. matched_box_prob: Tensor) -> Tensor:
  223. """Compute positive bag loss.
  224. :math:`-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )`.
  225. :math:`P_{ij}^{cls}`: matched_cls_prob, classification probability of matched samples.
  226. :math:`P_{ij}^{loc}`: matched_box_prob, box probability of matched samples.
  227. Args:
  228. matched_cls_prob (Tensor): Classification probability of matched
  229. samples in shape (num_gt, pre_anchor_topk).
  230. matched_box_prob (Tensor): BBox probability of matched samples,
  231. in shape (num_gt, pre_anchor_topk).
  232. Returns:
  233. Tensor: Positive bag loss in shape (num_gt,).
  234. """ # noqa: E501, W605
  235. # bag_prob = Mean-max(matched_prob)
  236. matched_prob = matched_cls_prob * matched_box_prob
  237. weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None)
  238. weight /= weight.sum(dim=1).unsqueeze(dim=-1)
  239. bag_prob = (weight * matched_prob).sum(dim=1)
  240. # positive_bag_loss = -self.alpha * log(bag_prob)
  241. return self.alpha * F.binary_cross_entropy(
  242. bag_prob, torch.ones_like(bag_prob), reduction='none')
  243. def negative_bag_loss(self, cls_prob: Tensor, box_prob: Tensor) -> Tensor:
  244. """Compute negative bag loss.
  245. :math:`FL((1 - P_{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}))`.
  246. :math:`P_{a_{j} \in A_{+}}`: Box_probability of matched samples.
  247. :math:`P_{j}^{bg}`: Classification probability of negative samples.
  248. Args:
  249. cls_prob (Tensor): Classification probability, in shape
  250. (num_img, num_anchors, num_classes).
  251. box_prob (Tensor): Box probability, in shape
  252. (num_img, num_anchors, num_classes).
  253. Returns:
  254. Tensor: Negative bag loss in shape (num_img, num_anchors,
  255. num_classes).
  256. """ # noqa: E501, W605
  257. prob = cls_prob * (1 - box_prob)
  258. # There are some cases when neg_prob = 0.
  259. # This will cause the neg_prob.log() to be inf without clamp.
  260. prob = prob.clamp(min=EPS, max=1 - EPS)
  261. negative_bag_loss = prob**self.gamma * F.binary_cross_entropy(
  262. prob, torch.zeros_like(prob), reduction='none')
  263. return (1 - self.alpha) * negative_bag_loss