dense_test_mixins.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import sys
  3. import warnings
  4. from inspect import signature
  5. import torch
  6. from mmcv.ops import batched_nms
  7. from mmengine.structures import InstanceData
  8. from mmdet.structures.bbox import bbox_mapping_back
  9. from ..test_time_augs import merge_aug_proposals
  10. if sys.version_info >= (3, 7):
  11. from mmdet.utils.contextmanagers import completed
  12. class BBoxTestMixin(object):
  13. """Mixin class for testing det bboxes via DenseHead."""
  14. def simple_test_bboxes(self, feats, img_metas, rescale=False):
  15. """Test det bboxes without test-time augmentation, can be applied in
  16. DenseHead except for ``RPNHead`` and its variants, e.g., ``GARPNHead``,
  17. etc.
  18. Args:
  19. feats (tuple[torch.Tensor]): Multi-level features from the
  20. upstream network, each is a 4D-tensor.
  21. img_metas (list[dict]): List of image information.
  22. rescale (bool, optional): Whether to rescale the results.
  23. Defaults to False.
  24. Returns:
  25. list[obj:`InstanceData`]: Detection results of each
  26. image after the post process. \
  27. Each item usually contains following keys. \
  28. - scores (Tensor): Classification scores, has a shape
  29. (num_instance,)
  30. - labels (Tensor): Labels of bboxes, has a shape
  31. (num_instances,).
  32. - bboxes (Tensor): Has a shape (num_instances, 4),
  33. the last dimension 4 arrange as (x1, y1, x2, y2).
  34. """
  35. warnings.warn('You are calling `simple_test_bboxes` in '
  36. '`dense_test_mixins`, but the `dense_test_mixins`'
  37. 'will be deprecated soon. Please use '
  38. '`simple_test` instead.')
  39. outs = self.forward(feats)
  40. results_list = self.get_results(
  41. *outs, img_metas=img_metas, rescale=rescale)
  42. return results_list
  43. def aug_test_bboxes(self, feats, img_metas, rescale=False):
  44. """Test det bboxes with test time augmentation, can be applied in
  45. DenseHead except for ``RPNHead`` and its variants, e.g., ``GARPNHead``,
  46. etc.
  47. Args:
  48. feats (list[Tensor]): the outer list indicates test-time
  49. augmentations and inner Tensor should have a shape NxCxHxW,
  50. which contains features for all images in the batch.
  51. img_metas (list[list[dict]]): the outer list indicates test-time
  52. augs (multiscale, flip, etc.) and the inner list indicates
  53. images in a batch. each dict has image information.
  54. rescale (bool, optional): Whether to rescale the results.
  55. Defaults to False.
  56. Returns:
  57. list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
  58. The first item is ``bboxes`` with shape (n, 5),
  59. where 5 represent (tl_x, tl_y, br_x, br_y, score).
  60. The shape of the second tensor in the tuple is ``labels``
  61. with shape (n,). The length of list should always be 1.
  62. """
  63. warnings.warn('You are calling `aug_test_bboxes` in '
  64. '`dense_test_mixins`, but the `dense_test_mixins`'
  65. 'will be deprecated soon. Please use '
  66. '`aug_test` instead.')
  67. # check with_nms argument
  68. gb_sig = signature(self.get_results)
  69. gb_args = [p.name for p in gb_sig.parameters.values()]
  70. gbs_sig = signature(self._get_results_single)
  71. gbs_args = [p.name for p in gbs_sig.parameters.values()]
  72. assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \
  73. f'{self.__class__.__name__}' \
  74. ' does not support test-time augmentation'
  75. aug_bboxes = []
  76. aug_scores = []
  77. aug_labels = []
  78. for x, img_meta in zip(feats, img_metas):
  79. # only one image in the batch
  80. outs = self.forward(x)
  81. bbox_outputs = self.get_results(
  82. *outs,
  83. img_metas=img_meta,
  84. cfg=self.test_cfg,
  85. rescale=False,
  86. with_nms=False)[0]
  87. aug_bboxes.append(bbox_outputs.bboxes)
  88. aug_scores.append(bbox_outputs.scores)
  89. if len(bbox_outputs) >= 3:
  90. aug_labels.append(bbox_outputs.labels)
  91. # after merging, bboxes will be rescaled to the original image size
  92. merged_bboxes, merged_scores = self.merge_aug_bboxes(
  93. aug_bboxes, aug_scores, img_metas)
  94. merged_labels = torch.cat(aug_labels, dim=0) if aug_labels else None
  95. if merged_bboxes.numel() == 0:
  96. det_bboxes = torch.cat([merged_bboxes, merged_scores[:, None]], -1)
  97. return [
  98. (det_bboxes, merged_labels),
  99. ]
  100. det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores,
  101. merged_labels, self.test_cfg.nms)
  102. det_bboxes = det_bboxes[:self.test_cfg.max_per_img]
  103. det_labels = merged_labels[keep_idxs][:self.test_cfg.max_per_img]
  104. if rescale:
  105. _det_bboxes = det_bboxes
  106. else:
  107. _det_bboxes = det_bboxes.clone()
  108. _det_bboxes[:, :4] *= det_bboxes.new_tensor(
  109. img_metas[0][0]['scale_factor'])
  110. results = InstanceData()
  111. results.bboxes = _det_bboxes[:, :4]
  112. results.scores = _det_bboxes[:, 4]
  113. results.labels = det_labels
  114. return [results]
  115. def aug_test_rpn(self, feats, img_metas):
  116. """Test with augmentation for only for ``RPNHead`` and its variants,
  117. e.g., ``GARPNHead``, etc.
  118. Args:
  119. feats (tuple[Tensor]): Features from the upstream network, each is
  120. a 4D-tensor.
  121. img_metas (list[dict]): Meta info of each image.
  122. Returns:
  123. list[Tensor]: Proposals of each image, each item has shape (n, 5),
  124. where 5 represent (tl_x, tl_y, br_x, br_y, score).
  125. """
  126. samples_per_gpu = len(img_metas[0])
  127. aug_proposals = [[] for _ in range(samples_per_gpu)]
  128. for x, img_meta in zip(feats, img_metas):
  129. results_list = self.simple_test_rpn(x, img_meta)
  130. for i, results in enumerate(results_list):
  131. proposals = torch.cat(
  132. [results.bboxes, results.scores[:, None]], dim=-1)
  133. aug_proposals[i].append(proposals)
  134. # reorganize the order of 'img_metas' to match the dimensions
  135. # of 'aug_proposals'
  136. aug_img_metas = []
  137. for i in range(samples_per_gpu):
  138. aug_img_meta = []
  139. for j in range(len(img_metas)):
  140. aug_img_meta.append(img_metas[j][i])
  141. aug_img_metas.append(aug_img_meta)
  142. # after merging, proposals will be rescaled to the original image size
  143. merged_proposals = []
  144. for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas):
  145. merged_proposal = merge_aug_proposals(proposals, aug_img_meta,
  146. self.test_cfg)
  147. results = InstanceData()
  148. results.bboxes = merged_proposal[:, :4]
  149. results.scores = merged_proposal[:, 4]
  150. merged_proposals.append(results)
  151. return merged_proposals
  152. if sys.version_info >= (3, 7):
  153. async def async_simple_test_rpn(self, x, img_metas):
  154. sleep_interval = self.test_cfg.pop('async_sleep_interval', 0.025)
  155. async with completed(
  156. __name__, 'rpn_head_forward',
  157. sleep_interval=sleep_interval):
  158. rpn_outs = self(x)
  159. proposal_list = self.get_results(*rpn_outs, img_metas=img_metas)
  160. return proposal_list
  161. def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
  162. """Merge augmented detection bboxes and scores.
  163. Args:
  164. aug_bboxes (list[Tensor]): shape (n, 4*#class)
  165. aug_scores (list[Tensor] or None): shape (n, #class)
  166. img_shapes (list[Tensor]): shape (3, ).
  167. Returns:
  168. tuple[Tensor]: ``bboxes`` with shape (n,4), where
  169. 4 represent (tl_x, tl_y, br_x, br_y)
  170. and ``scores`` with shape (n,).
  171. """
  172. recovered_bboxes = []
  173. for bboxes, img_info in zip(aug_bboxes, img_metas):
  174. img_shape = img_info[0]['img_shape']
  175. scale_factor = img_info[0]['scale_factor']
  176. flip = img_info[0]['flip']
  177. flip_direction = img_info[0]['flip_direction']
  178. bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
  179. flip_direction)
  180. recovered_bboxes.append(bboxes)
  181. bboxes = torch.cat(recovered_bboxes, dim=0)
  182. if aug_scores is None:
  183. return bboxes
  184. else:
  185. scores = torch.cat(aug_scores, dim=0)
  186. return bboxes, scores