dab_detr_head.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Tuple
  3. import torch.nn as nn
  4. from mmcv.cnn import Linear
  5. from mmengine.model import bias_init_with_prob, constant_init
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from mmdet.structures import SampleList
  9. from mmdet.utils import InstanceList
  10. from ..layers import MLP, inverse_sigmoid
  11. from .conditional_detr_head import ConditionalDETRHead
  12. @MODELS.register_module()
  13. class DABDETRHead(ConditionalDETRHead):
  14. """Head of DAB-DETR. DAB-DETR: Dynamic Anchor Boxes are Better Queries for
  15. DETR.
  16. More details can be found in the `paper
  17. <https://arxiv.org/abs/2201.12329>`_ .
  18. """
  19. def _init_layers(self) -> None:
  20. """Initialize layers of the transformer head."""
  21. # cls branch
  22. self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
  23. # reg branch
  24. self.fc_reg = MLP(self.embed_dims, self.embed_dims, 4, 3)
  25. def init_weights(self) -> None:
  26. """initialize weights."""
  27. if self.loss_cls.use_sigmoid:
  28. bias_init = bias_init_with_prob(0.01)
  29. nn.init.constant_(self.fc_cls.bias, bias_init)
  30. constant_init(self.fc_reg.layers[-1], 0., bias=0.)
  31. def forward(self, hidden_states: Tensor,
  32. references: Tensor) -> Tuple[Tensor, Tensor]:
  33. """"Forward function.
  34. Args:
  35. hidden_states (Tensor): Features from transformer decoder. If
  36. `return_intermediate_dec` is True output has shape
  37. (num_decoder_layers, bs, num_queries, dim), else has shape (1,
  38. bs, num_queries, dim) which only contains the last layer
  39. outputs.
  40. references (Tensor): References from transformer decoder. If
  41. `return_intermediate_dec` is True output has shape
  42. (num_decoder_layers, bs, num_queries, 2/4), else has shape (1,
  43. bs, num_queries, 2/4)
  44. which only contains the last layer reference.
  45. Returns:
  46. tuple[Tensor]: results of head containing the following tensor.
  47. - layers_cls_scores (Tensor): Outputs from the classification head,
  48. shape (num_decoder_layers, bs, num_queries, cls_out_channels).
  49. Note cls_out_channels should include background.
  50. - layers_bbox_preds (Tensor): Sigmoid outputs from the regression
  51. head with normalized coordinate format (cx, cy, w, h), has shape
  52. (num_decoder_layers, bs, num_queries, 4).
  53. """
  54. layers_cls_scores = self.fc_cls(hidden_states)
  55. references_before_sigmoid = inverse_sigmoid(references, eps=1e-3)
  56. tmp_reg_preds = self.fc_reg(hidden_states)
  57. tmp_reg_preds[..., :references_before_sigmoid.
  58. size(-1)] += references_before_sigmoid
  59. layers_bbox_preds = tmp_reg_preds.sigmoid()
  60. return layers_cls_scores, layers_bbox_preds
  61. def predict(self,
  62. hidden_states: Tensor,
  63. references: Tensor,
  64. batch_data_samples: SampleList,
  65. rescale: bool = True) -> InstanceList:
  66. """Perform forward propagation of the detection head and predict
  67. detection results on the features of the upstream network. Over-write
  68. because img_metas are needed as inputs for bbox_head.
  69. Args:
  70. hidden_states (Tensor): Feature from the transformer decoder, has
  71. shape (num_decoder_layers, bs, num_queries, dim).
  72. references (Tensor): references from the transformer decoder, has
  73. shape (num_decoder_layers, bs, num_queries, 2/4).
  74. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  75. Samples. It usually includes information such as
  76. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  77. rescale (bool, optional): Whether to rescale the results.
  78. Defaults to True.
  79. Returns:
  80. list[obj:`InstanceData`]: Detection results of each image
  81. after the post process.
  82. """
  83. batch_img_metas = [
  84. data_samples.metainfo for data_samples in batch_data_samples
  85. ]
  86. last_layer_hidden_state = hidden_states[-1].unsqueeze(0)
  87. last_layer_reference = references[-1].unsqueeze(0)
  88. outs = self(last_layer_hidden_state, last_layer_reference)
  89. predictions = self.predict_by_feat(
  90. *outs, batch_img_metas=batch_img_metas, rescale=rescale)
  91. return predictions