detic_bbox_head.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Optional, Union
  3. from mmengine.config import ConfigDict
  4. from mmengine.structures import InstanceData
  5. from torch import Tensor
  6. from mmdet.models.layers import multiclass_nms
  7. from mmdet.models.roi_heads.bbox_heads import Shared2FCBBoxHead
  8. from mmdet.models.utils import empty_instances
  9. from mmdet.registry import MODELS
  10. from mmdet.structures.bbox import get_box_tensor, scale_boxes
  11. @MODELS.register_module(force=True) # avoid bug
  12. class DeticBBoxHead(Shared2FCBBoxHead):
  13. def __init__(self,
  14. *args,
  15. init_cfg: Optional[Union[dict, ConfigDict]] = None,
  16. **kwargs) -> None:
  17. super().__init__(*args, init_cfg=init_cfg, **kwargs)
  18. # reconstruct fc_cls and fc_reg since input channels are changed
  19. assert self.with_cls
  20. cls_channels = self.num_classes
  21. cls_predictor_cfg_ = self.cls_predictor_cfg.copy()
  22. cls_predictor_cfg_.update(
  23. in_features=self.cls_last_dim, out_features=cls_channels)
  24. self.fc_cls = MODELS.build(cls_predictor_cfg_)
  25. def _predict_by_feat_single(
  26. self,
  27. roi: Tensor,
  28. cls_score: Tensor,
  29. bbox_pred: Tensor,
  30. img_meta: dict,
  31. rescale: bool = False,
  32. rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData:
  33. """Transform a single image's features extracted from the head into
  34. bbox results.
  35. Args:
  36. roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
  37. last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
  38. cls_score (Tensor): Box scores, has shape
  39. (num_boxes, num_classes + 1).
  40. bbox_pred (Tensor): Box energies / deltas.
  41. has shape (num_boxes, num_classes * 4).
  42. img_meta (dict): image information.
  43. rescale (bool): If True, return boxes in original image space.
  44. Defaults to False.
  45. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head.
  46. Defaults to None
  47. Returns:
  48. :obj:`InstanceData`: Detection results of each image\
  49. Each item usually contains following keys.
  50. - scores (Tensor): Classification scores, has a shape
  51. (num_instance, )
  52. - labels (Tensor): Labels of bboxes, has a shape
  53. (num_instances, ).
  54. - bboxes (Tensor): Has a shape (num_instances, 4),
  55. the last dimension 4 arrange as (x1, y1, x2, y2).
  56. """
  57. results = InstanceData()
  58. if roi.shape[0] == 0:
  59. return empty_instances([img_meta],
  60. roi.device,
  61. task_type='bbox',
  62. instance_results=[results],
  63. box_type=self.predict_box_type,
  64. use_box_type=False,
  65. num_classes=self.num_classes,
  66. score_per_cls=rcnn_test_cfg is None)[0]
  67. scores = cls_score
  68. img_shape = img_meta['img_shape']
  69. num_rois = roi.size(0)
  70. num_classes = 1 if self.reg_class_agnostic else self.num_classes
  71. roi = roi.repeat_interleave(num_classes, dim=0)
  72. bbox_pred = bbox_pred.view(-1, self.bbox_coder.encode_size)
  73. bboxes = self.bbox_coder.decode(
  74. roi[..., 1:], bbox_pred, max_shape=img_shape)
  75. if rescale and bboxes.size(0) > 0:
  76. assert img_meta.get('scale_factor') is not None
  77. scale_factor = [1 / s for s in img_meta['scale_factor']]
  78. bboxes = scale_boxes(bboxes, scale_factor)
  79. # Get the inside tensor when `bboxes` is a box type
  80. bboxes = get_box_tensor(bboxes)
  81. box_dim = bboxes.size(-1)
  82. bboxes = bboxes.view(num_rois, -1)
  83. if rcnn_test_cfg is None:
  84. # This means that it is aug test.
  85. # It needs to return the raw results without nms.
  86. results.bboxes = bboxes
  87. results.scores = scores
  88. else:
  89. det_bboxes, det_labels = multiclass_nms(
  90. bboxes,
  91. scores,
  92. rcnn_test_cfg.score_thr,
  93. rcnn_test_cfg.nms,
  94. rcnn_test_cfg.max_per_img,
  95. box_dim=box_dim)
  96. results.bboxes = det_bboxes[:, :-1]
  97. results.scores = det_bboxes[:, -1]
  98. results.labels = det_labels
  99. return results