embedding_rpn_head.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List
  3. import torch
  4. import torch.nn as nn
  5. from mmengine.model import BaseModule
  6. from mmengine.structures import InstanceData
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
  10. from mmdet.structures.det_data_sample import SampleList
  11. from mmdet.utils import InstanceList, OptConfigType
  12. @MODELS.register_module()
  13. class EmbeddingRPNHead(BaseModule):
  14. """RPNHead in the `Sparse R-CNN <https://arxiv.org/abs/2011.12450>`_ .
  15. Unlike traditional RPNHead, this module does not need FPN input, but just
  16. decode `init_proposal_bboxes` and expand the first dimension of
  17. `init_proposal_bboxes` and `init_proposal_features` to the batch_size.
  18. Args:
  19. num_proposals (int): Number of init_proposals. Defaults to 100.
  20. proposal_feature_channel (int): Channel number of
  21. init_proposal_feature. Defaults to 256.
  22. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  23. dict]): Initialization config dict. Defaults to None.
  24. """
  25. def __init__(self,
  26. num_proposals: int = 100,
  27. proposal_feature_channel: int = 256,
  28. init_cfg: OptConfigType = None,
  29. **kwargs) -> None:
  30. # `**kwargs` is necessary to avoid some potential error.
  31. assert init_cfg is None, 'To prevent abnormal initialization ' \
  32. 'behavior, init_cfg is not allowed to be set'
  33. super().__init__(init_cfg=init_cfg)
  34. self.num_proposals = num_proposals
  35. self.proposal_feature_channel = proposal_feature_channel
  36. self._init_layers()
  37. def _init_layers(self) -> None:
  38. """Initialize a sparse set of proposal boxes and proposal features."""
  39. self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4)
  40. self.init_proposal_features = nn.Embedding(
  41. self.num_proposals, self.proposal_feature_channel)
  42. def init_weights(self) -> None:
  43. """Initialize the init_proposal_bboxes as normalized.
  44. [c_x, c_y, w, h], and we initialize it to the size of the entire
  45. image.
  46. """
  47. super().init_weights()
  48. nn.init.constant_(self.init_proposal_bboxes.weight[:, :2], 0.5)
  49. nn.init.constant_(self.init_proposal_bboxes.weight[:, 2:], 1)
  50. def _decode_init_proposals(self, x: List[Tensor],
  51. batch_data_samples: SampleList) -> InstanceList:
  52. """Decode init_proposal_bboxes according to the size of images and
  53. expand dimension of init_proposal_features to batch_size.
  54. Args:
  55. x (list[Tensor]): List of FPN features.
  56. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  57. Samples. It usually includes information such as
  58. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  59. Returns:
  60. List[:obj:`InstanceData`:] Detection results of each image.
  61. Each item usually contains following keys.
  62. - proposals: Decoded proposal bboxes,
  63. has shape (num_proposals, 4).
  64. - features: init_proposal_features, expanded proposal
  65. features, has shape
  66. (num_proposals, proposal_feature_channel).
  67. - imgs_whwh: Tensor with shape
  68. (num_proposals, 4), the dimension means
  69. [img_width, img_height, img_width, img_height].
  70. """
  71. batch_img_metas = []
  72. for data_sample in batch_data_samples:
  73. batch_img_metas.append(data_sample.metainfo)
  74. proposals = self.init_proposal_bboxes.weight.clone()
  75. proposals = bbox_cxcywh_to_xyxy(proposals)
  76. imgs_whwh = []
  77. for meta in batch_img_metas:
  78. h, w = meta['img_shape'][:2]
  79. imgs_whwh.append(x[0].new_tensor([[w, h, w, h]]))
  80. imgs_whwh = torch.cat(imgs_whwh, dim=0)
  81. imgs_whwh = imgs_whwh[:, None, :]
  82. proposals = proposals * imgs_whwh
  83. rpn_results_list = []
  84. for idx in range(len(batch_img_metas)):
  85. rpn_results = InstanceData()
  86. rpn_results.bboxes = proposals[idx]
  87. rpn_results.imgs_whwh = imgs_whwh[idx].repeat(
  88. self.num_proposals, 1)
  89. rpn_results.features = self.init_proposal_features.weight.clone()
  90. rpn_results_list.append(rpn_results)
  91. return rpn_results_list
  92. def loss(self, *args, **kwargs):
  93. """Perform forward propagation and loss calculation of the detection
  94. head on the features of the upstream network."""
  95. raise NotImplementedError(
  96. 'EmbeddingRPNHead does not have `loss`, please use '
  97. '`predict` or `loss_and_predict` instead.')
  98. def predict(self, x: List[Tensor], batch_data_samples: SampleList,
  99. **kwargs) -> InstanceList:
  100. """Perform forward propagation of the detection head and predict
  101. detection results on the features of the upstream network."""
  102. # `**kwargs` is necessary to avoid some potential error.
  103. return self._decode_init_proposals(
  104. x=x, batch_data_samples=batch_data_samples)
  105. def loss_and_predict(self, x: List[Tensor], batch_data_samples: SampleList,
  106. **kwargs) -> tuple:
  107. """Perform forward propagation of the head, then calculate loss and
  108. predictions from the features and data samples."""
  109. # `**kwargs` is necessary to avoid some potential error.
  110. predictions = self._decode_init_proposals(
  111. x=x, batch_data_samples=batch_data_samples)
  112. return dict(), predictions