123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import List, Tuple
- import torch
- from mmcv.ops import batched_nms
- from mmengine.model import BaseTTAModel
- from mmengine.registry import MODELS
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.structures import DetDataSample
- from mmdet.structures.bbox import bbox_flip
- @MODELS.register_module()
- class DetTTAModel(BaseTTAModel):
- """Merge augmented detection results, only bboxes corresponding score under
- flipping and multi-scale resizing can be processed now.
- Examples:
- >>> tta_model = dict(
- >>> type='DetTTAModel',
- >>> tta_cfg=dict(nms=dict(
- >>> type='nms',
- >>> iou_threshold=0.5),
- >>> max_per_img=100))
- >>>
- >>> tta_pipeline = [
- >>> dict(type='LoadImageFromFile',
- >>> backend_args=None),
- >>> dict(
- >>> type='TestTimeAug',
- >>> transforms=[[
- >>> dict(type='Resize',
- >>> scale=(1333, 800),
- >>> keep_ratio=True),
- >>> ], [
- >>> dict(type='RandomFlip', prob=1.),
- >>> dict(type='RandomFlip', prob=0.)
- >>> ], [
- >>> dict(
- >>> type='PackDetInputs',
- >>> meta_keys=('img_id', 'img_path', 'ori_shape',
- >>> 'img_shape', 'scale_factor', 'flip',
- >>> 'flip_direction'))
- >>> ]])]
- """
- def __init__(self, tta_cfg=None, **kwargs):
- super().__init__(**kwargs)
- self.tta_cfg = tta_cfg
- def merge_aug_bboxes(self, aug_bboxes: List[Tensor],
- aug_scores: List[Tensor],
- img_metas: List[str]) -> Tuple[Tensor, Tensor]:
- """Merge augmented detection bboxes and scores.
- Args:
- aug_bboxes (list[Tensor]): shape (n, 4*#class)
- aug_scores (list[Tensor] or None): shape (n, #class)
- Returns:
- tuple[Tensor]: ``bboxes`` with shape (n,4), where
- 4 represent (tl_x, tl_y, br_x, br_y)
- and ``scores`` with shape (n,).
- """
- recovered_bboxes = []
- for bboxes, img_info in zip(aug_bboxes, img_metas):
- ori_shape = img_info['ori_shape']
- flip = img_info['flip']
- flip_direction = img_info['flip_direction']
- if flip:
- bboxes = bbox_flip(
- bboxes=bboxes,
- img_shape=ori_shape,
- direction=flip_direction)
- recovered_bboxes.append(bboxes)
- bboxes = torch.cat(recovered_bboxes, dim=0)
- if aug_scores is None:
- return bboxes
- else:
- scores = torch.cat(aug_scores, dim=0)
- return bboxes, scores
- def merge_preds(self, data_samples_list: List[List[DetDataSample]]):
- """Merge batch predictions of enhanced data.
- Args:
- data_samples_list (List[List[DetDataSample]]): List of predictions
- of all enhanced data. The outer list indicates images, and the
- inner list corresponds to the different views of one image.
- Each element of the inner list is a ``DetDataSample``.
- Returns:
- List[DetDataSample]: Merged batch prediction.
- """
- merged_data_samples = []
- for data_samples in data_samples_list:
- merged_data_samples.append(self._merge_single_sample(data_samples))
- return merged_data_samples
- def _merge_single_sample(
- self, data_samples: List[DetDataSample]) -> DetDataSample:
- """Merge predictions which come form the different views of one image
- to one prediction.
- Args:
- data_samples (List[DetDataSample]): List of predictions
- of enhanced data which come form one image.
- Returns:
- List[DetDataSample]: Merged prediction.
- """
- aug_bboxes = []
- aug_scores = []
- aug_labels = []
- img_metas = []
- # TODO: support instance segmentation TTA
- assert data_samples[0].pred_instances.get('masks', None) is None, \
- 'TTA of instance segmentation does not support now.'
- for data_sample in data_samples:
- aug_bboxes.append(data_sample.pred_instances.bboxes)
- aug_scores.append(data_sample.pred_instances.scores)
- aug_labels.append(data_sample.pred_instances.labels)
- img_metas.append(data_sample.metainfo)
- merged_bboxes, merged_scores = self.merge_aug_bboxes(
- aug_bboxes, aug_scores, img_metas)
- merged_labels = torch.cat(aug_labels, dim=0)
- if merged_bboxes.numel() == 0:
- return data_samples[0]
- det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores,
- merged_labels, self.tta_cfg.nms)
- det_bboxes = det_bboxes[:self.tta_cfg.max_per_img]
- det_labels = merged_labels[keep_idxs][:self.tta_cfg.max_per_img]
- results = InstanceData()
- _det_bboxes = det_bboxes.clone()
- results.bboxes = _det_bboxes[:, :-1]
- results.scores = _det_bboxes[:, -1]
- results.labels = det_labels
- det_results = data_samples[0]
- det_results.pred_instances = results
- return det_results
|