detr_head.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Tuple
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import Linear
  7. from mmcv.cnn.bricks.transformer import FFN
  8. from mmengine.model import BaseModule
  9. from mmengine.structures import InstanceData
  10. from torch import Tensor
  11. from mmdet.registry import MODELS, TASK_UTILS
  12. from mmdet.structures import SampleList
  13. from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
  14. from mmdet.utils import (ConfigType, InstanceList, OptInstanceList,
  15. OptMultiConfig, reduce_mean)
  16. from ..utils import multi_apply
  17. @MODELS.register_module()
  18. class DETRHead(BaseModule):
  19. r"""Head of DETR. DETR:End-to-End Object Detection with Transformers.
  20. More details can be found in the `paper
  21. <https://arxiv.org/pdf/2005.12872>`_ .
  22. Args:
  23. num_classes (int): Number of categories excluding the background.
  24. embed_dims (int): The dims of Transformer embedding.
  25. num_reg_fcs (int): Number of fully-connected layers used in `FFN`,
  26. which is then used for the regression head. Defaults to 2.
  27. sync_cls_avg_factor (bool): Whether to sync the `avg_factor` of
  28. all ranks. Default to `False`.
  29. loss_cls (:obj:`ConfigDict` or dict): Config of the classification
  30. loss. Defaults to `CrossEntropyLoss`.
  31. loss_bbox (:obj:`ConfigDict` or dict): Config of the regression bbox
  32. loss. Defaults to `L1Loss`.
  33. loss_iou (:obj:`ConfigDict` or dict): Config of the regression iou
  34. loss. Defaults to `GIoULoss`.
  35. train_cfg (:obj:`ConfigDict` or dict): Training config of transformer
  36. head.
  37. test_cfg (:obj:`ConfigDict` or dict): Testing config of transformer
  38. head.
  39. init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
  40. the initialization. Defaults to None.
  41. """
  42. _version = 2
  43. def __init__(
  44. self,
  45. num_classes: int,
  46. embed_dims: int = 256,
  47. num_reg_fcs: int = 2,
  48. sync_cls_avg_factor: bool = False,
  49. loss_cls: ConfigType = dict(
  50. type='CrossEntropyLoss',
  51. bg_cls_weight=0.1,
  52. use_sigmoid=False,
  53. loss_weight=1.0,
  54. class_weight=1.0),
  55. loss_bbox: ConfigType = dict(type='L1Loss', loss_weight=5.0),
  56. loss_iou: ConfigType = dict(type='GIoULoss', loss_weight=2.0),
  57. train_cfg: ConfigType = dict(
  58. assigner=dict(
  59. type='HungarianAssigner',
  60. match_costs=[
  61. dict(type='ClassificationCost', weight=1.),
  62. dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
  63. dict(type='IoUCost', iou_mode='giou', weight=2.0)
  64. ])),
  65. test_cfg: ConfigType = dict(max_per_img=100),
  66. init_cfg: OptMultiConfig = None) -> None:
  67. super().__init__(init_cfg=init_cfg)
  68. self.bg_cls_weight = 0
  69. self.sync_cls_avg_factor = sync_cls_avg_factor
  70. class_weight = loss_cls.get('class_weight', None)
  71. if class_weight is not None and (self.__class__ is DETRHead):
  72. assert isinstance(class_weight, float), 'Expected ' \
  73. 'class_weight to have type float. Found ' \
  74. f'{type(class_weight)}.'
  75. # NOTE following the official DETR repo, bg_cls_weight means
  76. # relative classification weight of the no-object class.
  77. bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
  78. assert isinstance(bg_cls_weight, float), 'Expected ' \
  79. 'bg_cls_weight to have type float. Found ' \
  80. f'{type(bg_cls_weight)}.'
  81. class_weight = torch.ones(num_classes + 1) * class_weight
  82. # set background class as the last indice
  83. class_weight[num_classes] = bg_cls_weight
  84. loss_cls.update({'class_weight': class_weight})
  85. if 'bg_cls_weight' in loss_cls:
  86. loss_cls.pop('bg_cls_weight')
  87. self.bg_cls_weight = bg_cls_weight
  88. if train_cfg:
  89. assert 'assigner' in train_cfg, 'assigner should be provided ' \
  90. 'when train_cfg is set.'
  91. assigner = train_cfg['assigner']
  92. self.assigner = TASK_UTILS.build(assigner)
  93. if train_cfg.get('sampler', None) is not None:
  94. raise RuntimeError('DETR do not build sampler.')
  95. self.num_classes = num_classes
  96. self.embed_dims = embed_dims
  97. self.num_reg_fcs = num_reg_fcs
  98. self.train_cfg = train_cfg
  99. self.test_cfg = test_cfg
  100. self.loss_cls = MODELS.build(loss_cls)
  101. self.loss_bbox = MODELS.build(loss_bbox)
  102. self.loss_iou = MODELS.build(loss_iou)
  103. if self.loss_cls.use_sigmoid:
  104. self.cls_out_channels = num_classes
  105. else:
  106. self.cls_out_channels = num_classes + 1
  107. self._init_layers()
  108. def _init_layers(self) -> None:
  109. """Initialize layers of the transformer head."""
  110. # cls branch
  111. self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
  112. # reg branch
  113. self.activate = nn.ReLU()
  114. self.reg_ffn = FFN(
  115. self.embed_dims,
  116. self.embed_dims,
  117. self.num_reg_fcs,
  118. dict(type='ReLU', inplace=True),
  119. dropout=0.0,
  120. add_residual=False)
  121. # NOTE the activations of reg_branch here is the same as
  122. # those in transformer, but they are actually different
  123. # in DAB-DETR (prelu in transformer and relu in reg_branch)
  124. self.fc_reg = Linear(self.embed_dims, 4)
  125. def forward(self, hidden_states: Tensor) -> Tuple[Tensor]:
  126. """"Forward function.
  127. Args:
  128. hidden_states (Tensor): Features from transformer decoder. If
  129. `return_intermediate_dec` in detr.py is True output has shape
  130. (num_decoder_layers, bs, num_queries, dim), else has shape
  131. (1, bs, num_queries, dim) which only contains the last layer
  132. outputs.
  133. Returns:
  134. tuple[Tensor]: results of head containing the following tensor.
  135. - layers_cls_scores (Tensor): Outputs from the classification head,
  136. shape (num_decoder_layers, bs, num_queries, cls_out_channels).
  137. Note cls_out_channels should include background.
  138. - layers_bbox_preds (Tensor): Sigmoid outputs from the regression
  139. head with normalized coordinate format (cx, cy, w, h), has shape
  140. (num_decoder_layers, bs, num_queries, 4).
  141. """
  142. layers_cls_scores = self.fc_cls(hidden_states)
  143. layers_bbox_preds = self.fc_reg(
  144. self.activate(self.reg_ffn(hidden_states))).sigmoid()
  145. return layers_cls_scores, layers_bbox_preds
  146. def loss(self, hidden_states: Tensor,
  147. batch_data_samples: SampleList) -> dict:
  148. """Perform forward propagation and loss calculation of the detection
  149. head on the features of the upstream network.
  150. Args:
  151. hidden_states (Tensor): Feature from the transformer decoder, has
  152. shape (num_decoder_layers, bs, num_queries, cls_out_channels)
  153. or (num_decoder_layers, num_queries, bs, cls_out_channels).
  154. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  155. Samples. It usually includes information such as
  156. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  157. Returns:
  158. dict: A dictionary of loss components.
  159. """
  160. batch_gt_instances = []
  161. batch_img_metas = []
  162. for data_sample in batch_data_samples:
  163. batch_img_metas.append(data_sample.metainfo)
  164. batch_gt_instances.append(data_sample.gt_instances)
  165. outs = self(hidden_states)
  166. loss_inputs = outs + (batch_gt_instances, batch_img_metas)
  167. losses = self.loss_by_feat(*loss_inputs)
  168. return losses
  169. def loss_by_feat(
  170. self,
  171. all_layers_cls_scores: Tensor,
  172. all_layers_bbox_preds: Tensor,
  173. batch_gt_instances: InstanceList,
  174. batch_img_metas: List[dict],
  175. batch_gt_instances_ignore: OptInstanceList = None
  176. ) -> Dict[str, Tensor]:
  177. """"Loss function.
  178. Only outputs from the last feature level are used for computing
  179. losses by default.
  180. Args:
  181. all_layers_cls_scores (Tensor): Classification outputs
  182. of each decoder layers. Each is a 4D-tensor, has shape
  183. (num_decoder_layers, bs, num_queries, cls_out_channels).
  184. all_layers_bbox_preds (Tensor): Sigmoid regression
  185. outputs of each decoder layers. Each is a 4D-tensor with
  186. normalized coordinate format (cx, cy, w, h) and shape
  187. (num_decoder_layers, bs, num_queries, 4).
  188. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  189. gt_instance. It usually includes ``bboxes`` and ``labels``
  190. attributes.
  191. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  192. image size, scaling factor, etc.
  193. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  194. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  195. data that is ignored during training and testing.
  196. Defaults to None.
  197. Returns:
  198. dict[str, Tensor]: A dictionary of loss components.
  199. """
  200. assert batch_gt_instances_ignore is None, \
  201. f'{self.__class__.__name__} only supports ' \
  202. 'for batch_gt_instances_ignore setting to None.'
  203. losses_cls, losses_bbox, losses_iou = multi_apply(
  204. self.loss_by_feat_single,
  205. all_layers_cls_scores,
  206. all_layers_bbox_preds,
  207. batch_gt_instances=batch_gt_instances,
  208. batch_img_metas=batch_img_metas)
  209. loss_dict = dict()
  210. # loss from the last decoder layer
  211. loss_dict['loss_cls'] = losses_cls[-1]
  212. loss_dict['loss_bbox'] = losses_bbox[-1]
  213. loss_dict['loss_iou'] = losses_iou[-1]
  214. # loss from other decoder layers
  215. num_dec_layer = 0
  216. for loss_cls_i, loss_bbox_i, loss_iou_i in \
  217. zip(losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1]):
  218. loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
  219. loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
  220. loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
  221. num_dec_layer += 1
  222. return loss_dict
  223. def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor,
  224. batch_gt_instances: InstanceList,
  225. batch_img_metas: List[dict]) -> Tuple[Tensor]:
  226. """Loss function for outputs from a single decoder layer of a single
  227. feature level.
  228. Args:
  229. cls_scores (Tensor): Box score logits from a single decoder layer
  230. for all images, has shape (bs, num_queries, cls_out_channels).
  231. bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
  232. for all images, with normalized coordinate (cx, cy, w, h) and
  233. shape (bs, num_queries, 4).
  234. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  235. gt_instance. It usually includes ``bboxes`` and ``labels``
  236. attributes.
  237. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  238. image size, scaling factor, etc.
  239. Returns:
  240. Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and
  241. `loss_iou`.
  242. """
  243. num_imgs = cls_scores.size(0)
  244. cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
  245. bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
  246. cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
  247. batch_gt_instances, batch_img_metas)
  248. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  249. num_total_pos, num_total_neg) = cls_reg_targets
  250. labels = torch.cat(labels_list, 0)
  251. label_weights = torch.cat(label_weights_list, 0)
  252. bbox_targets = torch.cat(bbox_targets_list, 0)
  253. bbox_weights = torch.cat(bbox_weights_list, 0)
  254. # classification loss
  255. cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
  256. # construct weighted avg_factor to match with the official DETR repo
  257. cls_avg_factor = num_total_pos * 1.0 + \
  258. num_total_neg * self.bg_cls_weight
  259. if self.sync_cls_avg_factor:
  260. cls_avg_factor = reduce_mean(
  261. cls_scores.new_tensor([cls_avg_factor]))
  262. cls_avg_factor = max(cls_avg_factor, 1)
  263. loss_cls = self.loss_cls(
  264. cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
  265. # Compute the average number of gt boxes across all gpus, for
  266. # normalization purposes
  267. num_total_pos = loss_cls.new_tensor([num_total_pos])
  268. num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
  269. # construct factors used for rescale bboxes
  270. factors = []
  271. for img_meta, bbox_pred in zip(batch_img_metas, bbox_preds):
  272. img_h, img_w, = img_meta['img_shape']
  273. factor = bbox_pred.new_tensor([img_w, img_h, img_w,
  274. img_h]).unsqueeze(0).repeat(
  275. bbox_pred.size(0), 1)
  276. factors.append(factor)
  277. factors = torch.cat(factors, 0)
  278. # DETR regress the relative position of boxes (cxcywh) in the image,
  279. # thus the learning target is normalized by the image size. So here
  280. # we need to re-scale them for calculating IoU loss
  281. bbox_preds = bbox_preds.reshape(-1, 4)
  282. bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
  283. bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
  284. # regression IoU loss, defaultly GIoU loss
  285. loss_iou = self.loss_iou(
  286. bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
  287. # regression L1 loss
  288. loss_bbox = self.loss_bbox(
  289. bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
  290. return loss_cls, loss_bbox, loss_iou
  291. def get_targets(self, cls_scores_list: List[Tensor],
  292. bbox_preds_list: List[Tensor],
  293. batch_gt_instances: InstanceList,
  294. batch_img_metas: List[dict]) -> tuple:
  295. """Compute regression and classification targets for a batch image.
  296. Outputs from a single decoder layer of a single feature level are used.
  297. Args:
  298. cls_scores_list (list[Tensor]): Box score logits from a single
  299. decoder layer for each image, has shape [num_queries,
  300. cls_out_channels].
  301. bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
  302. decoder layer for each image, with normalized coordinate
  303. (cx, cy, w, h) and shape [num_queries, 4].
  304. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  305. gt_instance. It usually includes ``bboxes`` and ``labels``
  306. attributes.
  307. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  308. image size, scaling factor, etc.
  309. Returns:
  310. tuple: a tuple containing the following targets.
  311. - labels_list (list[Tensor]): Labels for all images.
  312. - label_weights_list (list[Tensor]): Label weights for all images.
  313. - bbox_targets_list (list[Tensor]): BBox targets for all images.
  314. - bbox_weights_list (list[Tensor]): BBox weights for all images.
  315. - num_total_pos (int): Number of positive samples in all images.
  316. - num_total_neg (int): Number of negative samples in all images.
  317. """
  318. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  319. pos_inds_list,
  320. neg_inds_list) = multi_apply(self._get_targets_single,
  321. cls_scores_list, bbox_preds_list,
  322. batch_gt_instances, batch_img_metas)
  323. num_total_pos = sum((inds.numel() for inds in pos_inds_list))
  324. num_total_neg = sum((inds.numel() for inds in neg_inds_list))
  325. return (labels_list, label_weights_list, bbox_targets_list,
  326. bbox_weights_list, num_total_pos, num_total_neg)
  327. def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor,
  328. gt_instances: InstanceData,
  329. img_meta: dict) -> tuple:
  330. """Compute regression and classification targets for one image.
  331. Outputs from a single decoder layer of a single feature level are used.
  332. Args:
  333. cls_score (Tensor): Box score logits from a single decoder layer
  334. for one image. Shape [num_queries, cls_out_channels].
  335. bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
  336. for one image, with normalized coordinate (cx, cy, w, h) and
  337. shape [num_queries, 4].
  338. gt_instances (:obj:`InstanceData`): Ground truth of instance
  339. annotations. It should includes ``bboxes`` and ``labels``
  340. attributes.
  341. img_meta (dict): Meta information for one image.
  342. Returns:
  343. tuple[Tensor]: a tuple containing the following for one image.
  344. - labels (Tensor): Labels of each image.
  345. - label_weights (Tensor]): Label weights of each image.
  346. - bbox_targets (Tensor): BBox targets of each image.
  347. - bbox_weights (Tensor): BBox weights of each image.
  348. - pos_inds (Tensor): Sampled positive indices for each image.
  349. - neg_inds (Tensor): Sampled negative indices for each image.
  350. """
  351. img_h, img_w = img_meta['img_shape']
  352. factor = bbox_pred.new_tensor([img_w, img_h, img_w,
  353. img_h]).unsqueeze(0)
  354. num_bboxes = bbox_pred.size(0)
  355. # convert bbox_pred from xywh, normalized to xyxy, unnormalized
  356. bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
  357. bbox_pred = bbox_pred * factor
  358. pred_instances = InstanceData(scores=cls_score, bboxes=bbox_pred)
  359. # assigner and sampler
  360. assign_result = self.assigner.assign(
  361. pred_instances=pred_instances,
  362. gt_instances=gt_instances,
  363. img_meta=img_meta)
  364. gt_bboxes = gt_instances.bboxes
  365. gt_labels = gt_instances.labels
  366. pos_inds = torch.nonzero(
  367. assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
  368. neg_inds = torch.nonzero(
  369. assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
  370. pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
  371. pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds.long(), :]
  372. # label targets
  373. labels = gt_bboxes.new_full((num_bboxes, ),
  374. self.num_classes,
  375. dtype=torch.long)
  376. labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
  377. label_weights = gt_bboxes.new_ones(num_bboxes)
  378. # bbox targets
  379. bbox_targets = torch.zeros_like(bbox_pred)
  380. bbox_weights = torch.zeros_like(bbox_pred)
  381. bbox_weights[pos_inds] = 1.0
  382. # DETR regress the relative position of boxes (cxcywh) in the image.
  383. # Thus the learning target should be normalized by the image size, also
  384. # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
  385. pos_gt_bboxes_normalized = pos_gt_bboxes / factor
  386. pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
  387. bbox_targets[pos_inds] = pos_gt_bboxes_targets
  388. return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
  389. neg_inds)
  390. def loss_and_predict(
  391. self, hidden_states: Tuple[Tensor],
  392. batch_data_samples: SampleList) -> Tuple[dict, InstanceList]:
  393. """Perform forward propagation of the head, then calculate loss and
  394. predictions from the features and data samples. Over-write because
  395. img_metas are needed as inputs for bbox_head.
  396. Args:
  397. hidden_states (tuple[Tensor]): Feature from the transformer
  398. decoder, has shape (num_decoder_layers, bs, num_queries, dim).
  399. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
  400. the meta information of each image and corresponding
  401. annotations.
  402. Returns:
  403. tuple: the return value is a tuple contains:
  404. - losses: (dict[str, Tensor]): A dictionary of loss components.
  405. - predictions (list[:obj:`InstanceData`]): Detection
  406. results of each image after the post process.
  407. """
  408. batch_gt_instances = []
  409. batch_img_metas = []
  410. for data_sample in batch_data_samples:
  411. batch_img_metas.append(data_sample.metainfo)
  412. batch_gt_instances.append(data_sample.gt_instances)
  413. outs = self(hidden_states)
  414. loss_inputs = outs + (batch_gt_instances, batch_img_metas)
  415. losses = self.loss_by_feat(*loss_inputs)
  416. predictions = self.predict_by_feat(
  417. *outs, batch_img_metas=batch_img_metas)
  418. return losses, predictions
  419. def predict(self,
  420. hidden_states: Tuple[Tensor],
  421. batch_data_samples: SampleList,
  422. rescale: bool = True) -> InstanceList:
  423. """Perform forward propagation of the detection head and predict
  424. detection results on the features of the upstream network. Over-write
  425. because img_metas are needed as inputs for bbox_head.
  426. Args:
  427. hidden_states (tuple[Tensor]): Multi-level features from the
  428. upstream network, each is a 4D-tensor.
  429. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  430. Samples. It usually includes information such as
  431. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  432. rescale (bool, optional): Whether to rescale the results.
  433. Defaults to True.
  434. Returns:
  435. list[obj:`InstanceData`]: Detection results of each image
  436. after the post process.
  437. """
  438. batch_img_metas = [
  439. data_samples.metainfo for data_samples in batch_data_samples
  440. ]
  441. last_layer_hidden_state = hidden_states[-1].unsqueeze(0)
  442. outs = self(last_layer_hidden_state)
  443. predictions = self.predict_by_feat(
  444. *outs, batch_img_metas=batch_img_metas, rescale=rescale)
  445. return predictions
  446. def predict_by_feat(self,
  447. layer_cls_scores: Tensor,
  448. layer_bbox_preds: Tensor,
  449. batch_img_metas: List[dict],
  450. rescale: bool = True) -> InstanceList:
  451. """Transform network outputs for a batch into bbox predictions.
  452. Args:
  453. layer_cls_scores (Tensor): Classification outputs of the last or
  454. all decoder layer. Each is a 4D-tensor, has shape
  455. (num_decoder_layers, bs, num_queries, cls_out_channels).
  456. layer_bbox_preds (Tensor): Sigmoid regression outputs of the last
  457. or all decoder layer. Each is a 4D-tensor with normalized
  458. coordinate format (cx, cy, w, h) and shape
  459. (num_decoder_layers, bs, num_queries, 4).
  460. batch_img_metas (list[dict]): Meta information of each image.
  461. rescale (bool, optional): If `True`, return boxes in original
  462. image space. Defaults to `True`.
  463. Returns:
  464. list[:obj:`InstanceData`]: Object detection results of each image
  465. after the post process. Each item usually contains following keys.
  466. - scores (Tensor): Classification scores, has a shape
  467. (num_instance, )
  468. - labels (Tensor): Labels of bboxes, has a shape
  469. (num_instances, ).
  470. - bboxes (Tensor): Has a shape (num_instances, 4),
  471. the last dimension 4 arrange as (x1, y1, x2, y2).
  472. """
  473. # NOTE only using outputs from the last feature level,
  474. # and only the outputs from the last decoder layer is used.
  475. cls_scores = layer_cls_scores[-1]
  476. bbox_preds = layer_bbox_preds[-1]
  477. result_list = []
  478. for img_id in range(len(batch_img_metas)):
  479. cls_score = cls_scores[img_id]
  480. bbox_pred = bbox_preds[img_id]
  481. img_meta = batch_img_metas[img_id]
  482. results = self._predict_by_feat_single(cls_score, bbox_pred,
  483. img_meta, rescale)
  484. result_list.append(results)
  485. return result_list
  486. def _predict_by_feat_single(self,
  487. cls_score: Tensor,
  488. bbox_pred: Tensor,
  489. img_meta: dict,
  490. rescale: bool = True) -> InstanceData:
  491. """Transform outputs from the last decoder layer into bbox predictions
  492. for each image.
  493. Args:
  494. cls_score (Tensor): Box score logits from the last decoder layer
  495. for each image. Shape [num_queries, cls_out_channels].
  496. bbox_pred (Tensor): Sigmoid outputs from the last decoder layer
  497. for each image, with coordinate format (cx, cy, w, h) and
  498. shape [num_queries, 4].
  499. img_meta (dict): Image meta info.
  500. rescale (bool): If True, return boxes in original image
  501. space. Default True.
  502. Returns:
  503. :obj:`InstanceData`: Detection results of each image
  504. after the post process.
  505. Each item usually contains following keys.
  506. - scores (Tensor): Classification scores, has a shape
  507. (num_instance, )
  508. - labels (Tensor): Labels of bboxes, has a shape
  509. (num_instances, ).
  510. - bboxes (Tensor): Has a shape (num_instances, 4),
  511. the last dimension 4 arrange as (x1, y1, x2, y2).
  512. """
  513. assert len(cls_score) == len(bbox_pred) # num_queries
  514. max_per_img = self.test_cfg.get('max_per_img', len(cls_score))
  515. img_shape = img_meta['img_shape']
  516. # exclude background
  517. if self.loss_cls.use_sigmoid:
  518. cls_score = cls_score.sigmoid()
  519. scores, indexes = cls_score.view(-1).topk(max_per_img)
  520. det_labels = indexes % self.num_classes
  521. bbox_index = indexes // self.num_classes
  522. bbox_pred = bbox_pred[bbox_index]
  523. else:
  524. scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1)
  525. scores, bbox_index = scores.topk(max_per_img)
  526. bbox_pred = bbox_pred[bbox_index]
  527. det_labels = det_labels[bbox_index]
  528. det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
  529. det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
  530. det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
  531. det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1])
  532. det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0])
  533. if rescale:
  534. assert img_meta.get('scale_factor') is not None
  535. det_bboxes /= det_bboxes.new_tensor(
  536. img_meta['scale_factor']).repeat((1, 2))
  537. results = InstanceData()
  538. results.bboxes = det_bboxes
  539. results.scores = scores
  540. results.labels = det_labels
  541. return results