conditional_detr_head.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Tuple
  3. import torch
  4. import torch.nn as nn
  5. from mmengine.model import bias_init_with_prob
  6. from torch import Tensor
  7. from mmdet.models.layers.transformer import inverse_sigmoid
  8. from mmdet.registry import MODELS
  9. from mmdet.structures import SampleList
  10. from mmdet.utils import InstanceList
  11. from .detr_head import DETRHead
  12. @MODELS.register_module()
  13. class ConditionalDETRHead(DETRHead):
  14. """Head of Conditional DETR. Conditional DETR: Conditional DETR for Fast
  15. Training Convergence. More details can be found in the `paper.
  16. <https://arxiv.org/abs/2108.06152>`_ .
  17. """
  18. def init_weights(self):
  19. """Initialize weights of the transformer head."""
  20. super().init_weights()
  21. # The initialization below for transformer head is very
  22. # important as we use Focal_loss for loss_cls
  23. if self.loss_cls.use_sigmoid:
  24. bias_init = bias_init_with_prob(0.01)
  25. nn.init.constant_(self.fc_cls.bias, bias_init)
  26. def forward(self, hidden_states: Tensor,
  27. references: Tensor) -> Tuple[Tensor, Tensor]:
  28. """"Forward function.
  29. Args:
  30. hidden_states (Tensor): Features from transformer decoder. If
  31. `return_intermediate_dec` is True output has shape
  32. (num_decoder_layers, bs, num_queries, dim), else has shape (1,
  33. bs, num_queries, dim) which only contains the last layer
  34. outputs.
  35. references (Tensor): References from transformer decoder, has
  36. shape (bs, num_queries, 2).
  37. Returns:
  38. tuple[Tensor]: results of head containing the following tensor.
  39. - layers_cls_scores (Tensor): Outputs from the classification head,
  40. shape (num_decoder_layers, bs, num_queries, cls_out_channels).
  41. Note cls_out_channels should include background.
  42. - layers_bbox_preds (Tensor): Sigmoid outputs from the regression
  43. head with normalized coordinate format (cx, cy, w, h), has shape
  44. (num_decoder_layers, bs, num_queries, 4).
  45. """
  46. references_unsigmoid = inverse_sigmoid(references)
  47. layers_bbox_preds = []
  48. for layer_id in range(hidden_states.shape[0]):
  49. tmp_reg_preds = self.fc_reg(
  50. self.activate(self.reg_ffn(hidden_states[layer_id])))
  51. tmp_reg_preds[..., :2] += references_unsigmoid
  52. outputs_coord = tmp_reg_preds.sigmoid()
  53. layers_bbox_preds.append(outputs_coord)
  54. layers_bbox_preds = torch.stack(layers_bbox_preds)
  55. layers_cls_scores = self.fc_cls(hidden_states)
  56. return layers_cls_scores, layers_bbox_preds
  57. def loss(self, hidden_states: Tensor, references: Tensor,
  58. batch_data_samples: SampleList) -> dict:
  59. """Perform forward propagation and loss calculation of the detection
  60. head on the features of the upstream network.
  61. Args:
  62. hidden_states (Tensor): Features from the transformer decoder, has
  63. shape (num_decoder_layers, bs, num_queries, dim).
  64. references (Tensor): References from the transformer decoder, has
  65. shape (num_decoder_layers, bs, num_queries, 2).
  66. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  67. Samples. It usually includes information such as
  68. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  69. Returns:
  70. dict: A dictionary of loss components.
  71. """
  72. batch_gt_instances = []
  73. batch_img_metas = []
  74. for data_sample in batch_data_samples:
  75. batch_img_metas.append(data_sample.metainfo)
  76. batch_gt_instances.append(data_sample.gt_instances)
  77. outs = self(hidden_states, references)
  78. loss_inputs = outs + (batch_gt_instances, batch_img_metas)
  79. losses = self.loss_by_feat(*loss_inputs)
  80. return losses
  81. def loss_and_predict(
  82. self, hidden_states: Tensor, references: Tensor,
  83. batch_data_samples: SampleList) -> Tuple[dict, InstanceList]:
  84. """Perform forward propagation of the head, then calculate loss and
  85. predictions from the features and data samples. Over-write because
  86. img_metas are needed as inputs for bbox_head.
  87. Args:
  88. hidden_states (Tensor): Features from the transformer decoder, has
  89. shape (num_decoder_layers, bs, num_queries, dim).
  90. references (Tensor): References from the transformer decoder, has
  91. shape (num_decoder_layers, bs, num_queries, 2).
  92. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
  93. the meta information of each image and corresponding
  94. annotations.
  95. Returns:
  96. tuple: The return value is a tuple contains:
  97. - losses: (dict[str, Tensor]): A dictionary of loss components.
  98. - predictions (list[:obj:`InstanceData`]): Detection
  99. results of each image after the post process.
  100. """
  101. batch_gt_instances = []
  102. batch_img_metas = []
  103. for data_sample in batch_data_samples:
  104. batch_img_metas.append(data_sample.metainfo)
  105. batch_gt_instances.append(data_sample.gt_instances)
  106. outs = self(hidden_states, references)
  107. loss_inputs = outs + (batch_gt_instances, batch_img_metas)
  108. losses = self.loss_by_feat(*loss_inputs)
  109. predictions = self.predict_by_feat(
  110. *outs, batch_img_metas=batch_img_metas)
  111. return losses, predictions
  112. def predict(self,
  113. hidden_states: Tensor,
  114. references: Tensor,
  115. batch_data_samples: SampleList,
  116. rescale: bool = True) -> InstanceList:
  117. """Perform forward propagation of the detection head and predict
  118. detection results on the features of the upstream network. Over-write
  119. because img_metas are needed as inputs for bbox_head.
  120. Args:
  121. hidden_states (Tensor): Features from the transformer decoder, has
  122. shape (num_decoder_layers, bs, num_queries, dim).
  123. references (Tensor): References from the transformer decoder, has
  124. shape (num_decoder_layers, bs, num_queries, 2).
  125. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  126. Samples. It usually includes information such as
  127. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  128. rescale (bool, optional): Whether to rescale the results.
  129. Defaults to True.
  130. Returns:
  131. list[obj:`InstanceData`]: Detection results of each image
  132. after the post process.
  133. """
  134. batch_img_metas = [
  135. data_samples.metainfo for data_samples in batch_data_samples
  136. ]
  137. last_layer_hidden_state = hidden_states[-1].unsqueeze(0)
  138. outs = self(last_layer_hidden_state, references)
  139. predictions = self.predict_by_feat(
  140. *outs, batch_img_metas=batch_img_metas, rescale=rescale)
  141. return predictions