paa_head.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import numpy as np
  4. import torch
  5. from mmengine.structures import InstanceData
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from mmdet.structures.bbox import bbox_overlaps
  9. from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
  10. OptInstanceList)
  11. from ..layers import multiclass_nms
  12. from ..utils import levels_to_images, multi_apply
  13. from . import ATSSHead
  14. EPS = 1e-12
  15. try:
  16. import sklearn.mixture as skm
  17. except ImportError:
  18. skm = None
  19. @MODELS.register_module()
  20. class PAAHead(ATSSHead):
  21. """Head of PAAAssignment: Probabilistic Anchor Assignment with IoU
  22. Prediction for Object Detection.
  23. Code is modified from the `official github repo
  24. <https://github.com/kkhoot/PAA/blob/master/paa_core
  25. /modeling/rpn/paa/loss.py>`_.
  26. More details can be found in the `paper
  27. <https://arxiv.org/abs/2007.08103>`_ .
  28. Args:
  29. topk (int): Select topk samples with smallest loss in
  30. each level.
  31. score_voting (bool): Whether to use score voting in post-process.
  32. covariance_type : String describing the type of covariance parameters
  33. to be used in :class:`sklearn.mixture.GaussianMixture`.
  34. It must be one of:
  35. - 'full': each component has its own general covariance matrix
  36. - 'tied': all components share the same general covariance matrix
  37. - 'diag': each component has its own diagonal covariance matrix
  38. - 'spherical': each component has its own single variance
  39. Default: 'diag'. From 'full' to 'spherical', the gmm fitting
  40. process is faster yet the performance could be influenced. For most
  41. cases, 'diag' should be a good choice.
  42. """
  43. def __init__(self,
  44. *args,
  45. topk: int = 9,
  46. score_voting: bool = True,
  47. covariance_type: str = 'diag',
  48. **kwargs):
  49. # topk used in paa reassign process
  50. self.topk = topk
  51. self.with_score_voting = score_voting
  52. self.covariance_type = covariance_type
  53. super().__init__(*args, **kwargs)
  54. def loss_by_feat(
  55. self,
  56. cls_scores: List[Tensor],
  57. bbox_preds: List[Tensor],
  58. iou_preds: List[Tensor],
  59. batch_gt_instances: InstanceList,
  60. batch_img_metas: List[dict],
  61. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  62. """Calculate the loss based on the features extracted by the detection
  63. head.
  64. Args:
  65. cls_scores (list[Tensor]): Box scores for each scale level
  66. Has shape (N, num_anchors * num_classes, H, W)
  67. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  68. level with shape (N, num_anchors * 4, H, W)
  69. iou_preds (list[Tensor]): iou_preds for each scale
  70. level with shape (N, num_anchors * 1, H, W)
  71. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  72. gt_instance. It usually includes ``bboxes`` and ``labels``
  73. attributes.
  74. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  75. image size, scaling factor, etc.
  76. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  77. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  78. data that is ignored during training and testing.
  79. Defaults to None.
  80. Returns:
  81. dict[str, Tensor]: A dictionary of loss gmm_assignment.
  82. """
  83. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  84. assert len(featmap_sizes) == self.prior_generator.num_levels
  85. device = cls_scores[0].device
  86. anchor_list, valid_flag_list = self.get_anchors(
  87. featmap_sizes, batch_img_metas, device=device)
  88. cls_reg_targets = self.get_targets(
  89. anchor_list,
  90. valid_flag_list,
  91. batch_gt_instances,
  92. batch_img_metas,
  93. batch_gt_instances_ignore=batch_gt_instances_ignore,
  94. )
  95. (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds,
  96. pos_gt_index) = cls_reg_targets
  97. cls_scores = levels_to_images(cls_scores)
  98. cls_scores = [
  99. item.reshape(-1, self.cls_out_channels) for item in cls_scores
  100. ]
  101. bbox_preds = levels_to_images(bbox_preds)
  102. bbox_preds = [item.reshape(-1, 4) for item in bbox_preds]
  103. iou_preds = levels_to_images(iou_preds)
  104. iou_preds = [item.reshape(-1, 1) for item in iou_preds]
  105. pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list,
  106. cls_scores, bbox_preds, labels,
  107. labels_weight, bboxes_target,
  108. bboxes_weight, pos_inds)
  109. with torch.no_grad():
  110. reassign_labels, reassign_label_weight, \
  111. reassign_bbox_weights, num_pos = multi_apply(
  112. self.paa_reassign,
  113. pos_losses_list,
  114. labels,
  115. labels_weight,
  116. bboxes_weight,
  117. pos_inds,
  118. pos_gt_index,
  119. anchor_list)
  120. num_pos = sum(num_pos)
  121. # convert all tensor list to a flatten tensor
  122. cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1))
  123. bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1))
  124. iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1))
  125. labels = torch.cat(reassign_labels, 0).view(-1)
  126. flatten_anchors = torch.cat(
  127. [torch.cat(item, 0) for item in anchor_list])
  128. labels_weight = torch.cat(reassign_label_weight, 0).view(-1)
  129. bboxes_target = torch.cat(bboxes_target,
  130. 0).view(-1, bboxes_target[0].size(-1))
  131. pos_inds_flatten = ((labels >= 0)
  132. &
  133. (labels < self.num_classes)).nonzero().reshape(-1)
  134. losses_cls = self.loss_cls(
  135. cls_scores,
  136. labels,
  137. labels_weight,
  138. avg_factor=max(num_pos, len(batch_img_metas))) # avoid num_pos=0
  139. if num_pos:
  140. pos_bbox_pred = self.bbox_coder.decode(
  141. flatten_anchors[pos_inds_flatten],
  142. bbox_preds[pos_inds_flatten])
  143. pos_bbox_target = bboxes_target[pos_inds_flatten]
  144. iou_target = bbox_overlaps(
  145. pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True)
  146. losses_iou = self.loss_centerness(
  147. iou_preds[pos_inds_flatten],
  148. iou_target.unsqueeze(-1),
  149. avg_factor=num_pos)
  150. losses_bbox = self.loss_bbox(
  151. pos_bbox_pred,
  152. pos_bbox_target,
  153. iou_target.clamp(min=EPS),
  154. avg_factor=iou_target.sum())
  155. else:
  156. losses_iou = iou_preds.sum() * 0
  157. losses_bbox = bbox_preds.sum() * 0
  158. return dict(
  159. loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou)
  160. def get_pos_loss(self, anchors: List[Tensor], cls_score: Tensor,
  161. bbox_pred: Tensor, label: Tensor, label_weight: Tensor,
  162. bbox_target: dict, bbox_weight: Tensor,
  163. pos_inds: Tensor) -> Tensor:
  164. """Calculate loss of all potential positive samples obtained from first
  165. match process.
  166. Args:
  167. anchors (list[Tensor]): Anchors of each scale.
  168. cls_score (Tensor): Box scores of single image with shape
  169. (num_anchors, num_classes)
  170. bbox_pred (Tensor): Box energies / deltas of single image
  171. with shape (num_anchors, 4)
  172. label (Tensor): classification target of each anchor with
  173. shape (num_anchors,)
  174. label_weight (Tensor): Classification loss weight of each
  175. anchor with shape (num_anchors).
  176. bbox_target (dict): Regression target of each anchor with
  177. shape (num_anchors, 4).
  178. bbox_weight (Tensor): Bbox weight of each anchor with shape
  179. (num_anchors, 4).
  180. pos_inds (Tensor): Index of all positive samples got from
  181. first assign process.
  182. Returns:
  183. Tensor: Losses of all positive samples in single image.
  184. """
  185. if not len(pos_inds):
  186. return cls_score.new([]),
  187. anchors_all_level = torch.cat(anchors, 0)
  188. pos_scores = cls_score[pos_inds]
  189. pos_bbox_pred = bbox_pred[pos_inds]
  190. pos_label = label[pos_inds]
  191. pos_label_weight = label_weight[pos_inds]
  192. pos_bbox_target = bbox_target[pos_inds]
  193. pos_bbox_weight = bbox_weight[pos_inds]
  194. pos_anchors = anchors_all_level[pos_inds]
  195. pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred)
  196. # to keep loss dimension
  197. loss_cls = self.loss_cls(
  198. pos_scores,
  199. pos_label,
  200. pos_label_weight,
  201. avg_factor=1.0,
  202. reduction_override='none')
  203. loss_bbox = self.loss_bbox(
  204. pos_bbox_pred,
  205. pos_bbox_target,
  206. pos_bbox_weight,
  207. avg_factor=1.0, # keep same loss weight before reassign
  208. reduction_override='none')
  209. loss_cls = loss_cls.sum(-1)
  210. pos_loss = loss_bbox + loss_cls
  211. return pos_loss,
  212. def paa_reassign(self, pos_losses: Tensor, label: Tensor,
  213. label_weight: Tensor, bbox_weight: Tensor,
  214. pos_inds: Tensor, pos_gt_inds: Tensor,
  215. anchors: List[Tensor]) -> tuple:
  216. """Fit loss to GMM distribution and separate positive, ignore, negative
  217. samples again with GMM model.
  218. Args:
  219. pos_losses (Tensor): Losses of all positive samples in
  220. single image.
  221. label (Tensor): classification target of each anchor with
  222. shape (num_anchors,)
  223. label_weight (Tensor): Classification loss weight of each
  224. anchor with shape (num_anchors).
  225. bbox_weight (Tensor): Bbox weight of each anchor with shape
  226. (num_anchors, 4).
  227. pos_inds (Tensor): Index of all positive samples got from
  228. first assign process.
  229. pos_gt_inds (Tensor): Gt_index of all positive samples got
  230. from first assign process.
  231. anchors (list[Tensor]): Anchors of each scale.
  232. Returns:
  233. tuple: Usually returns a tuple containing learning targets.
  234. - label (Tensor): classification target of each anchor after
  235. paa assign, with shape (num_anchors,)
  236. - label_weight (Tensor): Classification loss weight of each
  237. anchor after paa assign, with shape (num_anchors).
  238. - bbox_weight (Tensor): Bbox weight of each anchor with shape
  239. (num_anchors, 4).
  240. - num_pos (int): The number of positive samples after paa
  241. assign.
  242. """
  243. if not len(pos_inds):
  244. return label, label_weight, bbox_weight, 0
  245. label = label.clone()
  246. label_weight = label_weight.clone()
  247. bbox_weight = bbox_weight.clone()
  248. num_gt = pos_gt_inds.max() + 1
  249. num_level = len(anchors)
  250. num_anchors_each_level = [item.size(0) for item in anchors]
  251. num_anchors_each_level.insert(0, 0)
  252. inds_level_interval = np.cumsum(num_anchors_each_level)
  253. pos_level_mask = []
  254. for i in range(num_level):
  255. mask = (pos_inds >= inds_level_interval[i]) & (
  256. pos_inds < inds_level_interval[i + 1])
  257. pos_level_mask.append(mask)
  258. pos_inds_after_paa = [label.new_tensor([])]
  259. ignore_inds_after_paa = [label.new_tensor([])]
  260. for gt_ind in range(num_gt):
  261. pos_inds_gmm = []
  262. pos_loss_gmm = []
  263. gt_mask = pos_gt_inds == gt_ind
  264. for level in range(num_level):
  265. level_mask = pos_level_mask[level]
  266. level_gt_mask = level_mask & gt_mask
  267. value, topk_inds = pos_losses[level_gt_mask].topk(
  268. min(level_gt_mask.sum(), self.topk), largest=False)
  269. pos_inds_gmm.append(pos_inds[level_gt_mask][topk_inds])
  270. pos_loss_gmm.append(value)
  271. pos_inds_gmm = torch.cat(pos_inds_gmm)
  272. pos_loss_gmm = torch.cat(pos_loss_gmm)
  273. # fix gmm need at least two sample
  274. if len(pos_inds_gmm) < 2:
  275. continue
  276. device = pos_inds_gmm.device
  277. pos_loss_gmm, sort_inds = pos_loss_gmm.sort()
  278. pos_inds_gmm = pos_inds_gmm[sort_inds]
  279. pos_loss_gmm = pos_loss_gmm.view(-1, 1).cpu().numpy()
  280. min_loss, max_loss = pos_loss_gmm.min(), pos_loss_gmm.max()
  281. means_init = np.array([min_loss, max_loss]).reshape(2, 1)
  282. weights_init = np.array([0.5, 0.5])
  283. precisions_init = np.array([1.0, 1.0]).reshape(2, 1, 1) # full
  284. if self.covariance_type == 'spherical':
  285. precisions_init = precisions_init.reshape(2)
  286. elif self.covariance_type == 'diag':
  287. precisions_init = precisions_init.reshape(2, 1)
  288. elif self.covariance_type == 'tied':
  289. precisions_init = np.array([[1.0]])
  290. if skm is None:
  291. raise ImportError('Please run "pip install sklearn" '
  292. 'to install sklearn first.')
  293. gmm = skm.GaussianMixture(
  294. 2,
  295. weights_init=weights_init,
  296. means_init=means_init,
  297. precisions_init=precisions_init,
  298. covariance_type=self.covariance_type)
  299. gmm.fit(pos_loss_gmm)
  300. gmm_assignment = gmm.predict(pos_loss_gmm)
  301. scores = gmm.score_samples(pos_loss_gmm)
  302. gmm_assignment = torch.from_numpy(gmm_assignment).to(device)
  303. scores = torch.from_numpy(scores).to(device)
  304. pos_inds_temp, ignore_inds_temp = self.gmm_separation_scheme(
  305. gmm_assignment, scores, pos_inds_gmm)
  306. pos_inds_after_paa.append(pos_inds_temp)
  307. ignore_inds_after_paa.append(ignore_inds_temp)
  308. pos_inds_after_paa = torch.cat(pos_inds_after_paa)
  309. ignore_inds_after_paa = torch.cat(ignore_inds_after_paa)
  310. reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_paa).all(1)
  311. reassign_ids = pos_inds[reassign_mask]
  312. label[reassign_ids] = self.num_classes
  313. label_weight[ignore_inds_after_paa] = 0
  314. bbox_weight[reassign_ids] = 0
  315. num_pos = len(pos_inds_after_paa)
  316. return label, label_weight, bbox_weight, num_pos
  317. def gmm_separation_scheme(self, gmm_assignment: Tensor, scores: Tensor,
  318. pos_inds_gmm: Tensor) -> Tuple[Tensor, Tensor]:
  319. """A general separation scheme for gmm model.
  320. It separates a GMM distribution of candidate samples into three
  321. parts, 0 1 and uncertain areas, and you can implement other
  322. separation schemes by rewriting this function.
  323. Args:
  324. gmm_assignment (Tensor): The prediction of GMM which is of shape
  325. (num_samples,). The 0/1 value indicates the distribution
  326. that each sample comes from.
  327. scores (Tensor): The probability of sample coming from the
  328. fit GMM distribution. The tensor is of shape (num_samples,).
  329. pos_inds_gmm (Tensor): All the indexes of samples which are used
  330. to fit GMM model. The tensor is of shape (num_samples,)
  331. Returns:
  332. tuple[Tensor, Tensor]: The indices of positive and ignored samples.
  333. - pos_inds_temp (Tensor): Indices of positive samples.
  334. - ignore_inds_temp (Tensor): Indices of ignore samples.
  335. """
  336. # The implementation is (c) in Fig.3 in origin paper instead of (b).
  337. # You can refer to issues such as
  338. # https://github.com/kkhoot/PAA/issues/8 and
  339. # https://github.com/kkhoot/PAA/issues/9.
  340. fgs = gmm_assignment == 0
  341. pos_inds_temp = fgs.new_tensor([], dtype=torch.long)
  342. ignore_inds_temp = fgs.new_tensor([], dtype=torch.long)
  343. if fgs.nonzero().numel():
  344. _, pos_thr_ind = scores[fgs].topk(1)
  345. pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1]
  346. ignore_inds_temp = pos_inds_gmm.new_tensor([])
  347. return pos_inds_temp, ignore_inds_temp
  348. def get_targets(self,
  349. anchor_list: List[List[Tensor]],
  350. valid_flag_list: List[List[Tensor]],
  351. batch_gt_instances: InstanceList,
  352. batch_img_metas: List[dict],
  353. batch_gt_instances_ignore: OptInstanceList = None,
  354. unmap_outputs: bool = True) -> tuple:
  355. """Get targets for PAA head.
  356. This method is almost the same as `AnchorHead.get_targets()`. We direct
  357. return the results from _get_targets_single instead map it to levels
  358. by images_to_levels function.
  359. Args:
  360. anchor_list (list[list[Tensor]]): Multi level anchors of each
  361. image. The outer list indicates images, and the inner list
  362. corresponds to feature levels of the image. Each element of
  363. the inner list is a tensor of shape (num_anchors, 4).
  364. valid_flag_list (list[list[Tensor]]): Multi level valid flags of
  365. each image. The outer list indicates images, and the inner list
  366. corresponds to feature levels of the image. Each element of
  367. the inner list is a tensor of shape (num_anchors, )
  368. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  369. gt_instance. It usually includes ``bboxes`` and ``labels``
  370. attributes.
  371. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  372. image size, scaling factor, etc.
  373. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  374. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  375. data that is ignored during training and testing.
  376. Defaults to None.
  377. unmap_outputs (bool): Whether to map outputs back to the original
  378. set of anchors. Defaults to True.
  379. Returns:
  380. tuple: Usually returns a tuple containing learning targets.
  381. - labels (list[Tensor]): Labels of all anchors, each with
  382. shape (num_anchors,).
  383. - label_weights (list[Tensor]): Label weights of all anchor.
  384. each with shape (num_anchors,).
  385. - bbox_targets (list[Tensor]): BBox targets of all anchors.
  386. each with shape (num_anchors, 4).
  387. - bbox_weights (list[Tensor]): BBox weights of all anchors.
  388. each with shape (num_anchors, 4).
  389. - pos_inds (list[Tensor]): Contains all index of positive
  390. sample in all anchor.
  391. - gt_inds (list[Tensor]): Contains all gt_index of positive
  392. sample in all anchor.
  393. """
  394. num_imgs = len(batch_img_metas)
  395. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  396. concat_anchor_list = []
  397. concat_valid_flag_list = []
  398. for i in range(num_imgs):
  399. assert len(anchor_list[i]) == len(valid_flag_list[i])
  400. concat_anchor_list.append(torch.cat(anchor_list[i]))
  401. concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
  402. # compute targets for each image
  403. if batch_gt_instances_ignore is None:
  404. batch_gt_instances_ignore = [None] * num_imgs
  405. results = multi_apply(
  406. self._get_targets_single,
  407. concat_anchor_list,
  408. concat_valid_flag_list,
  409. batch_gt_instances,
  410. batch_img_metas,
  411. batch_gt_instances_ignore,
  412. unmap_outputs=unmap_outputs)
  413. (labels, label_weights, bbox_targets, bbox_weights, valid_pos_inds,
  414. valid_neg_inds, sampling_result) = results
  415. # Due to valid flag of anchors, we have to calculate the real pos_inds
  416. # in origin anchor set.
  417. pos_inds = []
  418. for i, single_labels in enumerate(labels):
  419. pos_mask = (0 <= single_labels) & (
  420. single_labels < self.num_classes)
  421. pos_inds.append(pos_mask.nonzero().view(-1))
  422. gt_inds = [item.pos_assigned_gt_inds for item in sampling_result]
  423. return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
  424. gt_inds)
  425. def _get_targets_single(self,
  426. flat_anchors: Tensor,
  427. valid_flags: Tensor,
  428. gt_instances: InstanceData,
  429. img_meta: dict,
  430. gt_instances_ignore: Optional[InstanceData] = None,
  431. unmap_outputs: bool = True) -> tuple:
  432. """Compute regression and classification targets for anchors in a
  433. single image.
  434. This method is same as `AnchorHead._get_targets_single()`.
  435. """
  436. assert unmap_outputs, 'We must map outputs back to the original' \
  437. 'set of anchors in PAAhead'
  438. return super(ATSSHead, self)._get_targets_single(
  439. flat_anchors,
  440. valid_flags,
  441. gt_instances,
  442. img_meta,
  443. gt_instances_ignore,
  444. unmap_outputs=True)
  445. def predict_by_feat(self,
  446. cls_scores: List[Tensor],
  447. bbox_preds: List[Tensor],
  448. score_factors: Optional[List[Tensor]] = None,
  449. batch_img_metas: Optional[List[dict]] = None,
  450. cfg: OptConfigType = None,
  451. rescale: bool = False,
  452. with_nms: bool = True) -> InstanceList:
  453. """Transform a batch of output features extracted from the head into
  454. bbox results.
  455. This method is same as `BaseDenseHead.get_results()`.
  456. """
  457. assert with_nms, 'PAA only supports "with_nms=True" now and it ' \
  458. 'means PAAHead does not support ' \
  459. 'test-time augmentation'
  460. return super().predict_by_feat(
  461. cls_scores=cls_scores,
  462. bbox_preds=bbox_preds,
  463. score_factors=score_factors,
  464. batch_img_metas=batch_img_metas,
  465. cfg=cfg,
  466. rescale=rescale,
  467. with_nms=with_nms)
  468. def _predict_by_feat_single(self,
  469. cls_score_list: List[Tensor],
  470. bbox_pred_list: List[Tensor],
  471. score_factor_list: List[Tensor],
  472. mlvl_priors: List[Tensor],
  473. img_meta: dict,
  474. cfg: OptConfigType = None,
  475. rescale: bool = False,
  476. with_nms: bool = True) -> InstanceData:
  477. """Transform a single image's features extracted from the head into
  478. bbox results.
  479. Args:
  480. cls_score_list (list[Tensor]): Box scores from all scale
  481. levels of a single image, each item has shape
  482. (num_priors * num_classes, H, W).
  483. bbox_pred_list (list[Tensor]): Box energies / deltas from
  484. all scale levels of a single image, each item has shape
  485. (num_priors * 4, H, W).
  486. score_factor_list (list[Tensor]): Score factors from all scale
  487. levels of a single image, each item has shape
  488. (num_priors * 1, H, W).
  489. mlvl_priors (list[Tensor]): Each element in the list is
  490. the priors of a single level in feature pyramid, has shape
  491. (num_priors, 4).
  492. img_meta (dict): Image meta info.
  493. cfg (:obj:`ConfigDict` or dict, optional): Test / postprocessing
  494. configuration, if None, test_cfg would be used.
  495. rescale (bool): If True, return boxes in original image space.
  496. Default: False.
  497. with_nms (bool): If True, do nms before return boxes.
  498. Default: True.
  499. Returns:
  500. :obj:`InstanceData`: Detection results of each image
  501. after the post process.
  502. Each item usually contains following keys.
  503. - scores (Tensor): Classification scores, has a shape
  504. (num_instance, )
  505. - labels (Tensor): Labels of bboxes, has a shape
  506. (num_instances, ).
  507. - bboxes (Tensor): Has a shape (num_instances, 4),
  508. the last dimension 4 arrange as (x1, y1, x2, y2).
  509. """
  510. cfg = self.test_cfg if cfg is None else cfg
  511. img_shape = img_meta['img_shape']
  512. nms_pre = cfg.get('nms_pre', -1)
  513. mlvl_bboxes = []
  514. mlvl_scores = []
  515. mlvl_score_factors = []
  516. for level_idx, (cls_score, bbox_pred, score_factor, priors) in \
  517. enumerate(zip(cls_score_list, bbox_pred_list,
  518. score_factor_list, mlvl_priors)):
  519. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  520. scores = cls_score.permute(1, 2, 0).reshape(
  521. -1, self.cls_out_channels).sigmoid()
  522. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  523. score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid()
  524. if 0 < nms_pre < scores.shape[0]:
  525. max_scores, _ = (scores *
  526. score_factor[:, None]).sqrt().max(dim=1)
  527. _, topk_inds = max_scores.topk(nms_pre)
  528. priors = priors[topk_inds, :]
  529. bbox_pred = bbox_pred[topk_inds, :]
  530. scores = scores[topk_inds, :]
  531. score_factor = score_factor[topk_inds]
  532. bboxes = self.bbox_coder.decode(
  533. priors, bbox_pred, max_shape=img_shape)
  534. mlvl_bboxes.append(bboxes)
  535. mlvl_scores.append(scores)
  536. mlvl_score_factors.append(score_factor)
  537. results = InstanceData()
  538. results.bboxes = torch.cat(mlvl_bboxes)
  539. results.scores = torch.cat(mlvl_scores)
  540. results.score_factors = torch.cat(mlvl_score_factors)
  541. return self._bbox_post_process(results, cfg, rescale, with_nms,
  542. img_meta)
  543. def _bbox_post_process(self,
  544. results: InstanceData,
  545. cfg: ConfigType,
  546. rescale: bool = False,
  547. with_nms: bool = True,
  548. img_meta: Optional[dict] = None):
  549. """bbox post-processing method.
  550. The boxes would be rescaled to the original image scale and do
  551. the nms operation. Usually with_nms is False is used for aug test.
  552. Args:
  553. results (:obj:`InstaceData`): Detection instance results,
  554. each item has shape (num_bboxes, ).
  555. cfg (:obj:`ConfigDict` or dict): Test / postprocessing
  556. configuration, if None, test_cfg would be used.
  557. rescale (bool): If True, return boxes in original image space.
  558. Default: False.
  559. with_nms (bool): If True, do nms before return boxes.
  560. Default: True.
  561. img_meta (dict, optional): Image meta info. Defaults to None.
  562. Returns:
  563. :obj:`InstanceData`: Detection results of each image
  564. after the post process.
  565. Each item usually contains following keys.
  566. - scores (Tensor): Classification scores, has a shape
  567. (num_instance, )
  568. - labels (Tensor): Labels of bboxes, has a shape
  569. (num_instances, ).
  570. - bboxes (Tensor): Has a shape (num_instances, 4),
  571. the last dimension 4 arrange as (x1, y1, x2, y2).
  572. """
  573. if rescale:
  574. results.bboxes /= results.bboxes.new_tensor(
  575. img_meta['scale_factor']).repeat((1, 2))
  576. # Add a dummy background class to the backend when using sigmoid
  577. # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
  578. # BG cat_id: num_class
  579. padding = results.scores.new_zeros(results.scores.shape[0], 1)
  580. mlvl_scores = torch.cat([results.scores, padding], dim=1)
  581. mlvl_nms_scores = (mlvl_scores * results.score_factors[:, None]).sqrt()
  582. det_bboxes, det_labels = multiclass_nms(
  583. results.bboxes,
  584. mlvl_nms_scores,
  585. cfg.score_thr,
  586. cfg.nms,
  587. cfg.max_per_img,
  588. score_factors=None)
  589. if self.with_score_voting and len(det_bboxes) > 0:
  590. det_bboxes, det_labels = self.score_voting(det_bboxes, det_labels,
  591. results.bboxes,
  592. mlvl_nms_scores,
  593. cfg.score_thr)
  594. nms_results = InstanceData()
  595. nms_results.bboxes = det_bboxes[:, :-1]
  596. nms_results.scores = det_bboxes[:, -1]
  597. nms_results.labels = det_labels
  598. return nms_results
  599. def score_voting(self, det_bboxes: Tensor, det_labels: Tensor,
  600. mlvl_bboxes: Tensor, mlvl_nms_scores: Tensor,
  601. score_thr: float) -> Tuple[Tensor, Tensor]:
  602. """Implementation of score voting method works on each remaining boxes
  603. after NMS procedure.
  604. Args:
  605. det_bboxes (Tensor): Remaining boxes after NMS procedure,
  606. with shape (k, 5), each dimension means
  607. (x1, y1, x2, y2, score).
  608. det_labels (Tensor): The label of remaining boxes, with shape
  609. (k, 1),Labels are 0-based.
  610. mlvl_bboxes (Tensor): All boxes before the NMS procedure,
  611. with shape (num_anchors,4).
  612. mlvl_nms_scores (Tensor): The scores of all boxes which is used
  613. in the NMS procedure, with shape (num_anchors, num_class)
  614. score_thr (float): The score threshold of bboxes.
  615. Returns:
  616. tuple: Usually returns a tuple containing voting results.
  617. - det_bboxes_voted (Tensor): Remaining boxes after
  618. score voting procedure, with shape (k, 5), each
  619. dimension means (x1, y1, x2, y2, score).
  620. - det_labels_voted (Tensor): Label of remaining bboxes
  621. after voting, with shape (num_anchors,).
  622. """
  623. candidate_mask = mlvl_nms_scores > score_thr
  624. candidate_mask_nonzeros = candidate_mask.nonzero(as_tuple=False)
  625. candidate_inds = candidate_mask_nonzeros[:, 0]
  626. candidate_labels = candidate_mask_nonzeros[:, 1]
  627. candidate_bboxes = mlvl_bboxes[candidate_inds]
  628. candidate_scores = mlvl_nms_scores[candidate_mask]
  629. det_bboxes_voted = []
  630. det_labels_voted = []
  631. for cls in range(self.cls_out_channels):
  632. candidate_cls_mask = candidate_labels == cls
  633. if not candidate_cls_mask.any():
  634. continue
  635. candidate_cls_scores = candidate_scores[candidate_cls_mask]
  636. candidate_cls_bboxes = candidate_bboxes[candidate_cls_mask]
  637. det_cls_mask = det_labels == cls
  638. det_cls_bboxes = det_bboxes[det_cls_mask].view(
  639. -1, det_bboxes.size(-1))
  640. det_candidate_ious = bbox_overlaps(det_cls_bboxes[:, :4],
  641. candidate_cls_bboxes)
  642. for det_ind in range(len(det_cls_bboxes)):
  643. single_det_ious = det_candidate_ious[det_ind]
  644. pos_ious_mask = single_det_ious > 0.01
  645. pos_ious = single_det_ious[pos_ious_mask]
  646. pos_bboxes = candidate_cls_bboxes[pos_ious_mask]
  647. pos_scores = candidate_cls_scores[pos_ious_mask]
  648. pis = (torch.exp(-(1 - pos_ious)**2 / 0.025) *
  649. pos_scores)[:, None]
  650. voted_box = torch.sum(
  651. pis * pos_bboxes, dim=0) / torch.sum(
  652. pis, dim=0)
  653. voted_score = det_cls_bboxes[det_ind][-1:][None, :]
  654. det_bboxes_voted.append(
  655. torch.cat((voted_box[None, :], voted_score), dim=1))
  656. det_labels_voted.append(cls)
  657. det_bboxes_voted = torch.cat(det_bboxes_voted, dim=0)
  658. det_labels_voted = det_labels.new_tensor(det_labels_voted)
  659. return det_bboxes_voted, det_labels_voted