123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- # Copyright (c) OpenMMLab. All rights reserved.
- from typing import Optional, Union
- from mmengine.config import ConfigDict
- from mmengine.structures import InstanceData
- from torch import Tensor
- from mmdet.models.layers import multiclass_nms
- from mmdet.models.roi_heads.bbox_heads import Shared2FCBBoxHead
- from mmdet.models.utils import empty_instances
- from mmdet.registry import MODELS
- from mmdet.structures.bbox import get_box_tensor, scale_boxes
- @MODELS.register_module(force=True) # avoid bug
- class DeticBBoxHead(Shared2FCBBoxHead):
- def __init__(self,
- *args,
- init_cfg: Optional[Union[dict, ConfigDict]] = None,
- **kwargs) -> None:
- super().__init__(*args, init_cfg=init_cfg, **kwargs)
- # reconstruct fc_cls and fc_reg since input channels are changed
- assert self.with_cls
- cls_channels = self.num_classes
- cls_predictor_cfg_ = self.cls_predictor_cfg.copy()
- cls_predictor_cfg_.update(
- in_features=self.cls_last_dim, out_features=cls_channels)
- self.fc_cls = MODELS.build(cls_predictor_cfg_)
- def _predict_by_feat_single(
- self,
- roi: Tensor,
- cls_score: Tensor,
- bbox_pred: Tensor,
- img_meta: dict,
- rescale: bool = False,
- rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData:
- """Transform a single image's features extracted from the head into
- bbox results.
- Args:
- roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
- last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
- cls_score (Tensor): Box scores, has shape
- (num_boxes, num_classes + 1).
- bbox_pred (Tensor): Box energies / deltas.
- has shape (num_boxes, num_classes * 4).
- img_meta (dict): image information.
- rescale (bool): If True, return boxes in original image space.
- Defaults to False.
- rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head.
- Defaults to None
- Returns:
- :obj:`InstanceData`: Detection results of each image\
- Each item usually contains following keys.
- - scores (Tensor): Classification scores, has a shape
- (num_instance, )
- - labels (Tensor): Labels of bboxes, has a shape
- (num_instances, ).
- - bboxes (Tensor): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- """
- results = InstanceData()
- if roi.shape[0] == 0:
- return empty_instances([img_meta],
- roi.device,
- task_type='bbox',
- instance_results=[results],
- box_type=self.predict_box_type,
- use_box_type=False,
- num_classes=self.num_classes,
- score_per_cls=rcnn_test_cfg is None)[0]
- scores = cls_score
- img_shape = img_meta['img_shape']
- num_rois = roi.size(0)
- num_classes = 1 if self.reg_class_agnostic else self.num_classes
- roi = roi.repeat_interleave(num_classes, dim=0)
- bbox_pred = bbox_pred.view(-1, self.bbox_coder.encode_size)
- bboxes = self.bbox_coder.decode(
- roi[..., 1:], bbox_pred, max_shape=img_shape)
- if rescale and bboxes.size(0) > 0:
- assert img_meta.get('scale_factor') is not None
- scale_factor = [1 / s for s in img_meta['scale_factor']]
- bboxes = scale_boxes(bboxes, scale_factor)
- # Get the inside tensor when `bboxes` is a box type
- bboxes = get_box_tensor(bboxes)
- box_dim = bboxes.size(-1)
- bboxes = bboxes.view(num_rois, -1)
- if rcnn_test_cfg is None:
- # This means that it is aug test.
- # It needs to return the raw results without nms.
- results.bboxes = bboxes
- results.scores = scores
- else:
- det_bboxes, det_labels = multiclass_nms(
- bboxes,
- scores,
- rcnn_test_cfg.score_thr,
- rcnn_test_cfg.nms,
- rcnn_test_cfg.max_per_img,
- box_dim=box_dim)
- results.bboxes = det_bboxes[:, :-1]
- results.scores = det_bboxes[:, -1]
- results.labels = det_labels
- return results
|