merge_augs.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import warnings
  4. from typing import List, Optional, Union
  5. import numpy as np
  6. import torch
  7. from mmcv.ops import nms
  8. from mmengine.config import ConfigDict
  9. from torch import Tensor
  10. from mmdet.structures.bbox import bbox_mapping_back
  11. # TODO remove this, never be used in mmdet
  12. def merge_aug_proposals(aug_proposals, img_metas, cfg):
  13. """Merge augmented proposals (multiscale, flip, etc.)
  14. Args:
  15. aug_proposals (list[Tensor]): proposals from different testing
  16. schemes, shape (n, 5). Note that they are not rescaled to the
  17. original image size.
  18. img_metas (list[dict]): list of image info dict where each dict has:
  19. 'img_shape', 'scale_factor', 'flip', and may also contain
  20. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  21. For details on the values of these keys see
  22. `mmdet/datasets/pipelines/formatting.py:Collect`.
  23. cfg (dict): rpn test config.
  24. Returns:
  25. Tensor: shape (n, 4), proposals corresponding to original image scale.
  26. """
  27. cfg = copy.deepcopy(cfg)
  28. # deprecate arguments warning
  29. if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
  30. warnings.warn(
  31. 'In rpn_proposal or test_cfg, '
  32. 'nms_thr has been moved to a dict named nms as '
  33. 'iou_threshold, max_num has been renamed as max_per_img, '
  34. 'name of original arguments and the way to specify '
  35. 'iou_threshold of NMS will be deprecated.')
  36. if 'nms' not in cfg:
  37. cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
  38. if 'max_num' in cfg:
  39. if 'max_per_img' in cfg:
  40. assert cfg.max_num == cfg.max_per_img, f'You set max_num and ' \
  41. f'max_per_img at the same time, but get {cfg.max_num} ' \
  42. f'and {cfg.max_per_img} respectively' \
  43. f'Please delete max_num which will be deprecated.'
  44. else:
  45. cfg.max_per_img = cfg.max_num
  46. if 'nms_thr' in cfg:
  47. assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \
  48. f'iou_threshold in nms and ' \
  49. f'nms_thr at the same time, but get ' \
  50. f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \
  51. f' respectively. Please delete the nms_thr ' \
  52. f'which will be deprecated.'
  53. recovered_proposals = []
  54. for proposals, img_info in zip(aug_proposals, img_metas):
  55. img_shape = img_info['img_shape']
  56. scale_factor = img_info['scale_factor']
  57. flip = img_info['flip']
  58. flip_direction = img_info['flip_direction']
  59. _proposals = proposals.clone()
  60. _proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape,
  61. scale_factor, flip,
  62. flip_direction)
  63. recovered_proposals.append(_proposals)
  64. aug_proposals = torch.cat(recovered_proposals, dim=0)
  65. merged_proposals, _ = nms(aug_proposals[:, :4].contiguous(),
  66. aug_proposals[:, -1].contiguous(),
  67. cfg.nms.iou_threshold)
  68. scores = merged_proposals[:, 4]
  69. _, order = scores.sort(0, descending=True)
  70. num = min(cfg.max_per_img, merged_proposals.shape[0])
  71. order = order[:num]
  72. merged_proposals = merged_proposals[order, :]
  73. return merged_proposals
  74. # TODO remove this, never be used in mmdet
  75. def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg):
  76. """Merge augmented detection bboxes and scores.
  77. Args:
  78. aug_bboxes (list[Tensor]): shape (n, 4*#class)
  79. aug_scores (list[Tensor] or None): shape (n, #class)
  80. img_shapes (list[Tensor]): shape (3, ).
  81. rcnn_test_cfg (dict): rcnn test config.
  82. Returns:
  83. tuple: (bboxes, scores)
  84. """
  85. recovered_bboxes = []
  86. for bboxes, img_info in zip(aug_bboxes, img_metas):
  87. img_shape = img_info[0]['img_shape']
  88. scale_factor = img_info[0]['scale_factor']
  89. flip = img_info[0]['flip']
  90. flip_direction = img_info[0]['flip_direction']
  91. bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
  92. flip_direction)
  93. recovered_bboxes.append(bboxes)
  94. bboxes = torch.stack(recovered_bboxes).mean(dim=0)
  95. if aug_scores is None:
  96. return bboxes
  97. else:
  98. scores = torch.stack(aug_scores).mean(dim=0)
  99. return bboxes, scores
  100. def merge_aug_results(aug_batch_results, aug_batch_img_metas):
  101. """Merge augmented detection results, only bboxes corresponding score under
  102. flipping and multi-scale resizing can be processed now.
  103. Args:
  104. aug_batch_results (list[list[[obj:`InstanceData`]]):
  105. Detection results of multiple images with
  106. different augmentations.
  107. The outer list indicate the augmentation . The inter
  108. list indicate the batch dimension.
  109. Each item usually contains the following keys.
  110. - scores (Tensor): Classification scores, in shape
  111. (num_instance,)
  112. - labels (Tensor): Labels of bboxes, in shape
  113. (num_instances,).
  114. - bboxes (Tensor): In shape (num_instances, 4),
  115. the last dimension 4 arrange as (x1, y1, x2, y2).
  116. aug_batch_img_metas (list[list[dict]]): The outer list
  117. indicates test-time augs (multiscale, flip, etc.)
  118. and the inner list indicates
  119. images in a batch. Each dict in the list contains
  120. information of an image in the batch.
  121. Returns:
  122. batch_results (list[obj:`InstanceData`]): Same with
  123. the input `aug_results` except that all bboxes have
  124. been mapped to the original scale.
  125. """
  126. num_augs = len(aug_batch_results)
  127. num_imgs = len(aug_batch_results[0])
  128. batch_results = []
  129. aug_batch_results = copy.deepcopy(aug_batch_results)
  130. for img_id in range(num_imgs):
  131. aug_results = []
  132. for aug_id in range(num_augs):
  133. img_metas = aug_batch_img_metas[aug_id][img_id]
  134. results = aug_batch_results[aug_id][img_id]
  135. img_shape = img_metas['img_shape']
  136. scale_factor = img_metas['scale_factor']
  137. flip = img_metas['flip']
  138. flip_direction = img_metas['flip_direction']
  139. bboxes = bbox_mapping_back(results.bboxes, img_shape, scale_factor,
  140. flip, flip_direction)
  141. results.bboxes = bboxes
  142. aug_results.append(results)
  143. merged_aug_results = results.cat(aug_results)
  144. batch_results.append(merged_aug_results)
  145. return batch_results
  146. def merge_aug_scores(aug_scores):
  147. """Merge augmented bbox scores."""
  148. if isinstance(aug_scores[0], torch.Tensor):
  149. return torch.mean(torch.stack(aug_scores), dim=0)
  150. else:
  151. return np.mean(aug_scores, axis=0)
  152. def merge_aug_masks(aug_masks: List[Tensor],
  153. img_metas: dict,
  154. weights: Optional[Union[list, Tensor]] = None) -> Tensor:
  155. """Merge augmented mask prediction.
  156. Args:
  157. aug_masks (list[Tensor]): each has shape
  158. (n, c, h, w).
  159. img_metas (dict): Image information.
  160. weights (list or Tensor): Weight of each aug_masks,
  161. the length should be n.
  162. Returns:
  163. Tensor: has shape (n, c, h, w)
  164. """
  165. recovered_masks = []
  166. for i, mask in enumerate(aug_masks):
  167. if weights is not None:
  168. assert len(weights) == len(aug_masks)
  169. weight = weights[i]
  170. else:
  171. weight = 1
  172. flip = img_metas.get('filp', False)
  173. if flip:
  174. flip_direction = img_metas['flip_direction']
  175. if flip_direction == 'horizontal':
  176. mask = mask[:, :, :, ::-1]
  177. elif flip_direction == 'vertical':
  178. mask = mask[:, :, ::-1, :]
  179. elif flip_direction == 'diagonal':
  180. mask = mask[:, :, :, ::-1]
  181. mask = mask[:, :, ::-1, :]
  182. else:
  183. raise ValueError(
  184. f"Invalid flipping direction '{flip_direction}'")
  185. recovered_masks.append(mask[None, :] * weight)
  186. merged_masks = torch.cat(recovered_masks, 0).mean(dim=0)
  187. if weights is not None:
  188. merged_masks = merged_masks * len(weights) / sum(weights)
  189. return merged_masks