deformable_detr_head.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. from typing import Dict, List, Tuple
  4. import torch
  5. import torch.nn as nn
  6. from mmcv.cnn import Linear
  7. from mmengine.model import bias_init_with_prob, constant_init
  8. from torch import Tensor
  9. from mmdet.registry import MODELS
  10. from mmdet.structures import SampleList
  11. from mmdet.utils import InstanceList, OptInstanceList
  12. from ..layers import inverse_sigmoid
  13. from .detr_head import DETRHead
  14. @MODELS.register_module()
  15. class DeformableDETRHead(DETRHead):
  16. r"""Head of DeformDETR: Deformable DETR: Deformable Transformers for
  17. End-to-End Object Detection.
  18. Code is modified from the `official github repo
  19. <https://github.com/fundamentalvision/Deformable-DETR>`_.
  20. More details can be found in the `paper
  21. <https://arxiv.org/abs/2010.04159>`_ .
  22. Args:
  23. share_pred_layer (bool): Whether to share parameters for all the
  24. prediction layers. Defaults to `False`.
  25. num_pred_layer (int): The number of the prediction layers.
  26. Defaults to 6.
  27. as_two_stage (bool, optional): Whether to generate the proposal
  28. from the outputs of encoder. Defaults to `False`.
  29. """
  30. def __init__(self,
  31. *args,
  32. share_pred_layer: bool = False,
  33. num_pred_layer: int = 6,
  34. as_two_stage: bool = False,
  35. **kwargs) -> None:
  36. self.share_pred_layer = share_pred_layer
  37. self.num_pred_layer = num_pred_layer
  38. self.as_two_stage = as_two_stage
  39. super().__init__(*args, **kwargs)
  40. def _init_layers(self) -> None:
  41. """Initialize classification branch and regression branch of head."""
  42. fc_cls = Linear(self.embed_dims, self.cls_out_channels)
  43. reg_branch = []
  44. for _ in range(self.num_reg_fcs):
  45. reg_branch.append(Linear(self.embed_dims, self.embed_dims))
  46. reg_branch.append(nn.ReLU())
  47. reg_branch.append(Linear(self.embed_dims, 4))
  48. reg_branch = nn.Sequential(*reg_branch)
  49. if self.share_pred_layer:
  50. self.cls_branches = nn.ModuleList(
  51. [fc_cls for _ in range(self.num_pred_layer)])
  52. self.reg_branches = nn.ModuleList(
  53. [reg_branch for _ in range(self.num_pred_layer)])
  54. else:
  55. self.cls_branches = nn.ModuleList(
  56. [copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)])
  57. self.reg_branches = nn.ModuleList([
  58. copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer)
  59. ])
  60. def init_weights(self) -> None:
  61. """Initialize weights of the Deformable DETR head."""
  62. if self.loss_cls.use_sigmoid:
  63. bias_init = bias_init_with_prob(0.01)
  64. for m in self.cls_branches:
  65. nn.init.constant_(m.bias, bias_init)
  66. for m in self.reg_branches:
  67. constant_init(m[-1], 0, bias=0)
  68. nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0)
  69. if self.as_two_stage:
  70. for m in self.reg_branches:
  71. nn.init.constant_(m[-1].bias.data[2:], 0.0)
  72. def forward(self, hidden_states: Tensor,
  73. references: List[Tensor]) -> Tuple[Tensor]:
  74. """Forward function.
  75. Args:
  76. hidden_states (Tensor): Hidden states output from each decoder
  77. layer, has shape (num_decoder_layers, bs, num_queries, dim).
  78. references (list[Tensor]): List of the reference from the decoder.
  79. The first reference is the `init_reference` (initial) and the
  80. other num_decoder_layers(6) references are `inter_references`
  81. (intermediate). The `init_reference` has shape (bs,
  82. num_queries, 4) when `as_two_stage` of the detector is `True`,
  83. otherwise (bs, num_queries, 2). Each `inter_reference` has
  84. shape (bs, num_queries, 4) when `with_box_refine` of the
  85. detector is `True`, otherwise (bs, num_queries, 2). The
  86. coordinates are arranged as (cx, cy) when the last dimension is
  87. 2, and (cx, cy, w, h) when it is 4.
  88. Returns:
  89. tuple[Tensor]: results of head containing the following tensor.
  90. - all_layers_outputs_classes (Tensor): Outputs from the
  91. classification head, has shape (num_decoder_layers, bs,
  92. num_queries, cls_out_channels).
  93. - all_layers_outputs_coords (Tensor): Sigmoid outputs from the
  94. regression head with normalized coordinate format (cx, cy, w,
  95. h), has shape (num_decoder_layers, bs, num_queries, 4) with the
  96. last dimension arranged as (cx, cy, w, h).
  97. """
  98. all_layers_outputs_classes = []
  99. all_layers_outputs_coords = []
  100. for layer_id in range(hidden_states.shape[0]):
  101. reference = inverse_sigmoid(references[layer_id])
  102. # NOTE The last reference will not be used.
  103. hidden_state = hidden_states[layer_id]
  104. outputs_class = self.cls_branches[layer_id](hidden_state)
  105. tmp_reg_preds = self.reg_branches[layer_id](hidden_state)
  106. if reference.shape[-1] == 4:
  107. # When `layer` is 0 and `as_two_stage` of the detector
  108. # is `True`, or when `layer` is greater than 0 and
  109. # `with_box_refine` of the detector is `True`.
  110. tmp_reg_preds += reference
  111. else:
  112. # When `layer` is 0 and `as_two_stage` of the detector
  113. # is `False`, or when `layer` is greater than 0 and
  114. # `with_box_refine` of the detector is `False`.
  115. assert reference.shape[-1] == 2
  116. tmp_reg_preds[..., :2] += reference
  117. outputs_coord = tmp_reg_preds.sigmoid()
  118. all_layers_outputs_classes.append(outputs_class)
  119. all_layers_outputs_coords.append(outputs_coord)
  120. all_layers_outputs_classes = torch.stack(all_layers_outputs_classes)
  121. all_layers_outputs_coords = torch.stack(all_layers_outputs_coords)
  122. return all_layers_outputs_classes, all_layers_outputs_coords
  123. def loss(self, hidden_states: Tensor, references: List[Tensor],
  124. enc_outputs_class: Tensor, enc_outputs_coord: Tensor,
  125. batch_data_samples: SampleList) -> dict:
  126. """Perform forward propagation and loss calculation of the detection
  127. head on the queries of the upstream network.
  128. Args:
  129. hidden_states (Tensor): Hidden states output from each decoder
  130. layer, has shape (num_decoder_layers, num_queries, bs, dim).
  131. references (list[Tensor]): List of the reference from the decoder.
  132. The first reference is the `init_reference` (initial) and the
  133. other num_decoder_layers(6) references are `inter_references`
  134. (intermediate). The `init_reference` has shape (bs,
  135. num_queries, 4) when `as_two_stage` of the detector is `True`,
  136. otherwise (bs, num_queries, 2). Each `inter_reference` has
  137. shape (bs, num_queries, 4) when `with_box_refine` of the
  138. detector is `True`, otherwise (bs, num_queries, 2). The
  139. coordinates are arranged as (cx, cy) when the last dimension is
  140. 2, and (cx, cy, w, h) when it is 4.
  141. enc_outputs_class (Tensor): The score of each point on encode
  142. feature map, has shape (bs, num_feat_points, cls_out_channels).
  143. Only when `as_two_stage` is `True` it would be passed in,
  144. otherwise it would be `None`.
  145. enc_outputs_coord (Tensor): The proposal generate from the encode
  146. feature map, has shape (bs, num_feat_points, 4) with the last
  147. dimension arranged as (cx, cy, w, h). Only when `as_two_stage`
  148. is `True` it would be passed in, otherwise it would be `None`.
  149. batch_data_samples (list[:obj:`DetDataSample`]): The Data
  150. Samples. It usually includes information such as
  151. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  152. Returns:
  153. dict: A dictionary of loss components.
  154. """
  155. batch_gt_instances = []
  156. batch_img_metas = []
  157. for data_sample in batch_data_samples:
  158. batch_img_metas.append(data_sample.metainfo)
  159. batch_gt_instances.append(data_sample.gt_instances)
  160. outs = self(hidden_states, references)
  161. loss_inputs = outs + (enc_outputs_class, enc_outputs_coord,
  162. batch_gt_instances, batch_img_metas)
  163. losses = self.loss_by_feat(*loss_inputs)
  164. return losses
  165. def loss_by_feat(
  166. self,
  167. all_layers_cls_scores: Tensor,
  168. all_layers_bbox_preds: Tensor,
  169. enc_cls_scores: Tensor,
  170. enc_bbox_preds: Tensor,
  171. batch_gt_instances: InstanceList,
  172. batch_img_metas: List[dict],
  173. batch_gt_instances_ignore: OptInstanceList = None
  174. ) -> Dict[str, Tensor]:
  175. """Loss function.
  176. Args:
  177. all_layers_cls_scores (Tensor): Classification scores of all
  178. decoder layers, has shape (num_decoder_layers, bs, num_queries,
  179. cls_out_channels).
  180. all_layers_bbox_preds (Tensor): Regression outputs of all decoder
  181. layers. Each is a 4D-tensor with normalized coordinate format
  182. (cx, cy, w, h) and has shape (num_decoder_layers, bs,
  183. num_queries, 4) with the last dimension arranged as
  184. (cx, cy, w, h).
  185. enc_cls_scores (Tensor): The score of each point on encode
  186. feature map, has shape (bs, num_feat_points, cls_out_channels).
  187. Only when `as_two_stage` is `True` it would be passes in,
  188. otherwise, it would be `None`.
  189. enc_bbox_preds (Tensor): The proposal generate from the encode
  190. feature map, has shape (bs, num_feat_points, 4) with the last
  191. dimension arranged as (cx, cy, w, h). Only when `as_two_stage`
  192. is `True` it would be passed in, otherwise it would be `None`.
  193. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  194. gt_instance. It usually includes ``bboxes`` and ``labels``
  195. attributes.
  196. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  197. image size, scaling factor, etc.
  198. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  199. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  200. data that is ignored during training and testing.
  201. Defaults to None.
  202. Returns:
  203. dict[str, Tensor]: A dictionary of loss components.
  204. """
  205. loss_dict = super().loss_by_feat(all_layers_cls_scores,
  206. all_layers_bbox_preds,
  207. batch_gt_instances, batch_img_metas,
  208. batch_gt_instances_ignore)
  209. # loss of proposal generated from encode feature map.
  210. if enc_cls_scores is not None:
  211. proposal_gt_instances = copy.deepcopy(batch_gt_instances)
  212. for i in range(len(proposal_gt_instances)):
  213. proposal_gt_instances[i].labels = torch.zeros_like(
  214. proposal_gt_instances[i].labels)
  215. enc_loss_cls, enc_losses_bbox, enc_losses_iou = \
  216. self.loss_by_feat_single(
  217. enc_cls_scores, enc_bbox_preds,
  218. batch_gt_instances=proposal_gt_instances,
  219. batch_img_metas=batch_img_metas)
  220. loss_dict['enc_loss_cls'] = enc_loss_cls
  221. loss_dict['enc_loss_bbox'] = enc_losses_bbox
  222. loss_dict['enc_loss_iou'] = enc_losses_iou
  223. return loss_dict
  224. def predict(self,
  225. hidden_states: Tensor,
  226. references: List[Tensor],
  227. batch_data_samples: SampleList,
  228. rescale: bool = True) -> InstanceList:
  229. """Perform forward propagation and loss calculation of the detection
  230. head on the queries of the upstream network.
  231. Args:
  232. hidden_states (Tensor): Hidden states output from each decoder
  233. layer, has shape (num_decoder_layers, num_queries, bs, dim).
  234. references (list[Tensor]): List of the reference from the decoder.
  235. The first reference is the `init_reference` (initial) and the
  236. other num_decoder_layers(6) references are `inter_references`
  237. (intermediate). The `init_reference` has shape (bs,
  238. num_queries, 4) when `as_two_stage` of the detector is `True`,
  239. otherwise (bs, num_queries, 2). Each `inter_reference` has
  240. shape (bs, num_queries, 4) when `with_box_refine` of the
  241. detector is `True`, otherwise (bs, num_queries, 2). The
  242. coordinates are arranged as (cx, cy) when the last dimension is
  243. 2, and (cx, cy, w, h) when it is 4.
  244. batch_data_samples (list[:obj:`DetDataSample`]): The Data
  245. Samples. It usually includes information such as
  246. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  247. rescale (bool, optional): If `True`, return boxes in original
  248. image space. Defaults to `True`.
  249. Returns:
  250. list[obj:`InstanceData`]: Detection results of each image
  251. after the post process.
  252. """
  253. batch_img_metas = [
  254. data_samples.metainfo for data_samples in batch_data_samples
  255. ]
  256. outs = self(hidden_states, references)
  257. predictions = self.predict_by_feat(
  258. *outs, batch_img_metas=batch_img_metas, rescale=rescale)
  259. return predictions
  260. def predict_by_feat(self,
  261. all_layers_cls_scores: Tensor,
  262. all_layers_bbox_preds: Tensor,
  263. batch_img_metas: List[Dict],
  264. rescale: bool = False) -> InstanceList:
  265. """Transform a batch of output features extracted from the head into
  266. bbox results.
  267. Args:
  268. all_layers_cls_scores (Tensor): Classification scores of all
  269. decoder layers, has shape (num_decoder_layers, bs, num_queries,
  270. cls_out_channels).
  271. all_layers_bbox_preds (Tensor): Regression outputs of all decoder
  272. layers. Each is a 4D-tensor with normalized coordinate format
  273. (cx, cy, w, h) and shape (num_decoder_layers, bs, num_queries,
  274. 4) with the last dimension arranged as (cx, cy, w, h).
  275. batch_img_metas (list[dict]): Meta information of each image.
  276. rescale (bool, optional): If `True`, return boxes in original
  277. image space. Default `False`.
  278. Returns:
  279. list[obj:`InstanceData`]: Detection results of each image
  280. after the post process.
  281. """
  282. cls_scores = all_layers_cls_scores[-1]
  283. bbox_preds = all_layers_bbox_preds[-1]
  284. result_list = []
  285. for img_id in range(len(batch_img_metas)):
  286. cls_score = cls_scores[img_id]
  287. bbox_pred = bbox_preds[img_id]
  288. img_meta = batch_img_metas[img_id]
  289. results = self._predict_by_feat_single(cls_score, bbox_pred,
  290. img_meta, rescale)
  291. result_list.append(results)
  292. return result_list