123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- import warnings
- from typing import List, Optional, Union
- import numpy as np
- import torch
- from mmcv.ops import nms
- from mmengine.config import ConfigDict
- from torch import Tensor
- from mmdet.structures.bbox import bbox_mapping_back
- # TODO remove this, never be used in mmdet
- def merge_aug_proposals(aug_proposals, img_metas, cfg):
- """Merge augmented proposals (multiscale, flip, etc.)
- Args:
- aug_proposals (list[Tensor]): proposals from different testing
- schemes, shape (n, 5). Note that they are not rescaled to the
- original image size.
- img_metas (list[dict]): list of image info dict where each dict has:
- 'img_shape', 'scale_factor', 'flip', and may also contain
- 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
- For details on the values of these keys see
- `mmdet/datasets/pipelines/formatting.py:Collect`.
- cfg (dict): rpn test config.
- Returns:
- Tensor: shape (n, 4), proposals corresponding to original image scale.
- """
- cfg = copy.deepcopy(cfg)
- # deprecate arguments warning
- if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
- warnings.warn(
- 'In rpn_proposal or test_cfg, '
- 'nms_thr has been moved to a dict named nms as '
- 'iou_threshold, max_num has been renamed as max_per_img, '
- 'name of original arguments and the way to specify '
- 'iou_threshold of NMS will be deprecated.')
- if 'nms' not in cfg:
- cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
- if 'max_num' in cfg:
- if 'max_per_img' in cfg:
- assert cfg.max_num == cfg.max_per_img, f'You set max_num and ' \
- f'max_per_img at the same time, but get {cfg.max_num} ' \
- f'and {cfg.max_per_img} respectively' \
- f'Please delete max_num which will be deprecated.'
- else:
- cfg.max_per_img = cfg.max_num
- if 'nms_thr' in cfg:
- assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \
- f'iou_threshold in nms and ' \
- f'nms_thr at the same time, but get ' \
- f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \
- f' respectively. Please delete the nms_thr ' \
- f'which will be deprecated.'
- recovered_proposals = []
- for proposals, img_info in zip(aug_proposals, img_metas):
- img_shape = img_info['img_shape']
- scale_factor = img_info['scale_factor']
- flip = img_info['flip']
- flip_direction = img_info['flip_direction']
- _proposals = proposals.clone()
- _proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape,
- scale_factor, flip,
- flip_direction)
- recovered_proposals.append(_proposals)
- aug_proposals = torch.cat(recovered_proposals, dim=0)
- merged_proposals, _ = nms(aug_proposals[:, :4].contiguous(),
- aug_proposals[:, -1].contiguous(),
- cfg.nms.iou_threshold)
- scores = merged_proposals[:, 4]
- _, order = scores.sort(0, descending=True)
- num = min(cfg.max_per_img, merged_proposals.shape[0])
- order = order[:num]
- merged_proposals = merged_proposals[order, :]
- return merged_proposals
- # TODO remove this, never be used in mmdet
- def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg):
- """Merge augmented detection bboxes and scores.
- Args:
- aug_bboxes (list[Tensor]): shape (n, 4*#class)
- aug_scores (list[Tensor] or None): shape (n, #class)
- img_shapes (list[Tensor]): shape (3, ).
- rcnn_test_cfg (dict): rcnn test config.
- Returns:
- tuple: (bboxes, scores)
- """
- recovered_bboxes = []
- for bboxes, img_info in zip(aug_bboxes, img_metas):
- img_shape = img_info[0]['img_shape']
- scale_factor = img_info[0]['scale_factor']
- flip = img_info[0]['flip']
- flip_direction = img_info[0]['flip_direction']
- bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
- flip_direction)
- recovered_bboxes.append(bboxes)
- bboxes = torch.stack(recovered_bboxes).mean(dim=0)
- if aug_scores is None:
- return bboxes
- else:
- scores = torch.stack(aug_scores).mean(dim=0)
- return bboxes, scores
- def merge_aug_results(aug_batch_results, aug_batch_img_metas):
- """Merge augmented detection results, only bboxes corresponding score under
- flipping and multi-scale resizing can be processed now.
- Args:
- aug_batch_results (list[list[[obj:`InstanceData`]]):
- Detection results of multiple images with
- different augmentations.
- The outer list indicate the augmentation . The inter
- list indicate the batch dimension.
- Each item usually contains the following keys.
- - scores (Tensor): Classification scores, in shape
- (num_instance,)
- - labels (Tensor): Labels of bboxes, in shape
- (num_instances,).
- - bboxes (Tensor): In shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- aug_batch_img_metas (list[list[dict]]): The outer list
- indicates test-time augs (multiscale, flip, etc.)
- and the inner list indicates
- images in a batch. Each dict in the list contains
- information of an image in the batch.
- Returns:
- batch_results (list[obj:`InstanceData`]): Same with
- the input `aug_results` except that all bboxes have
- been mapped to the original scale.
- """
- num_augs = len(aug_batch_results)
- num_imgs = len(aug_batch_results[0])
- batch_results = []
- aug_batch_results = copy.deepcopy(aug_batch_results)
- for img_id in range(num_imgs):
- aug_results = []
- for aug_id in range(num_augs):
- img_metas = aug_batch_img_metas[aug_id][img_id]
- results = aug_batch_results[aug_id][img_id]
- img_shape = img_metas['img_shape']
- scale_factor = img_metas['scale_factor']
- flip = img_metas['flip']
- flip_direction = img_metas['flip_direction']
- bboxes = bbox_mapping_back(results.bboxes, img_shape, scale_factor,
- flip, flip_direction)
- results.bboxes = bboxes
- aug_results.append(results)
- merged_aug_results = results.cat(aug_results)
- batch_results.append(merged_aug_results)
- return batch_results
- def merge_aug_scores(aug_scores):
- """Merge augmented bbox scores."""
- if isinstance(aug_scores[0], torch.Tensor):
- return torch.mean(torch.stack(aug_scores), dim=0)
- else:
- return np.mean(aug_scores, axis=0)
- def merge_aug_masks(aug_masks: List[Tensor],
- img_metas: dict,
- weights: Optional[Union[list, Tensor]] = None) -> Tensor:
- """Merge augmented mask prediction.
- Args:
- aug_masks (list[Tensor]): each has shape
- (n, c, h, w).
- img_metas (dict): Image information.
- weights (list or Tensor): Weight of each aug_masks,
- the length should be n.
- Returns:
- Tensor: has shape (n, c, h, w)
- """
- recovered_masks = []
- for i, mask in enumerate(aug_masks):
- if weights is not None:
- assert len(weights) == len(aug_masks)
- weight = weights[i]
- else:
- weight = 1
- flip = img_metas.get('filp', False)
- if flip:
- flip_direction = img_metas['flip_direction']
- if flip_direction == 'horizontal':
- mask = mask[:, :, :, ::-1]
- elif flip_direction == 'vertical':
- mask = mask[:, :, ::-1, :]
- elif flip_direction == 'diagonal':
- mask = mask[:, :, :, ::-1]
- mask = mask[:, :, ::-1, :]
- else:
- raise ValueError(
- f"Invalid flipping direction '{flip_direction}'")
- recovered_masks.append(mask[None, :] * weight)
- merged_masks = torch.cat(recovered_masks, 0).mean(dim=0)
- if weights is not None:
- merged_masks = merged_masks * len(weights) / sum(weights)
- return merged_masks
|