detic_roi_head.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Sequence, Tuple
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from torch import Tensor
  6. from mmdet.models.roi_heads import CascadeRoIHead
  7. from mmdet.models.task_modules.samplers import SamplingResult
  8. from mmdet.models.test_time_augs import merge_aug_masks
  9. from mmdet.models.utils.misc import empty_instances
  10. from mmdet.registry import MODELS
  11. from mmdet.structures import SampleList
  12. from mmdet.structures.bbox import bbox2roi, get_box_tensor
  13. from mmdet.utils import ConfigType, InstanceList, MultiConfig
  14. @MODELS.register_module(force=True) # avoid bug
  15. class DeticRoIHead(CascadeRoIHead):
  16. def init_mask_head(self, mask_roi_extractor: MultiConfig,
  17. mask_head: MultiConfig) -> None:
  18. """Initialize mask head and mask roi extractor.
  19. Args:
  20. mask_head (dict): Config of mask in mask head.
  21. mask_roi_extractor (:obj:`ConfigDict`, dict or list):
  22. Config of mask roi extractor.
  23. """
  24. self.mask_head = MODELS.build(mask_head)
  25. if mask_roi_extractor is not None:
  26. self.share_roi_extractor = False
  27. self.mask_roi_extractor = MODELS.build(mask_roi_extractor)
  28. else:
  29. self.share_roi_extractor = True
  30. self.mask_roi_extractor = self.bbox_roi_extractor
  31. def _refine_roi(self, x: Tuple[Tensor], rois: Tensor,
  32. batch_img_metas: List[dict],
  33. num_proposals_per_img: Sequence[int], **kwargs) -> tuple:
  34. """Multi-stage refinement of RoI.
  35. Args:
  36. x (tuple[Tensor]): List of multi-level img features.
  37. rois (Tensor): shape (n, 5), [batch_ind, x1, y1, x2, y2]
  38. batch_img_metas (list[dict]): List of image information.
  39. num_proposals_per_img (sequence[int]): number of proposals
  40. in each image.
  41. Returns:
  42. tuple:
  43. - rois (Tensor): Refined RoI.
  44. - cls_scores (list[Tensor]): Average predicted
  45. cls score per image.
  46. - bbox_preds (list[Tensor]): Bbox branch predictions
  47. for the last stage of per image.
  48. """
  49. # "ms" in variable names means multi-stage
  50. ms_scores = []
  51. for stage in range(self.num_stages):
  52. bbox_results = self._bbox_forward(
  53. stage=stage, x=x, rois=rois, **kwargs)
  54. # split batch bbox prediction back to each image
  55. cls_scores = bbox_results['cls_score'].sigmoid()
  56. bbox_preds = bbox_results['bbox_pred']
  57. rois = rois.split(num_proposals_per_img, 0)
  58. cls_scores = cls_scores.split(num_proposals_per_img, 0)
  59. ms_scores.append(cls_scores)
  60. bbox_preds = bbox_preds.split(num_proposals_per_img, 0)
  61. if stage < self.num_stages - 1:
  62. bbox_head = self.bbox_head[stage]
  63. refine_rois_list = []
  64. for i in range(len(batch_img_metas)):
  65. if rois[i].shape[0] > 0:
  66. bbox_label = cls_scores[i][:, :-1].argmax(dim=1)
  67. # Refactor `bbox_head.regress_by_class` to only accept
  68. # box tensor without img_idx concatenated.
  69. refined_bboxes = bbox_head.regress_by_class(
  70. rois[i][:, 1:], bbox_label, bbox_preds[i],
  71. batch_img_metas[i])
  72. refined_bboxes = get_box_tensor(refined_bboxes)
  73. refined_rois = torch.cat(
  74. [rois[i][:, [0]], refined_bboxes], dim=1)
  75. refine_rois_list.append(refined_rois)
  76. rois = torch.cat(refine_rois_list)
  77. # ms_scores aligned
  78. # average scores of each image by stages
  79. cls_scores = [
  80. sum([score[i] for score in ms_scores]) / float(len(ms_scores))
  81. for i in range(len(batch_img_metas))
  82. ] # aligned
  83. return rois, cls_scores, bbox_preds
  84. def _bbox_forward(self, stage: int, x: Tuple[Tensor],
  85. rois: Tensor) -> dict:
  86. """Box head forward function used in both training and testing.
  87. Args:
  88. stage (int): The current stage in Cascade RoI Head.
  89. x (tuple[Tensor]): List of multi-level img features.
  90. rois (Tensor): RoIs with the shape (n, 5) where the first
  91. column indicates batch id of each RoI.
  92. Returns:
  93. dict[str, Tensor]: Usually returns a dictionary with keys:
  94. - `cls_score` (Tensor): Classification scores.
  95. - `bbox_pred` (Tensor): Box energies / deltas.
  96. - `bbox_feats` (Tensor): Extract bbox RoI features.
  97. """
  98. bbox_roi_extractor = self.bbox_roi_extractor[stage]
  99. bbox_head = self.bbox_head[stage]
  100. bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
  101. rois)
  102. # do not support caffe_c4 model anymore
  103. cls_score, bbox_pred = bbox_head(bbox_feats)
  104. bbox_results = dict(
  105. cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
  106. return bbox_results
  107. def predict_bbox(self,
  108. x: Tuple[Tensor],
  109. batch_img_metas: List[dict],
  110. rpn_results_list: InstanceList,
  111. rcnn_test_cfg: ConfigType,
  112. rescale: bool = False,
  113. **kwargs) -> InstanceList:
  114. """Perform forward propagation of the bbox head and predict detection
  115. results on the features of the upstream network.
  116. Args:
  117. x (tuple[Tensor]): Feature maps of all scale level.
  118. batch_img_metas (list[dict]): List of image information.
  119. rpn_results_list (list[:obj:`InstanceData`]): List of region
  120. proposals.
  121. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
  122. rescale (bool): If True, return boxes in original image space.
  123. Defaults to False.
  124. Returns:
  125. list[:obj:`InstanceData`]: Detection results of each image
  126. after the post process.
  127. Each item usually contains following keys.
  128. - scores (Tensor): Classification scores, has a shape
  129. (num_instance, )
  130. - labels (Tensor): Labels of bboxes, has a shape
  131. (num_instances, ).
  132. - bboxes (Tensor): Has a shape (num_instances, 4),
  133. the last dimension 4 arrange as (x1, y1, x2, y2).
  134. """
  135. proposals = [res.bboxes for res in rpn_results_list]
  136. proposal_scores = [res.scores for res in rpn_results_list]
  137. num_proposals_per_img = tuple(len(p) for p in proposals)
  138. rois = bbox2roi(proposals)
  139. if rois.shape[0] == 0:
  140. return empty_instances(
  141. batch_img_metas,
  142. rois.device,
  143. task_type='bbox',
  144. box_type=self.bbox_head[-1].predict_box_type,
  145. num_classes=self.bbox_head[-1].num_classes,
  146. score_per_cls=rcnn_test_cfg is None)
  147. # rois aligned
  148. rois, cls_scores, bbox_preds = self._refine_roi(
  149. x=x,
  150. rois=rois,
  151. batch_img_metas=batch_img_metas,
  152. num_proposals_per_img=num_proposals_per_img,
  153. **kwargs)
  154. # score reweighting in centernet2
  155. cls_scores = [(s * ps[:, None])**0.5
  156. for s, ps in zip(cls_scores, proposal_scores)]
  157. cls_scores = [
  158. s * (s == s[:, :-1].max(dim=1)[0][:, None]).float()
  159. for s in cls_scores
  160. ]
  161. # fast_rcnn_inference
  162. results_list = self.bbox_head[-1].predict_by_feat(
  163. rois=rois,
  164. cls_scores=cls_scores,
  165. bbox_preds=bbox_preds,
  166. batch_img_metas=batch_img_metas,
  167. rescale=rescale,
  168. rcnn_test_cfg=rcnn_test_cfg)
  169. return results_list
  170. def _mask_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict:
  171. """Mask head forward function used in both training and testing.
  172. Args:
  173. stage (int): The current stage in Cascade RoI Head.
  174. x (tuple[Tensor]): Tuple of multi-level img features.
  175. rois (Tensor): RoIs with the shape (n, 5) where the first
  176. column indicates batch id of each RoI.
  177. Returns:
  178. dict: Usually returns a dictionary with keys:
  179. - `mask_preds` (Tensor): Mask prediction.
  180. """
  181. mask_feats = self.mask_roi_extractor(
  182. x[:self.mask_roi_extractor.num_inputs], rois)
  183. # do not support caffe_c4 model anymore
  184. mask_preds = self.mask_head(mask_feats)
  185. mask_results = dict(mask_preds=mask_preds)
  186. return mask_results
  187. def mask_loss(self, x, sampling_results: List[SamplingResult],
  188. batch_gt_instances: InstanceList) -> dict:
  189. """Run forward function and calculate loss for mask head in training.
  190. Args:
  191. x (tuple[Tensor]): Tuple of multi-level img features.
  192. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  193. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  194. gt_instance. It usually includes ``bboxes``, ``labels``, and
  195. ``masks`` attributes.
  196. Returns:
  197. dict: Usually returns a dictionary with keys:
  198. - `mask_preds` (Tensor): Mask prediction.
  199. - `loss_mask` (dict): A dictionary of mask loss components.
  200. """
  201. pos_rois = bbox2roi([res.pos_priors for res in sampling_results])
  202. mask_results = self._mask_forward(x, pos_rois)
  203. mask_loss_and_target = self.mask_head.loss_and_target(
  204. mask_preds=mask_results['mask_preds'],
  205. sampling_results=sampling_results,
  206. batch_gt_instances=batch_gt_instances,
  207. rcnn_train_cfg=self.train_cfg[-1])
  208. mask_results.update(mask_loss_and_target)
  209. return mask_results
  210. def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  211. batch_data_samples: SampleList) -> dict:
  212. """Perform forward propagation and loss calculation of the detection
  213. roi on the features of the upstream network.
  214. Args:
  215. x (tuple[Tensor]): List of multi-level img features.
  216. rpn_results_list (list[:obj:`InstanceData`]): List of region
  217. proposals.
  218. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  219. data samples. It usually includes information such
  220. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  221. Returns:
  222. dict[str, Tensor]: A dictionary of loss components
  223. """
  224. raise NotImplementedError
  225. def predict_mask(self,
  226. x: Tuple[Tensor],
  227. batch_img_metas: List[dict],
  228. results_list: List[InstanceData],
  229. rescale: bool = False) -> List[InstanceData]:
  230. """Perform forward propagation of the mask head and predict detection
  231. results on the features of the upstream network.
  232. Args:
  233. x (tuple[Tensor]): Feature maps of all scale level.
  234. batch_img_metas (list[dict]): List of image information.
  235. results_list (list[:obj:`InstanceData`]): Detection results of
  236. each image.
  237. rescale (bool): If True, return boxes in original image space.
  238. Defaults to False.
  239. Returns:
  240. list[:obj:`InstanceData`]: Detection results of each image
  241. after the post process.
  242. Each item usually contains following keys.
  243. - scores (Tensor): Classification scores, has a shape
  244. (num_instance, )
  245. - labels (Tensor): Labels of bboxes, has a shape
  246. (num_instances, ).
  247. - bboxes (Tensor): Has a shape (num_instances, 4),
  248. the last dimension 4 arrange as (x1, y1, x2, y2).
  249. - masks (Tensor): Has a shape (num_instances, H, W).
  250. """
  251. bboxes = [res.bboxes for res in results_list]
  252. mask_rois = bbox2roi(bboxes)
  253. if mask_rois.shape[0] == 0:
  254. results_list = empty_instances(
  255. batch_img_metas,
  256. mask_rois.device,
  257. task_type='mask',
  258. instance_results=results_list,
  259. mask_thr_binary=self.test_cfg.mask_thr_binary)
  260. return results_list
  261. num_mask_rois_per_img = [len(res) for res in results_list]
  262. aug_masks = []
  263. mask_results = self._mask_forward(x, mask_rois)
  264. mask_preds = mask_results['mask_preds']
  265. # split batch mask prediction back to each image
  266. mask_preds = mask_preds.split(num_mask_rois_per_img, 0)
  267. aug_masks.append([m.sigmoid().detach() for m in mask_preds])
  268. merged_masks = []
  269. for i in range(len(batch_img_metas)):
  270. aug_mask = [mask[i] for mask in aug_masks]
  271. merged_mask = merge_aug_masks(aug_mask, batch_img_metas[i])
  272. merged_masks.append(merged_mask)
  273. results_list = self.mask_head.predict_by_feat(
  274. mask_preds=merged_masks,
  275. results_list=results_list,
  276. batch_img_metas=batch_img_metas,
  277. rcnn_test_cfg=self.test_cfg,
  278. rescale=rescale,
  279. activate_map=True)
  280. return results_list