det_tta.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. import torch
  4. from mmcv.ops import batched_nms
  5. from mmengine.model import BaseTTAModel
  6. from mmengine.registry import MODELS
  7. from mmengine.structures import InstanceData
  8. from torch import Tensor
  9. from mmdet.structures import DetDataSample
  10. from mmdet.structures.bbox import bbox_flip
  11. @MODELS.register_module()
  12. class DetTTAModel(BaseTTAModel):
  13. """Merge augmented detection results, only bboxes corresponding score under
  14. flipping and multi-scale resizing can be processed now.
  15. Examples:
  16. >>> tta_model = dict(
  17. >>> type='DetTTAModel',
  18. >>> tta_cfg=dict(nms=dict(
  19. >>> type='nms',
  20. >>> iou_threshold=0.5),
  21. >>> max_per_img=100))
  22. >>>
  23. >>> tta_pipeline = [
  24. >>> dict(type='LoadImageFromFile',
  25. >>> backend_args=None),
  26. >>> dict(
  27. >>> type='TestTimeAug',
  28. >>> transforms=[[
  29. >>> dict(type='Resize',
  30. >>> scale=(1333, 800),
  31. >>> keep_ratio=True),
  32. >>> ], [
  33. >>> dict(type='RandomFlip', prob=1.),
  34. >>> dict(type='RandomFlip', prob=0.)
  35. >>> ], [
  36. >>> dict(
  37. >>> type='PackDetInputs',
  38. >>> meta_keys=('img_id', 'img_path', 'ori_shape',
  39. >>> 'img_shape', 'scale_factor', 'flip',
  40. >>> 'flip_direction'))
  41. >>> ]])]
  42. """
  43. def __init__(self, tta_cfg=None, **kwargs):
  44. super().__init__(**kwargs)
  45. self.tta_cfg = tta_cfg
  46. def merge_aug_bboxes(self, aug_bboxes: List[Tensor],
  47. aug_scores: List[Tensor],
  48. img_metas: List[str]) -> Tuple[Tensor, Tensor]:
  49. """Merge augmented detection bboxes and scores.
  50. Args:
  51. aug_bboxes (list[Tensor]): shape (n, 4*#class)
  52. aug_scores (list[Tensor] or None): shape (n, #class)
  53. Returns:
  54. tuple[Tensor]: ``bboxes`` with shape (n,4), where
  55. 4 represent (tl_x, tl_y, br_x, br_y)
  56. and ``scores`` with shape (n,).
  57. """
  58. recovered_bboxes = []
  59. for bboxes, img_info in zip(aug_bboxes, img_metas):
  60. ori_shape = img_info['ori_shape']
  61. flip = img_info['flip']
  62. flip_direction = img_info['flip_direction']
  63. if flip:
  64. bboxes = bbox_flip(
  65. bboxes=bboxes,
  66. img_shape=ori_shape,
  67. direction=flip_direction)
  68. recovered_bboxes.append(bboxes)
  69. bboxes = torch.cat(recovered_bboxes, dim=0)
  70. if aug_scores is None:
  71. return bboxes
  72. else:
  73. scores = torch.cat(aug_scores, dim=0)
  74. return bboxes, scores
  75. def merge_preds(self, data_samples_list: List[List[DetDataSample]]):
  76. """Merge batch predictions of enhanced data.
  77. Args:
  78. data_samples_list (List[List[DetDataSample]]): List of predictions
  79. of all enhanced data. The outer list indicates images, and the
  80. inner list corresponds to the different views of one image.
  81. Each element of the inner list is a ``DetDataSample``.
  82. Returns:
  83. List[DetDataSample]: Merged batch prediction.
  84. """
  85. merged_data_samples = []
  86. for data_samples in data_samples_list:
  87. merged_data_samples.append(self._merge_single_sample(data_samples))
  88. return merged_data_samples
  89. def _merge_single_sample(
  90. self, data_samples: List[DetDataSample]) -> DetDataSample:
  91. """Merge predictions which come form the different views of one image
  92. to one prediction.
  93. Args:
  94. data_samples (List[DetDataSample]): List of predictions
  95. of enhanced data which come form one image.
  96. Returns:
  97. List[DetDataSample]: Merged prediction.
  98. """
  99. aug_bboxes = []
  100. aug_scores = []
  101. aug_labels = []
  102. img_metas = []
  103. # TODO: support instance segmentation TTA
  104. assert data_samples[0].pred_instances.get('masks', None) is None, \
  105. 'TTA of instance segmentation does not support now.'
  106. for data_sample in data_samples:
  107. aug_bboxes.append(data_sample.pred_instances.bboxes)
  108. aug_scores.append(data_sample.pred_instances.scores)
  109. aug_labels.append(data_sample.pred_instances.labels)
  110. img_metas.append(data_sample.metainfo)
  111. merged_bboxes, merged_scores = self.merge_aug_bboxes(
  112. aug_bboxes, aug_scores, img_metas)
  113. merged_labels = torch.cat(aug_labels, dim=0)
  114. if merged_bboxes.numel() == 0:
  115. return data_samples[0]
  116. det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores,
  117. merged_labels, self.tta_cfg.nms)
  118. det_bboxes = det_bboxes[:self.tta_cfg.max_per_img]
  119. det_labels = merged_labels[keep_idxs][:self.tta_cfg.max_per_img]
  120. results = InstanceData()
  121. _det_bboxes = det_bboxes.clone()
  122. results.bboxes = _det_bboxes[:, :-1]
  123. results.scores = _det_bboxes[:, -1]
  124. results.labels = det_labels
  125. det_results = data_samples[0]
  126. det_results.pred_instances = results
  127. return det_results