grid_roi_head.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import torch
  4. from torch import Tensor
  5. from mmdet.registry import MODELS
  6. from mmdet.structures import SampleList
  7. from mmdet.structures.bbox import bbox2roi
  8. from mmdet.utils import ConfigType, InstanceList
  9. from ..task_modules.samplers import SamplingResult
  10. from ..utils.misc import unpack_gt_instances
  11. from .standard_roi_head import StandardRoIHead
  12. @MODELS.register_module()
  13. class GridRoIHead(StandardRoIHead):
  14. """Implementation of `Grid RoI Head <https://arxiv.org/abs/1811.12030>`_
  15. Args:
  16. grid_roi_extractor (:obj:`ConfigDict` or dict): Config of
  17. roi extractor.
  18. grid_head (:obj:`ConfigDict` or dict): Config of grid head
  19. """
  20. def __init__(self, grid_roi_extractor: ConfigType, grid_head: ConfigType,
  21. **kwargs) -> None:
  22. assert grid_head is not None
  23. super().__init__(**kwargs)
  24. if grid_roi_extractor is not None:
  25. self.grid_roi_extractor = MODELS.build(grid_roi_extractor)
  26. self.share_roi_extractor = False
  27. else:
  28. self.share_roi_extractor = True
  29. self.grid_roi_extractor = self.bbox_roi_extractor
  30. self.grid_head = MODELS.build(grid_head)
  31. def _random_jitter(self,
  32. sampling_results: List[SamplingResult],
  33. batch_img_metas: List[dict],
  34. amplitude: float = 0.15) -> List[SamplingResult]:
  35. """Ramdom jitter positive proposals for training.
  36. Args:
  37. sampling_results (List[obj:SamplingResult]): Assign results of
  38. all images in a batch after sampling.
  39. batch_img_metas (list[dict]): List of image information.
  40. amplitude (float): Amplitude of random offset. Defaults to 0.15.
  41. Returns:
  42. list[obj:SamplingResult]: SamplingResults after random jittering.
  43. """
  44. for sampling_result, img_meta in zip(sampling_results,
  45. batch_img_metas):
  46. bboxes = sampling_result.pos_priors
  47. random_offsets = bboxes.new_empty(bboxes.shape[0], 4).uniform_(
  48. -amplitude, amplitude)
  49. # before jittering
  50. cxcy = (bboxes[:, 2:4] + bboxes[:, :2]) / 2
  51. wh = (bboxes[:, 2:4] - bboxes[:, :2]).abs()
  52. # after jittering
  53. new_cxcy = cxcy + wh * random_offsets[:, :2]
  54. new_wh = wh * (1 + random_offsets[:, 2:])
  55. # xywh to xyxy
  56. new_x1y1 = (new_cxcy - new_wh / 2)
  57. new_x2y2 = (new_cxcy + new_wh / 2)
  58. new_bboxes = torch.cat([new_x1y1, new_x2y2], dim=1)
  59. # clip bboxes
  60. max_shape = img_meta['img_shape']
  61. if max_shape is not None:
  62. new_bboxes[:, 0::2].clamp_(min=0, max=max_shape[1] - 1)
  63. new_bboxes[:, 1::2].clamp_(min=0, max=max_shape[0] - 1)
  64. sampling_result.pos_priors = new_bboxes
  65. return sampling_results
  66. # TODO: Forward is incorrect and need to refactor.
  67. def forward(self,
  68. x: Tuple[Tensor],
  69. rpn_results_list: InstanceList,
  70. batch_data_samples: SampleList = None) -> tuple:
  71. """Network forward process. Usually includes backbone, neck and head
  72. forward without any post-processing.
  73. Args:
  74. x (Tuple[Tensor]): Multi-level features that may have different
  75. resolutions.
  76. rpn_results_list (list[:obj:`InstanceData`]): List of region
  77. proposals.
  78. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
  79. the meta information of each image and corresponding
  80. annotations.
  81. Returns
  82. tuple: A tuple of features from ``bbox_head`` and ``mask_head``
  83. forward.
  84. """
  85. results = ()
  86. proposals = [rpn_results.bboxes for rpn_results in rpn_results_list]
  87. rois = bbox2roi(proposals)
  88. # bbox head
  89. if self.with_bbox:
  90. bbox_results = self._bbox_forward(x, rois)
  91. results = results + (bbox_results['cls_score'], )
  92. if self.bbox_head.with_reg:
  93. results = results + (bbox_results['bbox_pred'], )
  94. # grid head
  95. grid_rois = rois[:100]
  96. grid_feats = self.grid_roi_extractor(
  97. x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois)
  98. if self.with_shared_head:
  99. grid_feats = self.shared_head(grid_feats)
  100. self.grid_head.test_mode = True
  101. grid_preds = self.grid_head(grid_feats)
  102. results = results + (grid_preds, )
  103. # mask head
  104. if self.with_mask:
  105. mask_rois = rois[:100]
  106. mask_results = self._mask_forward(x, mask_rois)
  107. results = results + (mask_results['mask_preds'], )
  108. return results
  109. def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  110. batch_data_samples: SampleList, **kwargs) -> dict:
  111. """Perform forward propagation and loss calculation of the detection
  112. roi on the features of the upstream network.
  113. Args:
  114. x (tuple[Tensor]): List of multi-level img features.
  115. rpn_results_list (list[:obj:`InstanceData`]): List of region
  116. proposals.
  117. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  118. data samples. It usually includes information such
  119. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  120. Returns:
  121. dict[str, Tensor]: A dictionary of loss components
  122. """
  123. assert len(rpn_results_list) == len(batch_data_samples)
  124. outputs = unpack_gt_instances(batch_data_samples)
  125. (batch_gt_instances, batch_gt_instances_ignore,
  126. batch_img_metas) = outputs
  127. # assign gts and sample proposals
  128. num_imgs = len(batch_data_samples)
  129. sampling_results = []
  130. for i in range(num_imgs):
  131. # rename rpn_results.bboxes to rpn_results.priors
  132. rpn_results = rpn_results_list[i]
  133. rpn_results.priors = rpn_results.pop('bboxes')
  134. assign_result = self.bbox_assigner.assign(
  135. rpn_results, batch_gt_instances[i],
  136. batch_gt_instances_ignore[i])
  137. sampling_result = self.bbox_sampler.sample(
  138. assign_result,
  139. rpn_results,
  140. batch_gt_instances[i],
  141. feats=[lvl_feat[i][None] for lvl_feat in x])
  142. sampling_results.append(sampling_result)
  143. losses = dict()
  144. # bbox head loss
  145. if self.with_bbox:
  146. bbox_results = self.bbox_loss(x, sampling_results, batch_img_metas)
  147. losses.update(bbox_results['loss_bbox'])
  148. # mask head forward and loss
  149. if self.with_mask:
  150. mask_results = self.mask_loss(x, sampling_results,
  151. bbox_results['bbox_feats'],
  152. batch_gt_instances)
  153. losses.update(mask_results['loss_mask'])
  154. return losses
  155. def bbox_loss(self,
  156. x: Tuple[Tensor],
  157. sampling_results: List[SamplingResult],
  158. batch_img_metas: Optional[List[dict]] = None) -> dict:
  159. """Perform forward propagation and loss calculation of the bbox head on
  160. the features of the upstream network.
  161. Args:
  162. x (tuple[Tensor]): List of multi-level img features.
  163. sampling_results (list[:obj:`SamplingResult`]): Sampling results.
  164. batch_img_metas (list[dict], optional): Meta information of each
  165. image, e.g., image size, scaling factor, etc.
  166. Returns:
  167. dict[str, Tensor]: Usually returns a dictionary with keys:
  168. - `cls_score` (Tensor): Classification scores.
  169. - `bbox_pred` (Tensor): Box energies / deltas.
  170. - `bbox_feats` (Tensor): Extract bbox RoI features.
  171. - `loss_bbox` (dict): A dictionary of bbox loss components.
  172. """
  173. assert batch_img_metas is not None
  174. bbox_results = super().bbox_loss(x, sampling_results)
  175. # Grid head forward and loss
  176. sampling_results = self._random_jitter(sampling_results,
  177. batch_img_metas)
  178. pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
  179. # GN in head does not support zero shape input
  180. if pos_rois.shape[0] == 0:
  181. return bbox_results
  182. grid_feats = self.grid_roi_extractor(
  183. x[:self.grid_roi_extractor.num_inputs], pos_rois)
  184. if self.with_shared_head:
  185. grid_feats = self.shared_head(grid_feats)
  186. # Accelerate training
  187. max_sample_num_grid = self.train_cfg.get('max_num_grid', 192)
  188. sample_idx = torch.randperm(
  189. grid_feats.shape[0])[:min(grid_feats.shape[0], max_sample_num_grid
  190. )]
  191. grid_feats = grid_feats[sample_idx]
  192. grid_pred = self.grid_head(grid_feats)
  193. loss_grid = self.grid_head.loss(grid_pred, sample_idx,
  194. sampling_results, self.train_cfg)
  195. bbox_results['loss_bbox'].update(loss_grid)
  196. return bbox_results
  197. def predict_bbox(self,
  198. x: Tuple[Tensor],
  199. batch_img_metas: List[dict],
  200. rpn_results_list: InstanceList,
  201. rcnn_test_cfg: ConfigType,
  202. rescale: bool = False) -> InstanceList:
  203. """Perform forward propagation of the bbox head and predict detection
  204. results on the features of the upstream network.
  205. Args:
  206. x (tuple[Tensor]): Feature maps of all scale level.
  207. batch_img_metas (list[dict]): List of image information.
  208. rpn_results_list (list[:obj:`InstanceData`]): List of region
  209. proposals.
  210. rcnn_test_cfg (:obj:`ConfigDict`): `test_cfg` of R-CNN.
  211. rescale (bool): If True, return boxes in original image space.
  212. Defaults to False.
  213. Returns:
  214. list[:obj:`InstanceData`]: Detection results of each image
  215. after the post process.
  216. Each item usually contains following keys.
  217. - scores (Tensor): Classification scores, has a shape \
  218. (num_instance, )
  219. - labels (Tensor): Labels of bboxes, has a shape (num_instances, ).
  220. - bboxes (Tensor): Has a shape (num_instances, 4), the last \
  221. dimension 4 arrange as (x1, y1, x2, y2).
  222. """
  223. results_list = super().predict_bbox(
  224. x,
  225. batch_img_metas=batch_img_metas,
  226. rpn_results_list=rpn_results_list,
  227. rcnn_test_cfg=rcnn_test_cfg,
  228. rescale=False)
  229. grid_rois = bbox2roi([res.bboxes for res in results_list])
  230. if grid_rois.shape[0] != 0:
  231. grid_feats = self.grid_roi_extractor(
  232. x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois)
  233. if self.with_shared_head:
  234. grid_feats = self.shared_head(grid_feats)
  235. self.grid_head.test_mode = True
  236. grid_preds = self.grid_head(grid_feats)
  237. results_list = self.grid_head.predict_by_feat(
  238. grid_preds=grid_preds,
  239. results_list=results_list,
  240. batch_img_metas=batch_img_metas,
  241. rescale=rescale)
  242. return results_list