rpn_head.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from typing import List, Optional, Tuple
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mmcv.cnn import ConvModule
  8. from mmcv.ops import batched_nms
  9. from mmengine.config import ConfigDict
  10. from mmengine.structures import InstanceData
  11. from torch import Tensor
  12. from mmdet.registry import MODELS
  13. from mmdet.structures.bbox import (cat_boxes, empty_box_as, get_box_tensor,
  14. get_box_wh, scale_boxes)
  15. from mmdet.utils import InstanceList, MultiConfig, OptInstanceList
  16. from .anchor_head import AnchorHead
  17. @MODELS.register_module()
  18. class RPNHead(AnchorHead):
  19. """Implementation of RPN head.
  20. Args:
  21. in_channels (int): Number of channels in the input feature map.
  22. num_classes (int): Number of categories excluding the background
  23. category. Defaults to 1.
  24. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or \
  25. list[dict]): Initialization config dict.
  26. num_convs (int): Number of convolution layers in the head.
  27. Defaults to 1.
  28. """ # noqa: W605
  29. def __init__(self,
  30. in_channels: int,
  31. num_classes: int = 1,
  32. init_cfg: MultiConfig = dict(
  33. type='Normal', layer='Conv2d', std=0.01),
  34. num_convs: int = 1,
  35. **kwargs) -> None:
  36. self.num_convs = num_convs
  37. assert num_classes == 1
  38. super().__init__(
  39. num_classes=num_classes,
  40. in_channels=in_channels,
  41. init_cfg=init_cfg,
  42. **kwargs)
  43. def _init_layers(self) -> None:
  44. """Initialize layers of the head."""
  45. if self.num_convs > 1:
  46. rpn_convs = []
  47. for i in range(self.num_convs):
  48. if i == 0:
  49. in_channels = self.in_channels
  50. else:
  51. in_channels = self.feat_channels
  52. # use ``inplace=False`` to avoid error: one of the variables
  53. # needed for gradient computation has been modified by an
  54. # inplace operation.
  55. rpn_convs.append(
  56. ConvModule(
  57. in_channels,
  58. self.feat_channels,
  59. 3,
  60. padding=1,
  61. inplace=False))
  62. self.rpn_conv = nn.Sequential(*rpn_convs)
  63. else:
  64. self.rpn_conv = nn.Conv2d(
  65. self.in_channels, self.feat_channels, 3, padding=1)
  66. self.rpn_cls = nn.Conv2d(self.feat_channels,
  67. self.num_base_priors * self.cls_out_channels,
  68. 1)
  69. reg_dim = self.bbox_coder.encode_size
  70. self.rpn_reg = nn.Conv2d(self.feat_channels,
  71. self.num_base_priors * reg_dim, 1)
  72. def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
  73. """Forward feature of a single scale level.
  74. Args:
  75. x (Tensor): Features of a single scale level.
  76. Returns:
  77. tuple:
  78. cls_score (Tensor): Cls scores for a single scale level \
  79. the channels number is num_base_priors * num_classes.
  80. bbox_pred (Tensor): Box energies / deltas for a single scale \
  81. level, the channels number is num_base_priors * 4.
  82. """
  83. x = self.rpn_conv(x)
  84. x = F.relu(x)
  85. rpn_cls_score = self.rpn_cls(x)
  86. rpn_bbox_pred = self.rpn_reg(x)
  87. return rpn_cls_score, rpn_bbox_pred
  88. def loss_by_feat(self,
  89. cls_scores: List[Tensor],
  90. bbox_preds: List[Tensor],
  91. batch_gt_instances: InstanceList,
  92. batch_img_metas: List[dict],
  93. batch_gt_instances_ignore: OptInstanceList = None) \
  94. -> dict:
  95. """Calculate the loss based on the features extracted by the detection
  96. head.
  97. Args:
  98. cls_scores (list[Tensor]): Box scores for each scale level,
  99. has shape (N, num_anchors * num_classes, H, W).
  100. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  101. level with shape (N, num_anchors * 4, H, W).
  102. batch_gt_instances (list[obj:InstanceData]): Batch of gt_instance.
  103. It usually includes ``bboxes`` and ``labels`` attributes.
  104. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  105. image size, scaling factor, etc.
  106. batch_gt_instances_ignore (list[obj:InstanceData], Optional):
  107. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  108. data that is ignored during training and testing.
  109. Returns:
  110. dict[str, Tensor]: A dictionary of loss components.
  111. """
  112. losses = super().loss_by_feat(
  113. cls_scores,
  114. bbox_preds,
  115. batch_gt_instances,
  116. batch_img_metas,
  117. batch_gt_instances_ignore=batch_gt_instances_ignore)
  118. return dict(
  119. loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
  120. def _predict_by_feat_single(self,
  121. cls_score_list: List[Tensor],
  122. bbox_pred_list: List[Tensor],
  123. score_factor_list: List[Tensor],
  124. mlvl_priors: List[Tensor],
  125. img_meta: dict,
  126. cfg: ConfigDict,
  127. rescale: bool = False,
  128. with_nms: bool = True) -> InstanceData:
  129. """Transform a single image's features extracted from the head into
  130. bbox results.
  131. Args:
  132. cls_score_list (list[Tensor]): Box scores from all scale
  133. levels of a single image, each item has shape
  134. (num_priors * num_classes, H, W).
  135. bbox_pred_list (list[Tensor]): Box energies / deltas from
  136. all scale levels of a single image, each item has shape
  137. (num_priors * 4, H, W).
  138. score_factor_list (list[Tensor]): Be compatible with
  139. BaseDenseHead. Not used in RPNHead.
  140. mlvl_priors (list[Tensor]): Each element in the list is
  141. the priors of a single level in feature pyramid. In all
  142. anchor-based methods, it has shape (num_priors, 4). In
  143. all anchor-free methods, it has shape (num_priors, 2)
  144. when `with_stride=True`, otherwise it still has shape
  145. (num_priors, 4).
  146. img_meta (dict): Image meta info.
  147. cfg (ConfigDict, optional): Test / postprocessing configuration,
  148. if None, test_cfg would be used.
  149. rescale (bool): If True, return boxes in original image space.
  150. Defaults to False.
  151. Returns:
  152. :obj:`InstanceData`: Detection results of each image
  153. after the post process.
  154. Each item usually contains following keys.
  155. - scores (Tensor): Classification scores, has a shape
  156. (num_instance, )
  157. - labels (Tensor): Labels of bboxes, has a shape
  158. (num_instances, ).
  159. - bboxes (Tensor): Has a shape (num_instances, 4),
  160. the last dimension 4 arrange as (x1, y1, x2, y2).
  161. """
  162. cfg = self.test_cfg if cfg is None else cfg
  163. cfg = copy.deepcopy(cfg)
  164. img_shape = img_meta['img_shape']
  165. nms_pre = cfg.get('nms_pre', -1)
  166. mlvl_bbox_preds = []
  167. mlvl_valid_priors = []
  168. mlvl_scores = []
  169. level_ids = []
  170. for level_idx, (cls_score, bbox_pred, priors) in \
  171. enumerate(zip(cls_score_list, bbox_pred_list,
  172. mlvl_priors)):
  173. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  174. reg_dim = self.bbox_coder.encode_size
  175. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, reg_dim)
  176. cls_score = cls_score.permute(1, 2,
  177. 0).reshape(-1, self.cls_out_channels)
  178. if self.use_sigmoid_cls:
  179. scores = cls_score.sigmoid()
  180. else:
  181. # remind that we set FG labels to [0] since mmdet v2.0
  182. # BG cat_id: 1
  183. scores = cls_score.softmax(-1)[:, :-1]
  184. scores = torch.squeeze(scores)
  185. if 0 < nms_pre < scores.shape[0]:
  186. # sort is faster than topk
  187. # _, topk_inds = scores.topk(cfg.nms_pre)
  188. ranked_scores, rank_inds = scores.sort(descending=True)
  189. topk_inds = rank_inds[:nms_pre]
  190. scores = ranked_scores[:nms_pre]
  191. bbox_pred = bbox_pred[topk_inds, :]
  192. priors = priors[topk_inds]
  193. mlvl_bbox_preds.append(bbox_pred)
  194. mlvl_valid_priors.append(priors)
  195. mlvl_scores.append(scores)
  196. # use level id to implement the separate level nms
  197. level_ids.append(
  198. scores.new_full((scores.size(0), ),
  199. level_idx,
  200. dtype=torch.long))
  201. bbox_pred = torch.cat(mlvl_bbox_preds)
  202. priors = cat_boxes(mlvl_valid_priors)
  203. bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)
  204. results = InstanceData()
  205. results.bboxes = bboxes
  206. results.scores = torch.cat(mlvl_scores)
  207. results.level_ids = torch.cat(level_ids)
  208. return self._bbox_post_process(
  209. results=results, cfg=cfg, rescale=rescale, img_meta=img_meta)
  210. def _bbox_post_process(self,
  211. results: InstanceData,
  212. cfg: ConfigDict,
  213. rescale: bool = False,
  214. with_nms: bool = True,
  215. img_meta: Optional[dict] = None) -> InstanceData:
  216. """bbox post-processing method.
  217. The boxes would be rescaled to the original image scale and do
  218. the nms operation.
  219. Args:
  220. results (:obj:`InstaceData`): Detection instance results,
  221. each item has shape (num_bboxes, ).
  222. cfg (ConfigDict): Test / postprocessing configuration.
  223. rescale (bool): If True, return boxes in original image space.
  224. Defaults to False.
  225. with_nms (bool): If True, do nms before return boxes.
  226. Default to True.
  227. img_meta (dict, optional): Image meta info. Defaults to None.
  228. Returns:
  229. :obj:`InstanceData`: Detection results of each image
  230. after the post process.
  231. Each item usually contains following keys.
  232. - scores (Tensor): Classification scores, has a shape
  233. (num_instance, )
  234. - labels (Tensor): Labels of bboxes, has a shape
  235. (num_instances, ).
  236. - bboxes (Tensor): Has a shape (num_instances, 4),
  237. the last dimension 4 arrange as (x1, y1, x2, y2).
  238. """
  239. assert with_nms, '`with_nms` must be True in RPNHead'
  240. if rescale:
  241. assert img_meta.get('scale_factor') is not None
  242. scale_factor = [1 / s for s in img_meta['scale_factor']]
  243. results.bboxes = scale_boxes(results.bboxes, scale_factor)
  244. # filter small size bboxes
  245. if cfg.get('min_bbox_size', -1) >= 0:
  246. w, h = get_box_wh(results.bboxes)
  247. valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
  248. if not valid_mask.all():
  249. results = results[valid_mask]
  250. if results.bboxes.numel() > 0:
  251. bboxes = get_box_tensor(results.bboxes)
  252. det_bboxes, keep_idxs = batched_nms(bboxes, results.scores,
  253. results.level_ids, cfg.nms)
  254. results = results[keep_idxs]
  255. # some nms would reweight the score, such as softnms
  256. results.scores = det_bboxes[:, -1]
  257. results = results[:cfg.max_per_img]
  258. # TODO: This would unreasonably show the 0th class label
  259. # in visualization
  260. results.labels = results.scores.new_zeros(
  261. len(results), dtype=torch.long)
  262. del results.level_ids
  263. else:
  264. # To avoid some potential error
  265. results_ = InstanceData()
  266. results_.bboxes = empty_box_as(results.bboxes)
  267. results_.scores = results.scores.new_zeros(0)
  268. results_.labels = results.scores.new_zeros(0)
  269. results = results_
  270. return results