centernet_rpn_head.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from typing import List, Sequence, Tuple
  4. import torch
  5. import torch.nn as nn
  6. from mmcv.cnn import Scale
  7. from mmengine import ConfigDict
  8. from mmengine.structures import InstanceData
  9. from torch import Tensor
  10. from mmdet.models.dense_heads import CenterNetUpdateHead
  11. from mmdet.models.utils import multi_apply
  12. from mmdet.registry import MODELS
  13. INF = 1000000000
  14. RangeType = Sequence[Tuple[int, int]]
  15. @MODELS.register_module(force=True) # avoid bug
  16. class CenterNetRPNHead(CenterNetUpdateHead):
  17. """CenterNetUpdateHead is an improved version of CenterNet in CenterNet2.
  18. Paper link `<https://arxiv.org/abs/2103.07461>`_.
  19. """
  20. def _init_layers(self) -> None:
  21. """Initialize layers of the head."""
  22. self._init_reg_convs()
  23. self._init_predictor()
  24. def _init_predictor(self) -> None:
  25. """Initialize predictor layers of the head."""
  26. self.conv_cls = nn.Conv2d(
  27. self.feat_channels, self.num_classes, 3, padding=1)
  28. self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
  29. def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
  30. """Forward features from the upstream network.
  31. Args:
  32. x (tuple[Tensor]): Features from the upstream network, each is
  33. a 4D-tensor.
  34. Returns:
  35. tuple: A tuple of each level outputs.
  36. - cls_scores (list[Tensor]): Box scores for each scale level, \
  37. each is a 4D-tensor, the channel number is num_classes.
  38. - bbox_preds (list[Tensor]): Box energies / deltas for each \
  39. scale level, each is a 4D-tensor, the channel number is 4.
  40. """
  41. res = multi_apply(self.forward_single, x, self.scales, self.strides)
  42. return res
  43. def forward_single(self, x: Tensor, scale: Scale,
  44. stride: int) -> Tuple[Tensor, Tensor]:
  45. """Forward features of a single scale level.
  46. Args:
  47. x (Tensor): FPN feature maps of the specified stride.
  48. scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize
  49. the bbox prediction.
  50. stride (int): The corresponding stride for feature maps.
  51. Returns:
  52. tuple: scores for each class, bbox predictions of
  53. input feature maps.
  54. """
  55. for m in self.reg_convs:
  56. x = m(x)
  57. cls_score = self.conv_cls(x)
  58. bbox_pred = self.conv_reg(x)
  59. # scale the bbox_pred of different level
  60. # float to avoid overflow when enabling FP16
  61. bbox_pred = scale(bbox_pred).float()
  62. # bbox_pred needed for gradient computation has been modified
  63. # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
  64. # F.relu(bbox_pred) with bbox_pred.clamp(min=0)
  65. bbox_pred = bbox_pred.clamp(min=0)
  66. if not self.training:
  67. bbox_pred *= stride
  68. return cls_score, bbox_pred # score aligned, box larger
  69. def _predict_by_feat_single(self,
  70. cls_score_list: List[Tensor],
  71. bbox_pred_list: List[Tensor],
  72. score_factor_list: List[Tensor],
  73. mlvl_priors: List[Tensor],
  74. img_meta: dict,
  75. cfg: ConfigDict,
  76. rescale: bool = False,
  77. with_nms: bool = True) -> InstanceData:
  78. """Transform a single image's features extracted from the head into
  79. bbox results.
  80. Args:
  81. cls_score_list (list[Tensor]): Box scores from all scale
  82. levels of a single image, each item has shape
  83. (num_priors * num_classes, H, W).
  84. bbox_pred_list (list[Tensor]): Box energies / deltas from
  85. all scale levels of a single image, each item has shape
  86. (num_priors * 4, H, W).
  87. score_factor_list (list[Tensor]): Score factor from all scale
  88. levels of a single image, each item has shape
  89. (num_priors * 1, H, W).
  90. mlvl_priors (list[Tensor]): Each element in the list is
  91. the priors of a single level in feature pyramid. In all
  92. anchor-based methods, it has shape (num_priors, 4). In
  93. all anchor-free methods, it has shape (num_priors, 2)
  94. when `with_stride=True`, otherwise it still has shape
  95. (num_priors, 4).
  96. img_meta (dict): Image meta info.
  97. cfg (mmengine.Config): Test / postprocessing configuration,
  98. if None, test_cfg would be used.
  99. rescale (bool): If True, return boxes in original image space.
  100. Defaults to False.
  101. with_nms (bool): If True, do nms before return boxes.
  102. Defaults to True.
  103. Returns:
  104. :obj:`InstanceData`: Detection results of each image
  105. after the post process.
  106. Each item usually contains following keys.
  107. - scores (Tensor): Classification scores, has a shape
  108. (num_instance, )
  109. - labels (Tensor): Labels of bboxes, has a shape
  110. (num_instances, ).
  111. - bboxes (Tensor): Has a shape (num_instances, 4),
  112. the last dimension 4 arrange as (x1, y1, x2, y2).
  113. """
  114. cfg = self.test_cfg if cfg is None else cfg
  115. cfg = copy.deepcopy(cfg)
  116. nms_pre = cfg.get('nms_pre', -1)
  117. mlvl_bbox_preds = []
  118. mlvl_valid_priors = []
  119. mlvl_scores = []
  120. mlvl_labels = []
  121. for level_idx, (cls_score, bbox_pred, score_factor, priors) in \
  122. enumerate(zip(cls_score_list, bbox_pred_list,
  123. score_factor_list, mlvl_priors)):
  124. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  125. dim = self.bbox_coder.encode_size
  126. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim)
  127. cls_score = cls_score.permute(1, 2,
  128. 0).reshape(-1, self.cls_out_channels)
  129. heatmap = cls_score.sigmoid()
  130. score_thr = cfg.get('score_thr', 0)
  131. candidate_inds = heatmap > score_thr # 0.05
  132. pre_nms_top_n = candidate_inds.sum() # N
  133. pre_nms_top_n = pre_nms_top_n.clamp(max=nms_pre) # N
  134. heatmap = heatmap[candidate_inds] # n
  135. candidate_nonzeros = candidate_inds.nonzero() # n
  136. box_loc = candidate_nonzeros[:, 0] # n
  137. labels = candidate_nonzeros[:, 1] # n
  138. bbox_pred = bbox_pred[box_loc] # n x 4
  139. per_grids = priors[box_loc] # n x 2
  140. if candidate_inds.sum().item() > pre_nms_top_n.item():
  141. heatmap, top_k_indices = \
  142. heatmap.topk(pre_nms_top_n, sorted=False)
  143. labels = labels[top_k_indices]
  144. bbox_pred = bbox_pred[top_k_indices]
  145. per_grids = per_grids[top_k_indices]
  146. bboxes = self.bbox_coder.decode(per_grids, bbox_pred)
  147. # avoid invalid boxes in RoI heads
  148. bboxes[:, 2] = torch.max(bboxes[:, 2], bboxes[:, 0] + 0.01)
  149. bboxes[:, 3] = torch.max(bboxes[:, 3], bboxes[:, 1] + 0.01)
  150. mlvl_bbox_preds.append(bboxes)
  151. mlvl_valid_priors.append(priors)
  152. mlvl_scores.append(torch.sqrt(heatmap))
  153. mlvl_labels.append(labels)
  154. results = InstanceData()
  155. results.bboxes = torch.cat(mlvl_bbox_preds)
  156. results.scores = torch.cat(mlvl_scores)
  157. results.labels = torch.cat(mlvl_labels)
  158. return self._bbox_post_process(
  159. results=results,
  160. cfg=cfg,
  161. rescale=rescale,
  162. with_nms=with_nms,
  163. img_meta=img_meta)