standard_roi_head.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import torch
  4. from torch import Tensor
  5. from mmdet.registry import MODELS, TASK_UTILS
  6. from mmdet.structures import DetDataSample, SampleList
  7. from mmdet.structures.bbox import bbox2roi
  8. from mmdet.utils import ConfigType, InstanceList
  9. from ..task_modules.samplers import SamplingResult
  10. from ..utils import empty_instances, unpack_gt_instances
  11. from .base_roi_head import BaseRoIHead
  12. @MODELS.register_module()
  13. class StandardRoIHead(BaseRoIHead):
  14. """Simplest base roi head including one bbox head and one mask head."""
  15. def init_assigner_sampler(self) -> None:
  16. """Initialize assigner and sampler."""
  17. self.bbox_assigner = None
  18. self.bbox_sampler = None
  19. if self.train_cfg:
  20. self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner)
  21. self.bbox_sampler = TASK_UTILS.build(
  22. self.train_cfg.sampler, default_args=dict(context=self))
  23. def init_bbox_head(self, bbox_roi_extractor: ConfigType,
  24. bbox_head: ConfigType) -> None:
  25. """Initialize box head and box roi extractor.
  26. Args:
  27. bbox_roi_extractor (dict or ConfigDict): Config of box
  28. roi extractor.
  29. bbox_head (dict or ConfigDict): Config of box in box head.
  30. """
  31. self.bbox_roi_extractor = MODELS.build(bbox_roi_extractor)
  32. self.bbox_head = MODELS.build(bbox_head)
  33. def init_mask_head(self, mask_roi_extractor: ConfigType,
  34. mask_head: ConfigType) -> None:
  35. """Initialize mask head and mask roi extractor.
  36. Args:
  37. mask_roi_extractor (dict or ConfigDict): Config of mask roi
  38. extractor.
  39. mask_head (dict or ConfigDict): Config of mask in mask head.
  40. """
  41. if mask_roi_extractor is not None:
  42. self.mask_roi_extractor = MODELS.build(mask_roi_extractor)
  43. self.share_roi_extractor = False
  44. else:
  45. self.share_roi_extractor = True
  46. self.mask_roi_extractor = self.bbox_roi_extractor
  47. self.mask_head = MODELS.build(mask_head)
  48. # TODO: Need to refactor later
  49. def forward(self,
  50. x: Tuple[Tensor],
  51. rpn_results_list: InstanceList,
  52. batch_data_samples: SampleList = None) -> tuple:
  53. """Network forward process. Usually includes backbone, neck and head
  54. forward without any post-processing.
  55. Args:
  56. x (List[Tensor]): Multi-level features that may have different
  57. resolutions.
  58. rpn_results_list (list[:obj:`InstanceData`]): List of region
  59. proposals.
  60. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
  61. the meta information of each image and corresponding
  62. annotations.
  63. Returns
  64. tuple: A tuple of features from ``bbox_head`` and ``mask_head``
  65. forward.
  66. """
  67. results = ()
  68. proposals = [rpn_results.bboxes for rpn_results in rpn_results_list]
  69. rois = bbox2roi(proposals)
  70. # bbox head
  71. if self.with_bbox:
  72. bbox_results = self._bbox_forward(x, rois)
  73. results = results + (bbox_results['cls_score'],
  74. bbox_results['bbox_pred'])
  75. # mask head
  76. if self.with_mask:
  77. mask_rois = rois[:100]
  78. mask_results = self._mask_forward(x, mask_rois)
  79. results = results + (mask_results['mask_preds'], )
  80. return results
  81. def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  82. batch_data_samples: List[DetDataSample]) -> dict:
  83. """Perform forward propagation and loss calculation of the detection
  84. roi on the features of the upstream network.
  85. Args:
  86. x (tuple[Tensor]): List of multi-level img features.
  87. rpn_results_list (list[:obj:`InstanceData`]): List of region
  88. proposals.
  89. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  90. data samples. It usually includes information such
  91. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  92. Returns:
  93. dict[str, Tensor]: A dictionary of loss components
  94. """
  95. assert len(rpn_results_list) == len(batch_data_samples)
  96. outputs = unpack_gt_instances(batch_data_samples)
  97. batch_gt_instances, batch_gt_instances_ignore, _ = outputs
  98. # assign gts and sample proposals
  99. num_imgs = len(batch_data_samples)
  100. sampling_results = []
  101. for i in range(num_imgs):
  102. # rename rpn_results.bboxes to rpn_results.priors
  103. rpn_results = rpn_results_list[i]
  104. rpn_results.priors = rpn_results.pop('bboxes')
  105. assign_result = self.bbox_assigner.assign(
  106. rpn_results, batch_gt_instances[i],
  107. batch_gt_instances_ignore[i])
  108. sampling_result = self.bbox_sampler.sample(
  109. assign_result,
  110. rpn_results,
  111. batch_gt_instances[i],
  112. feats=[lvl_feat[i][None] for lvl_feat in x])
  113. sampling_results.append(sampling_result)
  114. losses = dict()
  115. # bbox head loss
  116. if self.with_bbox:
  117. bbox_results = self.bbox_loss(x, sampling_results)
  118. losses.update(bbox_results['loss_bbox'])
  119. # mask head forward and loss
  120. if self.with_mask:
  121. mask_results = self.mask_loss(x, sampling_results,
  122. bbox_results['bbox_feats'],
  123. batch_gt_instances)
  124. losses.update(mask_results['loss_mask'])
  125. return losses
  126. def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict:
  127. """Box head forward function used in both training and testing.
  128. Args:
  129. x (tuple[Tensor]): List of multi-level img features.
  130. rois (Tensor): RoIs with the shape (n, 5) where the first
  131. column indicates batch id of each RoI.
  132. Returns:
  133. dict[str, Tensor]: Usually returns a dictionary with keys:
  134. - `cls_score` (Tensor): Classification scores.
  135. - `bbox_pred` (Tensor): Box energies / deltas.
  136. - `bbox_feats` (Tensor): Extract bbox RoI features.
  137. """
  138. # TODO: a more flexible way to decide which feature maps to use
  139. bbox_feats = self.bbox_roi_extractor(
  140. x[:self.bbox_roi_extractor.num_inputs], rois)
  141. if self.with_shared_head:
  142. bbox_feats = self.shared_head(bbox_feats)
  143. cls_score, bbox_pred = self.bbox_head(bbox_feats)
  144. bbox_results = dict(
  145. cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
  146. return bbox_results
  147. def bbox_loss(self, x: Tuple[Tensor],
  148. sampling_results: List[SamplingResult]) -> dict:
  149. """Perform forward propagation and loss calculation of the bbox head on
  150. the features of the upstream network.
  151. Args:
  152. x (tuple[Tensor]): List of multi-level img features.
  153. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  154. Returns:
  155. dict[str, Tensor]: Usually returns a dictionary with keys:
  156. - `cls_score` (Tensor): Classification scores.
  157. - `bbox_pred` (Tensor): Box energies / deltas.
  158. - `bbox_feats` (Tensor): Extract bbox RoI features.
  159. - `loss_bbox` (dict): A dictionary of bbox loss components.
  160. """
  161. rois = bbox2roi([res.priors for res in sampling_results])
  162. bbox_results = self._bbox_forward(x, rois)
  163. bbox_loss_and_target = self.bbox_head.loss_and_target(
  164. cls_score=bbox_results['cls_score'],
  165. bbox_pred=bbox_results['bbox_pred'],
  166. rois=rois,
  167. sampling_results=sampling_results,
  168. rcnn_train_cfg=self.train_cfg)
  169. bbox_results.update(loss_bbox=bbox_loss_and_target['loss_bbox'])
  170. return bbox_results
  171. def mask_loss(self, x: Tuple[Tensor],
  172. sampling_results: List[SamplingResult], bbox_feats: Tensor,
  173. batch_gt_instances: InstanceList) -> dict:
  174. """Perform forward propagation and loss calculation of the mask head on
  175. the features of the upstream network.
  176. Args:
  177. x (tuple[Tensor]): Tuple of multi-level img features.
  178. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  179. bbox_feats (Tensor): Extract bbox RoI features.
  180. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  181. gt_instance. It usually includes ``bboxes``, ``labels``, and
  182. ``masks`` attributes.
  183. Returns:
  184. dict: Usually returns a dictionary with keys:
  185. - `mask_preds` (Tensor): Mask prediction.
  186. - `mask_feats` (Tensor): Extract mask RoI features.
  187. - `mask_targets` (Tensor): Mask target of each positive\
  188. proposals in the image.
  189. - `loss_mask` (dict): A dictionary of mask loss components.
  190. """
  191. if not self.share_roi_extractor:
  192. pos_rois = bbox2roi([res.pos_priors for res in sampling_results])
  193. mask_results = self._mask_forward(x, pos_rois)
  194. else:
  195. pos_inds = []
  196. device = bbox_feats.device
  197. for res in sampling_results:
  198. pos_inds.append(
  199. torch.ones(
  200. res.pos_priors.shape[0],
  201. device=device,
  202. dtype=torch.uint8))
  203. pos_inds.append(
  204. torch.zeros(
  205. res.neg_priors.shape[0],
  206. device=device,
  207. dtype=torch.uint8))
  208. pos_inds = torch.cat(pos_inds)
  209. mask_results = self._mask_forward(
  210. x, pos_inds=pos_inds, bbox_feats=bbox_feats)
  211. mask_loss_and_target = self.mask_head.loss_and_target(
  212. mask_preds=mask_results['mask_preds'],
  213. sampling_results=sampling_results,
  214. batch_gt_instances=batch_gt_instances,
  215. rcnn_train_cfg=self.train_cfg)
  216. mask_results.update(loss_mask=mask_loss_and_target['loss_mask'])
  217. return mask_results
  218. def _mask_forward(self,
  219. x: Tuple[Tensor],
  220. rois: Tensor = None,
  221. pos_inds: Optional[Tensor] = None,
  222. bbox_feats: Optional[Tensor] = None) -> dict:
  223. """Mask head forward function used in both training and testing.
  224. Args:
  225. x (tuple[Tensor]): Tuple of multi-level img features.
  226. rois (Tensor): RoIs with the shape (n, 5) where the first
  227. column indicates batch id of each RoI.
  228. pos_inds (Tensor, optional): Indices of positive samples.
  229. Defaults to None.
  230. bbox_feats (Tensor): Extract bbox RoI features. Defaults to None.
  231. Returns:
  232. dict[str, Tensor]: Usually returns a dictionary with keys:
  233. - `mask_preds` (Tensor): Mask prediction.
  234. - `mask_feats` (Tensor): Extract mask RoI features.
  235. """
  236. assert ((rois is not None) ^
  237. (pos_inds is not None and bbox_feats is not None))
  238. if rois is not None:
  239. mask_feats = self.mask_roi_extractor(
  240. x[:self.mask_roi_extractor.num_inputs], rois)
  241. if self.with_shared_head:
  242. mask_feats = self.shared_head(mask_feats)
  243. else:
  244. assert bbox_feats is not None
  245. mask_feats = bbox_feats[pos_inds]
  246. mask_preds = self.mask_head(mask_feats)
  247. mask_results = dict(mask_preds=mask_preds, mask_feats=mask_feats)
  248. return mask_results
  249. def predict_bbox(self,
  250. x: Tuple[Tensor],
  251. batch_img_metas: List[dict],
  252. rpn_results_list: InstanceList,
  253. rcnn_test_cfg: ConfigType,
  254. rescale: bool = False) -> InstanceList:
  255. """Perform forward propagation of the bbox head and predict detection
  256. results on the features of the upstream network.
  257. Args:
  258. x (tuple[Tensor]): Feature maps of all scale level.
  259. batch_img_metas (list[dict]): List of image information.
  260. rpn_results_list (list[:obj:`InstanceData`]): List of region
  261. proposals.
  262. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
  263. rescale (bool): If True, return boxes in original image space.
  264. Defaults to False.
  265. Returns:
  266. list[:obj:`InstanceData`]: Detection results of each image
  267. after the post process.
  268. Each item usually contains following keys.
  269. - scores (Tensor): Classification scores, has a shape
  270. (num_instance, )
  271. - labels (Tensor): Labels of bboxes, has a shape
  272. (num_instances, ).
  273. - bboxes (Tensor): Has a shape (num_instances, 4),
  274. the last dimension 4 arrange as (x1, y1, x2, y2).
  275. """
  276. proposals = [res.bboxes for res in rpn_results_list]
  277. rois = bbox2roi(proposals)
  278. if rois.shape[0] == 0:
  279. return empty_instances(
  280. batch_img_metas,
  281. rois.device,
  282. task_type='bbox',
  283. box_type=self.bbox_head.predict_box_type,
  284. num_classes=self.bbox_head.num_classes,
  285. score_per_cls=rcnn_test_cfg is None)
  286. bbox_results = self._bbox_forward(x, rois)
  287. # split batch bbox prediction back to each image
  288. cls_scores = bbox_results['cls_score']
  289. bbox_preds = bbox_results['bbox_pred']
  290. num_proposals_per_img = tuple(len(p) for p in proposals)
  291. rois = rois.split(num_proposals_per_img, 0)
  292. cls_scores = cls_scores.split(num_proposals_per_img, 0)
  293. # some detector with_reg is False, bbox_preds will be None
  294. if bbox_preds is not None:
  295. # TODO move this to a sabl_roi_head
  296. # the bbox prediction of some detectors like SABL is not Tensor
  297. if isinstance(bbox_preds, torch.Tensor):
  298. bbox_preds = bbox_preds.split(num_proposals_per_img, 0)
  299. else:
  300. bbox_preds = self.bbox_head.bbox_pred_split(
  301. bbox_preds, num_proposals_per_img)
  302. else:
  303. bbox_preds = (None, ) * len(proposals)
  304. result_list = self.bbox_head.predict_by_feat(
  305. rois=rois,
  306. cls_scores=cls_scores,
  307. bbox_preds=bbox_preds,
  308. batch_img_metas=batch_img_metas,
  309. rcnn_test_cfg=rcnn_test_cfg,
  310. rescale=rescale)
  311. return result_list
  312. def predict_mask(self,
  313. x: Tuple[Tensor],
  314. batch_img_metas: List[dict],
  315. results_list: InstanceList,
  316. rescale: bool = False) -> InstanceList:
  317. """Perform forward propagation of the mask head and predict detection
  318. results on the features of the upstream network.
  319. Args:
  320. x (tuple[Tensor]): Feature maps of all scale level.
  321. batch_img_metas (list[dict]): List of image information.
  322. results_list (list[:obj:`InstanceData`]): Detection results of
  323. each image.
  324. rescale (bool): If True, return boxes in original image space.
  325. Defaults to False.
  326. Returns:
  327. list[:obj:`InstanceData`]: Detection results of each image
  328. after the post process.
  329. Each item usually contains following keys.
  330. - scores (Tensor): Classification scores, has a shape
  331. (num_instance, )
  332. - labels (Tensor): Labels of bboxes, has a shape
  333. (num_instances, ).
  334. - bboxes (Tensor): Has a shape (num_instances, 4),
  335. the last dimension 4 arrange as (x1, y1, x2, y2).
  336. - masks (Tensor): Has a shape (num_instances, H, W).
  337. """
  338. # don't need to consider aug_test.
  339. bboxes = [res.bboxes for res in results_list]
  340. mask_rois = bbox2roi(bboxes)
  341. if mask_rois.shape[0] == 0:
  342. results_list = empty_instances(
  343. batch_img_metas,
  344. mask_rois.device,
  345. task_type='mask',
  346. instance_results=results_list,
  347. mask_thr_binary=self.test_cfg.mask_thr_binary)
  348. return results_list
  349. mask_results = self._mask_forward(x, mask_rois)
  350. mask_preds = mask_results['mask_preds']
  351. # split batch mask prediction back to each image
  352. num_mask_rois_per_img = [len(res) for res in results_list]
  353. mask_preds = mask_preds.split(num_mask_rois_per_img, 0)
  354. # TODO: Handle the case where rescale is false
  355. results_list = self.mask_head.predict_by_feat(
  356. mask_preds=mask_preds,
  357. results_list=results_list,
  358. batch_img_metas=batch_img_metas,
  359. rcnn_test_cfg=self.test_cfg,
  360. rescale=rescale)
  361. return results_list