test_mixins.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # TODO: delete this file after refactor
  3. import sys
  4. import torch
  5. from mmdet.models.layers import multiclass_nms
  6. from mmdet.models.test_time_augs import merge_aug_bboxes, merge_aug_masks
  7. from mmdet.structures.bbox import bbox2roi, bbox_mapping
  8. if sys.version_info >= (3, 7):
  9. from mmdet.utils.contextmanagers import completed
  10. class BBoxTestMixin:
  11. if sys.version_info >= (3, 7):
  12. # TODO: Currently not supported
  13. async def async_test_bboxes(self,
  14. x,
  15. img_metas,
  16. proposals,
  17. rcnn_test_cfg,
  18. rescale=False,
  19. **kwargs):
  20. """Asynchronized test for box head without augmentation."""
  21. rois = bbox2roi(proposals)
  22. roi_feats = self.bbox_roi_extractor(
  23. x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
  24. if self.with_shared_head:
  25. roi_feats = self.shared_head(roi_feats)
  26. sleep_interval = rcnn_test_cfg.get('async_sleep_interval', 0.017)
  27. async with completed(
  28. __name__, 'bbox_head_forward',
  29. sleep_interval=sleep_interval):
  30. cls_score, bbox_pred = self.bbox_head(roi_feats)
  31. img_shape = img_metas[0]['img_shape']
  32. scale_factor = img_metas[0]['scale_factor']
  33. det_bboxes, det_labels = self.bbox_head.get_bboxes(
  34. rois,
  35. cls_score,
  36. bbox_pred,
  37. img_shape,
  38. scale_factor,
  39. rescale=rescale,
  40. cfg=rcnn_test_cfg)
  41. return det_bboxes, det_labels
  42. # TODO: Currently not supported
  43. def aug_test_bboxes(self, feats, img_metas, rpn_results_list,
  44. rcnn_test_cfg):
  45. """Test det bboxes with test time augmentation."""
  46. aug_bboxes = []
  47. aug_scores = []
  48. for x, img_meta in zip(feats, img_metas):
  49. # only one image in the batch
  50. img_shape = img_meta[0]['img_shape']
  51. scale_factor = img_meta[0]['scale_factor']
  52. flip = img_meta[0]['flip']
  53. flip_direction = img_meta[0]['flip_direction']
  54. # TODO more flexible
  55. proposals = bbox_mapping(rpn_results_list[0][:, :4], img_shape,
  56. scale_factor, flip, flip_direction)
  57. rois = bbox2roi([proposals])
  58. bbox_results = self.bbox_forward(x, rois)
  59. bboxes, scores = self.bbox_head.get_bboxes(
  60. rois,
  61. bbox_results['cls_score'],
  62. bbox_results['bbox_pred'],
  63. img_shape,
  64. scale_factor,
  65. rescale=False,
  66. cfg=None)
  67. aug_bboxes.append(bboxes)
  68. aug_scores.append(scores)
  69. # after merging, bboxes will be rescaled to the original image size
  70. merged_bboxes, merged_scores = merge_aug_bboxes(
  71. aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
  72. if merged_bboxes.shape[0] == 0:
  73. # There is no proposal in the single image
  74. det_bboxes = merged_bboxes.new_zeros(0, 5)
  75. det_labels = merged_bboxes.new_zeros((0, ), dtype=torch.long)
  76. else:
  77. det_bboxes, det_labels = multiclass_nms(merged_bboxes,
  78. merged_scores,
  79. rcnn_test_cfg.score_thr,
  80. rcnn_test_cfg.nms,
  81. rcnn_test_cfg.max_per_img)
  82. return det_bboxes, det_labels
  83. class MaskTestMixin:
  84. if sys.version_info >= (3, 7):
  85. # TODO: Currently not supported
  86. async def async_test_mask(self,
  87. x,
  88. img_metas,
  89. det_bboxes,
  90. det_labels,
  91. rescale=False,
  92. mask_test_cfg=None):
  93. """Asynchronized test for mask head without augmentation."""
  94. # image shape of the first image in the batch (only one)
  95. ori_shape = img_metas[0]['ori_shape']
  96. scale_factor = img_metas[0]['scale_factor']
  97. if det_bboxes.shape[0] == 0:
  98. segm_result = [[] for _ in range(self.mask_head.num_classes)]
  99. else:
  100. if rescale and not isinstance(scale_factor,
  101. (float, torch.Tensor)):
  102. scale_factor = det_bboxes.new_tensor(scale_factor)
  103. _bboxes = (
  104. det_bboxes[:, :4] *
  105. scale_factor if rescale else det_bboxes)
  106. mask_rois = bbox2roi([_bboxes])
  107. mask_feats = self.mask_roi_extractor(
  108. x[:len(self.mask_roi_extractor.featmap_strides)],
  109. mask_rois)
  110. if self.with_shared_head:
  111. mask_feats = self.shared_head(mask_feats)
  112. if mask_test_cfg and \
  113. mask_test_cfg.get('async_sleep_interval'):
  114. sleep_interval = mask_test_cfg['async_sleep_interval']
  115. else:
  116. sleep_interval = 0.035
  117. async with completed(
  118. __name__,
  119. 'mask_head_forward',
  120. sleep_interval=sleep_interval):
  121. mask_pred = self.mask_head(mask_feats)
  122. segm_result = self.mask_head.get_results(
  123. mask_pred, _bboxes, det_labels, self.test_cfg, ori_shape,
  124. scale_factor, rescale)
  125. return segm_result
  126. # TODO: Currently not supported
  127. def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
  128. """Test for mask head with test time augmentation."""
  129. if det_bboxes.shape[0] == 0:
  130. segm_result = [[] for _ in range(self.mask_head.num_classes)]
  131. else:
  132. aug_masks = []
  133. for x, img_meta in zip(feats, img_metas):
  134. img_shape = img_meta[0]['img_shape']
  135. scale_factor = img_meta[0]['scale_factor']
  136. flip = img_meta[0]['flip']
  137. flip_direction = img_meta[0]['flip_direction']
  138. _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
  139. scale_factor, flip, flip_direction)
  140. mask_rois = bbox2roi([_bboxes])
  141. mask_results = self._mask_forward(x, mask_rois)
  142. # convert to numpy array to save memory
  143. aug_masks.append(
  144. mask_results['mask_pred'].sigmoid().cpu().numpy())
  145. merged_masks = merge_aug_masks(aug_masks, img_metas, self.test_cfg)
  146. ori_shape = img_metas[0][0]['ori_shape']
  147. scale_factor = det_bboxes.new_ones(4)
  148. segm_result = self.mask_head.get_results(
  149. merged_masks,
  150. det_bboxes,
  151. det_labels,
  152. self.test_cfg,
  153. ori_shape,
  154. scale_factor=scale_factor,
  155. rescale=False)
  156. return segm_result