# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Tuple import torch from mmcv.ops import batched_nms from mmengine.model import BaseTTAModel from mmengine.registry import MODELS from mmengine.structures import InstanceData from torch import Tensor from mmdet.structures import DetDataSample from mmdet.structures.bbox import bbox_flip @MODELS.register_module() class DetTTAModel(BaseTTAModel): """Merge augmented detection results, only bboxes corresponding score under flipping and multi-scale resizing can be processed now. Examples: >>> tta_model = dict( >>> type='DetTTAModel', >>> tta_cfg=dict(nms=dict( >>> type='nms', >>> iou_threshold=0.5), >>> max_per_img=100)) >>> >>> tta_pipeline = [ >>> dict(type='LoadImageFromFile', >>> backend_args=None), >>> dict( >>> type='TestTimeAug', >>> transforms=[[ >>> dict(type='Resize', >>> scale=(1333, 800), >>> keep_ratio=True), >>> ], [ >>> dict(type='RandomFlip', prob=1.), >>> dict(type='RandomFlip', prob=0.) >>> ], [ >>> dict( >>> type='PackDetInputs', >>> meta_keys=('img_id', 'img_path', 'ori_shape', >>> 'img_shape', 'scale_factor', 'flip', >>> 'flip_direction')) >>> ]])] """ def __init__(self, tta_cfg=None, **kwargs): super().__init__(**kwargs) self.tta_cfg = tta_cfg def merge_aug_bboxes(self, aug_bboxes: List[Tensor], aug_scores: List[Tensor], img_metas: List[str]) -> Tuple[Tensor, Tensor]: """Merge augmented detection bboxes and scores. Args: aug_bboxes (list[Tensor]): shape (n, 4*#class) aug_scores (list[Tensor] or None): shape (n, #class) Returns: tuple[Tensor]: ``bboxes`` with shape (n,4), where 4 represent (tl_x, tl_y, br_x, br_y) and ``scores`` with shape (n,). """ recovered_bboxes = [] for bboxes, img_info in zip(aug_bboxes, img_metas): ori_shape = img_info['ori_shape'] flip = img_info['flip'] flip_direction = img_info['flip_direction'] if flip: bboxes = bbox_flip( bboxes=bboxes, img_shape=ori_shape, direction=flip_direction) recovered_bboxes.append(bboxes) bboxes = torch.cat(recovered_bboxes, dim=0) if aug_scores is None: return bboxes else: scores = torch.cat(aug_scores, dim=0) return bboxes, scores def merge_preds(self, data_samples_list: List[List[DetDataSample]]): """Merge batch predictions of enhanced data. Args: data_samples_list (List[List[DetDataSample]]): List of predictions of all enhanced data. The outer list indicates images, and the inner list corresponds to the different views of one image. Each element of the inner list is a ``DetDataSample``. Returns: List[DetDataSample]: Merged batch prediction. """ merged_data_samples = [] for data_samples in data_samples_list: merged_data_samples.append(self._merge_single_sample(data_samples)) return merged_data_samples def _merge_single_sample( self, data_samples: List[DetDataSample]) -> DetDataSample: """Merge predictions which come form the different views of one image to one prediction. Args: data_samples (List[DetDataSample]): List of predictions of enhanced data which come form one image. Returns: List[DetDataSample]: Merged prediction. """ aug_bboxes = [] aug_scores = [] aug_labels = [] img_metas = [] # TODO: support instance segmentation TTA assert data_samples[0].pred_instances.get('masks', None) is None, \ 'TTA of instance segmentation does not support now.' for data_sample in data_samples: aug_bboxes.append(data_sample.pred_instances.bboxes) aug_scores.append(data_sample.pred_instances.scores) aug_labels.append(data_sample.pred_instances.labels) img_metas.append(data_sample.metainfo) merged_bboxes, merged_scores = self.merge_aug_bboxes( aug_bboxes, aug_scores, img_metas) merged_labels = torch.cat(aug_labels, dim=0) if merged_bboxes.numel() == 0: return data_samples[0] det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores, merged_labels, self.tta_cfg.nms) det_bboxes = det_bboxes[:self.tta_cfg.max_per_img] det_labels = merged_labels[keep_idxs][:self.tta_cfg.max_per_img] results = InstanceData() _det_bboxes = det_bboxes.clone() results.bboxes = _det_bboxes[:, :-1] results.scores = _det_bboxes[:, -1] results.labels = det_labels det_results = data_samples[0] det_results.pred_instances = results return det_results