ga_rpn_head.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from typing import List, Tuple
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mmcv.ops import nms
  8. from mmengine.structures import InstanceData
  9. from torch import Tensor
  10. from mmdet.registry import MODELS
  11. from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList
  12. from .guided_anchor_head import GuidedAnchorHead
  13. @MODELS.register_module()
  14. class GARPNHead(GuidedAnchorHead):
  15. """Guided-Anchor-based RPN head."""
  16. def __init__(self,
  17. in_channels: int,
  18. num_classes: int = 1,
  19. init_cfg: MultiConfig = dict(
  20. type='Normal',
  21. layer='Conv2d',
  22. std=0.01,
  23. override=dict(
  24. type='Normal',
  25. name='conv_loc',
  26. std=0.01,
  27. bias_prob=0.01)),
  28. **kwargs) -> None:
  29. super().__init__(
  30. num_classes=num_classes,
  31. in_channels=in_channels,
  32. init_cfg=init_cfg,
  33. **kwargs)
  34. def _init_layers(self) -> None:
  35. """Initialize layers of the head."""
  36. self.rpn_conv = nn.Conv2d(
  37. self.in_channels, self.feat_channels, 3, padding=1)
  38. super(GARPNHead, self)._init_layers()
  39. def forward_single(self, x: Tensor) -> Tuple[Tensor]:
  40. """Forward feature of a single scale level."""
  41. x = self.rpn_conv(x)
  42. x = F.relu(x, inplace=True)
  43. (cls_score, bbox_pred, shape_pred,
  44. loc_pred) = super().forward_single(x)
  45. return cls_score, bbox_pred, shape_pred, loc_pred
  46. def loss_by_feat(
  47. self,
  48. cls_scores: List[Tensor],
  49. bbox_preds: List[Tensor],
  50. shape_preds: List[Tensor],
  51. loc_preds: List[Tensor],
  52. batch_gt_instances: InstanceList,
  53. batch_img_metas: List[dict],
  54. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  55. """Calculate the loss based on the features extracted by the detection
  56. head.
  57. Args:
  58. cls_scores (list[Tensor]): Box scores for each scale level
  59. has shape (N, num_anchors * num_classes, H, W).
  60. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  61. level with shape (N, num_anchors * 4, H, W).
  62. shape_preds (list[Tensor]): shape predictions for each scale
  63. level with shape (N, 1, H, W).
  64. loc_preds (list[Tensor]): location predictions for each scale
  65. level with shape (N, num_anchors * 2, H, W).
  66. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  67. gt_instance. It usually includes ``bboxes`` and ``labels``
  68. attributes.
  69. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  70. image size, scaling factor, etc.
  71. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  72. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  73. data that is ignored during training and testing.
  74. Defaults to None.
  75. Returns:
  76. dict: A dictionary of loss components.
  77. """
  78. losses = super().loss_by_feat(
  79. cls_scores,
  80. bbox_preds,
  81. shape_preds,
  82. loc_preds,
  83. batch_gt_instances,
  84. batch_img_metas,
  85. batch_gt_instances_ignore=batch_gt_instances_ignore)
  86. return dict(
  87. loss_rpn_cls=losses['loss_cls'],
  88. loss_rpn_bbox=losses['loss_bbox'],
  89. loss_anchor_shape=losses['loss_shape'],
  90. loss_anchor_loc=losses['loss_loc'])
  91. def _predict_by_feat_single(self,
  92. cls_scores: List[Tensor],
  93. bbox_preds: List[Tensor],
  94. mlvl_anchors: List[Tensor],
  95. mlvl_masks: List[Tensor],
  96. img_meta: dict,
  97. cfg: ConfigType,
  98. rescale: bool = False) -> InstanceData:
  99. """Transform a single image's features extracted from the head into
  100. bbox results.
  101. Args:
  102. cls_scores (list[Tensor]): Box scores from all scale
  103. levels of a single image, each item has shape
  104. (num_priors * num_classes, H, W).
  105. bbox_preds (list[Tensor]): Box energies / deltas from
  106. all scale levels of a single image, each item has shape
  107. (num_priors * 4, H, W).
  108. mlvl_anchors (list[Tensor]): Each element in the list is
  109. the anchors of a single level in feature pyramid. it has
  110. shape (num_priors, 4).
  111. mlvl_masks (list[Tensor]): Each element in the list is location
  112. masks of a single level.
  113. img_meta (dict): Image meta info.
  114. cfg (:obj:`ConfigDict` or dict): Test / postprocessing
  115. configuration, if None, test_cfg would be used.
  116. rescale (bool): If True, return boxes in original image space.
  117. Defaults to False.
  118. Returns:
  119. :obj:`InstanceData`: Detection results of each image
  120. after the post process.
  121. Each item usually contains following keys.
  122. - scores (Tensor): Classification scores, has a shape
  123. (num_instance, )
  124. - labels (Tensor): Labels of bboxes, has a shape (num_instances, ).
  125. - bboxes (Tensor): Has a shape (num_instances, 4), the last
  126. dimension 4 arrange as (x1, y1, x2, y2).
  127. """
  128. cfg = self.test_cfg if cfg is None else cfg
  129. cfg = copy.deepcopy(cfg)
  130. assert cfg.nms.get('type', 'nms') == 'nms', 'GARPNHead only support ' \
  131. 'naive nms.'
  132. mlvl_proposals = []
  133. for idx in range(len(cls_scores)):
  134. rpn_cls_score = cls_scores[idx]
  135. rpn_bbox_pred = bbox_preds[idx]
  136. anchors = mlvl_anchors[idx]
  137. mask = mlvl_masks[idx]
  138. assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
  139. # if no location is kept, end.
  140. if mask.sum() == 0:
  141. continue
  142. rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
  143. if self.use_sigmoid_cls:
  144. rpn_cls_score = rpn_cls_score.reshape(-1)
  145. scores = rpn_cls_score.sigmoid()
  146. else:
  147. rpn_cls_score = rpn_cls_score.reshape(-1, 2)
  148. # remind that we set FG labels to [0, num_class-1]
  149. # since mmdet v2.0
  150. # BG cat_id: num_class
  151. scores = rpn_cls_score.softmax(dim=1)[:, :-1]
  152. # filter scores, bbox_pred w.r.t. mask.
  153. # anchors are filtered in get_anchors() beforehand.
  154. scores = scores[mask]
  155. rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1,
  156. 4)[mask, :]
  157. if scores.dim() == 0:
  158. rpn_bbox_pred = rpn_bbox_pred.unsqueeze(0)
  159. anchors = anchors.unsqueeze(0)
  160. scores = scores.unsqueeze(0)
  161. # filter anchors, bbox_pred, scores w.r.t. scores
  162. if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
  163. _, topk_inds = scores.topk(cfg.nms_pre)
  164. rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
  165. anchors = anchors[topk_inds, :]
  166. scores = scores[topk_inds]
  167. # get proposals w.r.t. anchors and rpn_bbox_pred
  168. proposals = self.bbox_coder.decode(
  169. anchors, rpn_bbox_pred, max_shape=img_meta['img_shape'])
  170. # filter out too small bboxes
  171. if cfg.min_bbox_size >= 0:
  172. w = proposals[:, 2] - proposals[:, 0]
  173. h = proposals[:, 3] - proposals[:, 1]
  174. valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
  175. if not valid_mask.all():
  176. proposals = proposals[valid_mask]
  177. scores = scores[valid_mask]
  178. # NMS in current level
  179. proposals, _ = nms(proposals, scores, cfg.nms.iou_threshold)
  180. proposals = proposals[:cfg.nms_post, :]
  181. mlvl_proposals.append(proposals)
  182. proposals = torch.cat(mlvl_proposals, 0)
  183. if cfg.get('nms_across_levels', False):
  184. # NMS across multi levels
  185. proposals, _ = nms(proposals[:, :4], proposals[:, -1],
  186. cfg.nms.iou_threshold)
  187. proposals = proposals[:cfg.max_per_img, :]
  188. else:
  189. scores = proposals[:, 4]
  190. num = min(cfg.max_per_img, proposals.shape[0])
  191. _, topk_inds = scores.topk(num)
  192. proposals = proposals[topk_inds, :]
  193. bboxes = proposals[:, :-1]
  194. scores = proposals[:, -1]
  195. if rescale:
  196. assert img_meta.get('scale_factor') is not None
  197. bboxes /= bboxes.new_tensor(img_meta['scale_factor']).repeat(
  198. (1, 2))
  199. results = InstanceData()
  200. results.bboxes = bboxes
  201. results.scores = scores
  202. results.labels = scores.new_zeros(scores.size(0), dtype=torch.long)
  203. return results