bbox_head.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmengine.config import ConfigDict
  7. from mmengine.model import BaseModule
  8. from mmengine.structures import InstanceData
  9. from torch import Tensor
  10. from torch.nn.modules.utils import _pair
  11. from mmdet.models.layers import multiclass_nms
  12. from mmdet.models.losses import accuracy
  13. from mmdet.models.task_modules.samplers import SamplingResult
  14. from mmdet.models.utils import empty_instances, multi_apply
  15. from mmdet.registry import MODELS, TASK_UTILS
  16. from mmdet.structures.bbox import get_box_tensor, scale_boxes
  17. from mmdet.utils import ConfigType, InstanceList, OptMultiConfig
  18. @MODELS.register_module()
  19. class BBoxHead(BaseModule):
  20. """Simplest RoI head, with only two fc layers for classification and
  21. regression respectively."""
  22. def __init__(self,
  23. with_avg_pool: bool = False,
  24. with_cls: bool = True,
  25. with_reg: bool = True,
  26. roi_feat_size: int = 7,
  27. in_channels: int = 256,
  28. num_classes: int = 80,
  29. bbox_coder: ConfigType = dict(
  30. type='DeltaXYWHBBoxCoder',
  31. clip_border=True,
  32. target_means=[0., 0., 0., 0.],
  33. target_stds=[0.1, 0.1, 0.2, 0.2]),
  34. predict_box_type: str = 'hbox',
  35. reg_class_agnostic: bool = False,
  36. reg_decoded_bbox: bool = False,
  37. reg_predictor_cfg: ConfigType = dict(type='Linear'),
  38. cls_predictor_cfg: ConfigType = dict(type='Linear'),
  39. loss_cls: ConfigType = dict(
  40. type='CrossEntropyLoss',
  41. use_sigmoid=False,
  42. loss_weight=1.0),
  43. loss_bbox: ConfigType = dict(
  44. type='SmoothL1Loss', beta=1.0, loss_weight=1.0),
  45. init_cfg: OptMultiConfig = None) -> None:
  46. super().__init__(init_cfg=init_cfg)
  47. assert with_cls or with_reg
  48. self.with_avg_pool = with_avg_pool
  49. self.with_cls = with_cls
  50. self.with_reg = with_reg
  51. self.roi_feat_size = _pair(roi_feat_size)
  52. self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
  53. self.in_channels = in_channels
  54. self.num_classes = num_classes
  55. self.predict_box_type = predict_box_type
  56. self.reg_class_agnostic = reg_class_agnostic
  57. self.reg_decoded_bbox = reg_decoded_bbox
  58. self.reg_predictor_cfg = reg_predictor_cfg
  59. self.cls_predictor_cfg = cls_predictor_cfg
  60. self.bbox_coder = TASK_UTILS.build(bbox_coder)
  61. self.loss_cls = MODELS.build(loss_cls)
  62. self.loss_bbox = MODELS.build(loss_bbox)
  63. in_channels = self.in_channels
  64. if self.with_avg_pool:
  65. self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
  66. else:
  67. in_channels *= self.roi_feat_area
  68. if self.with_cls:
  69. # need to add background class
  70. if self.custom_cls_channels:
  71. cls_channels = self.loss_cls.get_cls_channels(self.num_classes)
  72. else:
  73. cls_channels = num_classes + 1
  74. cls_predictor_cfg_ = self.cls_predictor_cfg.copy()
  75. cls_predictor_cfg_.update(
  76. in_features=in_channels, out_features=cls_channels)
  77. self.fc_cls = MODELS.build(cls_predictor_cfg_)
  78. if self.with_reg:
  79. box_dim = self.bbox_coder.encode_size
  80. out_dim_reg = box_dim if reg_class_agnostic else \
  81. box_dim * num_classes
  82. reg_predictor_cfg_ = self.reg_predictor_cfg.copy()
  83. if isinstance(reg_predictor_cfg_, (dict, ConfigDict)):
  84. reg_predictor_cfg_.update(
  85. in_features=in_channels, out_features=out_dim_reg)
  86. self.fc_reg = MODELS.build(reg_predictor_cfg_)
  87. self.debug_imgs = None
  88. if init_cfg is None:
  89. self.init_cfg = []
  90. if self.with_cls:
  91. self.init_cfg += [
  92. dict(
  93. type='Normal', std=0.01, override=dict(name='fc_cls'))
  94. ]
  95. if self.with_reg:
  96. self.init_cfg += [
  97. dict(
  98. type='Normal', std=0.001, override=dict(name='fc_reg'))
  99. ]
  100. # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
  101. @property
  102. def custom_cls_channels(self) -> bool:
  103. """get custom_cls_channels from loss_cls."""
  104. return getattr(self.loss_cls, 'custom_cls_channels', False)
  105. # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
  106. @property
  107. def custom_activation(self) -> bool:
  108. """get custom_activation from loss_cls."""
  109. return getattr(self.loss_cls, 'custom_activation', False)
  110. # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
  111. @property
  112. def custom_accuracy(self) -> bool:
  113. """get custom_accuracy from loss_cls."""
  114. return getattr(self.loss_cls, 'custom_accuracy', False)
  115. def forward(self, x: Tuple[Tensor]) -> tuple:
  116. """Forward features from the upstream network.
  117. Args:
  118. x (tuple[Tensor]): Features from the upstream network, each is
  119. a 4D-tensor.
  120. Returns:
  121. tuple: A tuple of classification scores and bbox prediction.
  122. - cls_score (Tensor): Classification scores for all
  123. scale levels, each is a 4D-tensor, the channels number
  124. is num_base_priors * num_classes.
  125. - bbox_pred (Tensor): Box energies / deltas for all
  126. scale levels, each is a 4D-tensor, the channels number
  127. is num_base_priors * 4.
  128. """
  129. if self.with_avg_pool:
  130. if x.numel() > 0:
  131. x = self.avg_pool(x)
  132. x = x.view(x.size(0), -1)
  133. else:
  134. # avg_pool does not support empty tensor,
  135. # so use torch.mean instead it
  136. x = torch.mean(x, dim=(-1, -2))
  137. cls_score = self.fc_cls(x) if self.with_cls else None
  138. bbox_pred = self.fc_reg(x) if self.with_reg else None
  139. return cls_score, bbox_pred
  140. def _get_targets_single(self, pos_priors: Tensor, neg_priors: Tensor,
  141. pos_gt_bboxes: Tensor, pos_gt_labels: Tensor,
  142. cfg: ConfigDict) -> tuple:
  143. """Calculate the ground truth for proposals in the single image
  144. according to the sampling results.
  145. Args:
  146. pos_priors (Tensor): Contains all the positive boxes,
  147. has shape (num_pos, 4), the last dimension 4
  148. represents [tl_x, tl_y, br_x, br_y].
  149. neg_priors (Tensor): Contains all the negative boxes,
  150. has shape (num_neg, 4), the last dimension 4
  151. represents [tl_x, tl_y, br_x, br_y].
  152. pos_gt_bboxes (Tensor): Contains gt_boxes for
  153. all positive samples, has shape (num_pos, 4),
  154. the last dimension 4
  155. represents [tl_x, tl_y, br_x, br_y].
  156. pos_gt_labels (Tensor): Contains gt_labels for
  157. all positive samples, has shape (num_pos, ).
  158. cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
  159. Returns:
  160. Tuple[Tensor]: Ground truth for proposals
  161. in a single image. Containing the following Tensors:
  162. - labels(Tensor): Gt_labels for all proposals, has
  163. shape (num_proposals,).
  164. - label_weights(Tensor): Labels_weights for all
  165. proposals, has shape (num_proposals,).
  166. - bbox_targets(Tensor):Regression target for all
  167. proposals, has shape (num_proposals, 4), the
  168. last dimension 4 represents [tl_x, tl_y, br_x, br_y].
  169. - bbox_weights(Tensor):Regression weights for all
  170. proposals, has shape (num_proposals, 4).
  171. """
  172. num_pos = pos_priors.size(0)
  173. num_neg = neg_priors.size(0)
  174. num_samples = num_pos + num_neg
  175. # original implementation uses new_zeros since BG are set to be 0
  176. # now use empty & fill because BG cat_id = num_classes,
  177. # FG cat_id = [0, num_classes-1]
  178. labels = pos_priors.new_full((num_samples, ),
  179. self.num_classes,
  180. dtype=torch.long)
  181. reg_dim = pos_gt_bboxes.size(-1) if self.reg_decoded_bbox \
  182. else self.bbox_coder.encode_size
  183. label_weights = pos_priors.new_zeros(num_samples)
  184. bbox_targets = pos_priors.new_zeros(num_samples, reg_dim)
  185. bbox_weights = pos_priors.new_zeros(num_samples, reg_dim)
  186. if num_pos > 0:
  187. labels[:num_pos] = pos_gt_labels
  188. pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
  189. label_weights[:num_pos] = pos_weight
  190. if not self.reg_decoded_bbox:
  191. pos_bbox_targets = self.bbox_coder.encode(
  192. pos_priors, pos_gt_bboxes)
  193. else:
  194. # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
  195. # is applied directly on the decoded bounding boxes, both
  196. # the predicted boxes and regression targets should be with
  197. # absolute coordinate format.
  198. pos_bbox_targets = get_box_tensor(pos_gt_bboxes)
  199. bbox_targets[:num_pos, :] = pos_bbox_targets
  200. bbox_weights[:num_pos, :] = 1
  201. if num_neg > 0:
  202. label_weights[-num_neg:] = 1.0
  203. return labels, label_weights, bbox_targets, bbox_weights
  204. def get_targets(self,
  205. sampling_results: List[SamplingResult],
  206. rcnn_train_cfg: ConfigDict,
  207. concat: bool = True) -> tuple:
  208. """Calculate the ground truth for all samples in a batch according to
  209. the sampling_results.
  210. Almost the same as the implementation in bbox_head, we passed
  211. additional parameters pos_inds_list and neg_inds_list to
  212. `_get_targets_single` function.
  213. Args:
  214. sampling_results (List[obj:SamplingResult]): Assign results of
  215. all images in a batch after sampling.
  216. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
  217. concat (bool): Whether to concatenate the results of all
  218. the images in a single batch.
  219. Returns:
  220. Tuple[Tensor]: Ground truth for proposals in a single image.
  221. Containing the following list of Tensors:
  222. - labels (list[Tensor],Tensor): Gt_labels for all
  223. proposals in a batch, each tensor in list has
  224. shape (num_proposals,) when `concat=False`, otherwise
  225. just a single tensor has shape (num_all_proposals,).
  226. - label_weights (list[Tensor]): Labels_weights for
  227. all proposals in a batch, each tensor in list has
  228. shape (num_proposals,) when `concat=False`, otherwise
  229. just a single tensor has shape (num_all_proposals,).
  230. - bbox_targets (list[Tensor],Tensor): Regression target
  231. for all proposals in a batch, each tensor in list
  232. has shape (num_proposals, 4) when `concat=False`,
  233. otherwise just a single tensor has shape
  234. (num_all_proposals, 4), the last dimension 4 represents
  235. [tl_x, tl_y, br_x, br_y].
  236. - bbox_weights (list[tensor],Tensor): Regression weights for
  237. all proposals in a batch, each tensor in list has shape
  238. (num_proposals, 4) when `concat=False`, otherwise just a
  239. single tensor has shape (num_all_proposals, 4).
  240. """
  241. pos_priors_list = [res.pos_priors for res in sampling_results]
  242. neg_priors_list = [res.neg_priors for res in sampling_results]
  243. pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
  244. pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
  245. labels, label_weights, bbox_targets, bbox_weights = multi_apply(
  246. self._get_targets_single,
  247. pos_priors_list,
  248. neg_priors_list,
  249. pos_gt_bboxes_list,
  250. pos_gt_labels_list,
  251. cfg=rcnn_train_cfg)
  252. if concat:
  253. labels = torch.cat(labels, 0)
  254. label_weights = torch.cat(label_weights, 0)
  255. bbox_targets = torch.cat(bbox_targets, 0)
  256. bbox_weights = torch.cat(bbox_weights, 0)
  257. return labels, label_weights, bbox_targets, bbox_weights
  258. def loss_and_target(self,
  259. cls_score: Tensor,
  260. bbox_pred: Tensor,
  261. rois: Tensor,
  262. sampling_results: List[SamplingResult],
  263. rcnn_train_cfg: ConfigDict,
  264. concat: bool = True,
  265. reduction_override: Optional[str] = None) -> dict:
  266. """Calculate the loss based on the features extracted by the bbox head.
  267. Args:
  268. cls_score (Tensor): Classification prediction
  269. results of all class, has shape
  270. (batch_size * num_proposals_single_image, num_classes)
  271. bbox_pred (Tensor): Regression prediction results,
  272. has shape
  273. (batch_size * num_proposals_single_image, 4), the last
  274. dimension 4 represents [tl_x, tl_y, br_x, br_y].
  275. rois (Tensor): RoIs with the shape
  276. (batch_size * num_proposals_single_image, 5) where the first
  277. column indicates batch id of each RoI.
  278. sampling_results (List[obj:SamplingResult]): Assign results of
  279. all images in a batch after sampling.
  280. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
  281. concat (bool): Whether to concatenate the results of all
  282. the images in a single batch. Defaults to True.
  283. reduction_override (str, optional): The reduction
  284. method used to override the original reduction
  285. method of the loss. Options are "none",
  286. "mean" and "sum". Defaults to None,
  287. Returns:
  288. dict: A dictionary of loss and targets components.
  289. The targets are only used for cascade rcnn.
  290. """
  291. cls_reg_targets = self.get_targets(
  292. sampling_results, rcnn_train_cfg, concat=concat)
  293. losses = self.loss(
  294. cls_score,
  295. bbox_pred,
  296. rois,
  297. *cls_reg_targets,
  298. reduction_override=reduction_override)
  299. # cls_reg_targets is only for cascade rcnn
  300. return dict(loss_bbox=losses, bbox_targets=cls_reg_targets)
  301. def loss(self,
  302. cls_score: Tensor,
  303. bbox_pred: Tensor,
  304. rois: Tensor,
  305. labels: Tensor,
  306. label_weights: Tensor,
  307. bbox_targets: Tensor,
  308. bbox_weights: Tensor,
  309. reduction_override: Optional[str] = None) -> dict:
  310. """Calculate the loss based on the network predictions and targets.
  311. Args:
  312. cls_score (Tensor): Classification prediction
  313. results of all class, has shape
  314. (batch_size * num_proposals_single_image, num_classes)
  315. bbox_pred (Tensor): Regression prediction results,
  316. has shape
  317. (batch_size * num_proposals_single_image, 4), the last
  318. dimension 4 represents [tl_x, tl_y, br_x, br_y].
  319. rois (Tensor): RoIs with the shape
  320. (batch_size * num_proposals_single_image, 5) where the first
  321. column indicates batch id of each RoI.
  322. labels (Tensor): Gt_labels for all proposals in a batch, has
  323. shape (batch_size * num_proposals_single_image, ).
  324. label_weights (Tensor): Labels_weights for all proposals in a
  325. batch, has shape (batch_size * num_proposals_single_image, ).
  326. bbox_targets (Tensor): Regression target for all proposals in a
  327. batch, has shape (batch_size * num_proposals_single_image, 4),
  328. the last dimension 4 represents [tl_x, tl_y, br_x, br_y].
  329. bbox_weights (Tensor): Regression weights for all proposals in a
  330. batch, has shape (batch_size * num_proposals_single_image, 4).
  331. reduction_override (str, optional): The reduction
  332. method used to override the original reduction
  333. method of the loss. Options are "none",
  334. "mean" and "sum". Defaults to None,
  335. Returns:
  336. dict: A dictionary of loss.
  337. """
  338. losses = dict()
  339. if cls_score is not None:
  340. avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
  341. if cls_score.numel() > 0:
  342. loss_cls_ = self.loss_cls(
  343. cls_score,
  344. labels,
  345. label_weights,
  346. avg_factor=avg_factor,
  347. reduction_override=reduction_override)
  348. if isinstance(loss_cls_, dict):
  349. losses.update(loss_cls_)
  350. else:
  351. losses['loss_cls'] = loss_cls_
  352. if self.custom_activation:
  353. acc_ = self.loss_cls.get_accuracy(cls_score, labels)
  354. losses.update(acc_)
  355. else:
  356. losses['acc'] = accuracy(cls_score, labels)
  357. if bbox_pred is not None:
  358. bg_class_ind = self.num_classes
  359. # 0~self.num_classes-1 are FG, self.num_classes is BG
  360. pos_inds = (labels >= 0) & (labels < bg_class_ind)
  361. # do not perform bounding box regression for BG anymore.
  362. if pos_inds.any():
  363. if self.reg_decoded_bbox:
  364. # When the regression loss (e.g. `IouLoss`,
  365. # `GIouLoss`, `DIouLoss`) is applied directly on
  366. # the decoded bounding boxes, it decodes the
  367. # already encoded coordinates to absolute format.
  368. bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred)
  369. bbox_pred = get_box_tensor(bbox_pred)
  370. if self.reg_class_agnostic:
  371. pos_bbox_pred = bbox_pred.view(
  372. bbox_pred.size(0), -1)[pos_inds.type(torch.bool)]
  373. else:
  374. pos_bbox_pred = bbox_pred.view(
  375. bbox_pred.size(0), self.num_classes,
  376. -1)[pos_inds.type(torch.bool),
  377. labels[pos_inds.type(torch.bool)]]
  378. losses['loss_bbox'] = self.loss_bbox(
  379. pos_bbox_pred,
  380. bbox_targets[pos_inds.type(torch.bool)],
  381. bbox_weights[pos_inds.type(torch.bool)],
  382. avg_factor=bbox_targets.size(0),
  383. reduction_override=reduction_override)
  384. else:
  385. losses['loss_bbox'] = bbox_pred[pos_inds].sum()
  386. return losses
  387. def predict_by_feat(self,
  388. rois: Tuple[Tensor],
  389. cls_scores: Tuple[Tensor],
  390. bbox_preds: Tuple[Tensor],
  391. batch_img_metas: List[dict],
  392. rcnn_test_cfg: Optional[ConfigDict] = None,
  393. rescale: bool = False) -> InstanceList:
  394. """Transform a batch of output features extracted from the head into
  395. bbox results.
  396. Args:
  397. rois (tuple[Tensor]): Tuple of boxes to be transformed.
  398. Each has shape (num_boxes, 5). last dimension 5 arrange as
  399. (batch_index, x1, y1, x2, y2).
  400. cls_scores (tuple[Tensor]): Tuple of box scores, each has shape
  401. (num_boxes, num_classes + 1).
  402. bbox_preds (tuple[Tensor]): Tuple of box energies / deltas, each
  403. has shape (num_boxes, num_classes * 4).
  404. batch_img_metas (list[dict]): List of image information.
  405. rcnn_test_cfg (obj:`ConfigDict`, optional): `test_cfg` of R-CNN.
  406. Defaults to None.
  407. rescale (bool): If True, return boxes in original image space.
  408. Defaults to False.
  409. Returns:
  410. list[:obj:`InstanceData`]: Instance segmentation
  411. results of each image after the post process.
  412. Each item usually contains following keys.
  413. - scores (Tensor): Classification scores, has a shape
  414. (num_instance, )
  415. - labels (Tensor): Labels of bboxes, has a shape
  416. (num_instances, ).
  417. - bboxes (Tensor): Has a shape (num_instances, 4),
  418. the last dimension 4 arrange as (x1, y1, x2, y2).
  419. """
  420. assert len(cls_scores) == len(bbox_preds)
  421. result_list = []
  422. for img_id in range(len(batch_img_metas)):
  423. img_meta = batch_img_metas[img_id]
  424. results = self._predict_by_feat_single(
  425. roi=rois[img_id],
  426. cls_score=cls_scores[img_id],
  427. bbox_pred=bbox_preds[img_id],
  428. img_meta=img_meta,
  429. rescale=rescale,
  430. rcnn_test_cfg=rcnn_test_cfg)
  431. result_list.append(results)
  432. return result_list
  433. def _predict_by_feat_single(
  434. self,
  435. roi: Tensor,
  436. cls_score: Tensor,
  437. bbox_pred: Tensor,
  438. img_meta: dict,
  439. rescale: bool = False,
  440. rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData:
  441. """Transform a single image's features extracted from the head into
  442. bbox results.
  443. Args:
  444. roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
  445. last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
  446. cls_score (Tensor): Box scores, has shape
  447. (num_boxes, num_classes + 1).
  448. bbox_pred (Tensor): Box energies / deltas.
  449. has shape (num_boxes, num_classes * 4).
  450. img_meta (dict): image information.
  451. rescale (bool): If True, return boxes in original image space.
  452. Defaults to False.
  453. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head.
  454. Defaults to None
  455. Returns:
  456. :obj:`InstanceData`: Detection results of each image\
  457. Each item usually contains following keys.
  458. - scores (Tensor): Classification scores, has a shape
  459. (num_instance, )
  460. - labels (Tensor): Labels of bboxes, has a shape
  461. (num_instances, ).
  462. - bboxes (Tensor): Has a shape (num_instances, 4),
  463. the last dimension 4 arrange as (x1, y1, x2, y2).
  464. """
  465. results = InstanceData()
  466. if roi.shape[0] == 0:
  467. return empty_instances([img_meta],
  468. roi.device,
  469. task_type='bbox',
  470. instance_results=[results],
  471. box_type=self.predict_box_type,
  472. use_box_type=False,
  473. num_classes=self.num_classes,
  474. score_per_cls=rcnn_test_cfg is None)[0]
  475. # some loss (Seesaw loss..) may have custom activation
  476. if self.custom_cls_channels:
  477. scores = self.loss_cls.get_activation(cls_score)
  478. else:
  479. scores = F.softmax(
  480. cls_score, dim=-1) if cls_score is not None else None
  481. img_shape = img_meta['img_shape']
  482. num_rois = roi.size(0)
  483. # bbox_pred would be None in some detector when with_reg is False,
  484. # e.g. Grid R-CNN.
  485. if bbox_pred is not None:
  486. num_classes = 1 if self.reg_class_agnostic else self.num_classes
  487. roi = roi.repeat_interleave(num_classes, dim=0)
  488. bbox_pred = bbox_pred.view(-1, self.bbox_coder.encode_size)
  489. bboxes = self.bbox_coder.decode(
  490. roi[..., 1:], bbox_pred, max_shape=img_shape)
  491. else:
  492. bboxes = roi[:, 1:].clone()
  493. if img_shape is not None and bboxes.size(-1) == 4:
  494. bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1])
  495. bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0])
  496. if rescale and bboxes.size(0) > 0:
  497. assert img_meta.get('scale_factor') is not None
  498. scale_factor = [1 / s for s in img_meta['scale_factor']]
  499. bboxes = scale_boxes(bboxes, scale_factor)
  500. # Get the inside tensor when `bboxes` is a box type
  501. bboxes = get_box_tensor(bboxes)
  502. box_dim = bboxes.size(-1)
  503. bboxes = bboxes.view(num_rois, -1)
  504. if rcnn_test_cfg is None:
  505. # This means that it is aug test.
  506. # It needs to return the raw results without nms.
  507. results.bboxes = bboxes
  508. results.scores = scores
  509. else:
  510. det_bboxes, det_labels = multiclass_nms(
  511. bboxes,
  512. scores,
  513. rcnn_test_cfg.score_thr,
  514. rcnn_test_cfg.nms,
  515. rcnn_test_cfg.max_per_img,
  516. box_dim=box_dim)
  517. results.bboxes = det_bboxes[:, :-1]
  518. results.scores = det_bboxes[:, -1]
  519. results.labels = det_labels
  520. return results
  521. def refine_bboxes(self, sampling_results: Union[List[SamplingResult],
  522. InstanceList],
  523. bbox_results: dict,
  524. batch_img_metas: List[dict]) -> InstanceList:
  525. """Refine bboxes during training.
  526. Args:
  527. sampling_results (List[:obj:`SamplingResult`] or
  528. List[:obj:`InstanceData`]): Sampling results.
  529. :obj:`SamplingResult` is the real sampling results
  530. calculate from bbox_head, while :obj:`InstanceData` is
  531. fake sampling results, e.g., in Sparse R-CNN or QueryInst, etc.
  532. bbox_results (dict): Usually is a dictionary with keys:
  533. - `cls_score` (Tensor): Classification scores.
  534. - `bbox_pred` (Tensor): Box energies / deltas.
  535. - `rois` (Tensor): RoIs with the shape (n, 5) where the first
  536. column indicates batch id of each RoI.
  537. - `bbox_targets` (tuple): Ground truth for proposals in a
  538. single image. Containing the following list of Tensors:
  539. (labels, label_weights, bbox_targets, bbox_weights)
  540. batch_img_metas (List[dict]): List of image information.
  541. Returns:
  542. list[:obj:`InstanceData`]: Refined bboxes of each image.
  543. Example:
  544. >>> # xdoctest: +REQUIRES(module:kwarray)
  545. >>> import numpy as np
  546. >>> from mmdet.models.task_modules.samplers.
  547. ... sampling_result import random_boxes
  548. >>> from mmdet.models.task_modules.samplers import SamplingResult
  549. >>> self = BBoxHead(reg_class_agnostic=True)
  550. >>> n_roi = 2
  551. >>> n_img = 4
  552. >>> scale = 512
  553. >>> rng = np.random.RandomState(0)
  554. ... batch_img_metas = [{'img_shape': (scale, scale)}
  555. >>> for _ in range(n_img)]
  556. >>> sampling_results = [SamplingResult.random(rng=10)
  557. ... for _ in range(n_img)]
  558. >>> # Create rois in the expected format
  559. >>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
  560. >>> img_ids = torch.randint(0, n_img, (n_roi,))
  561. >>> img_ids = img_ids.float()
  562. >>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1)
  563. >>> # Create other args
  564. >>> labels = torch.randint(0, 81, (scale,)).long()
  565. >>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
  566. >>> cls_score = torch.randn((scale, 81))
  567. ... # For each image, pretend random positive boxes are gts
  568. >>> bbox_targets = (labels, None, None, None)
  569. ... bbox_results = dict(rois=rois, bbox_pred=bbox_preds,
  570. ... cls_score=cls_score,
  571. ... bbox_targets=bbox_targets)
  572. >>> bboxes_list = self.refine_bboxes(sampling_results,
  573. ... bbox_results,
  574. ... batch_img_metas)
  575. >>> print(bboxes_list)
  576. """
  577. pos_is_gts = [res.pos_is_gt for res in sampling_results]
  578. # bbox_targets is a tuple
  579. labels = bbox_results['bbox_targets'][0]
  580. cls_scores = bbox_results['cls_score']
  581. rois = bbox_results['rois']
  582. bbox_preds = bbox_results['bbox_pred']
  583. if self.custom_activation:
  584. # TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
  585. cls_scores = self.loss_cls.get_activation(cls_scores)
  586. if cls_scores.numel() == 0:
  587. return None
  588. if cls_scores.shape[-1] == self.num_classes + 1:
  589. # remove background class
  590. cls_scores = cls_scores[:, :-1]
  591. elif cls_scores.shape[-1] != self.num_classes:
  592. raise ValueError('The last dim of `cls_scores` should equal to '
  593. '`num_classes` or `num_classes + 1`,'
  594. f'but got {cls_scores.shape[-1]}.')
  595. labels = torch.where(labels == self.num_classes, cls_scores.argmax(1),
  596. labels)
  597. img_ids = rois[:, 0].long().unique(sorted=True)
  598. assert img_ids.numel() <= len(batch_img_metas)
  599. results_list = []
  600. for i in range(len(batch_img_metas)):
  601. inds = torch.nonzero(
  602. rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
  603. num_rois = inds.numel()
  604. bboxes_ = rois[inds, 1:]
  605. label_ = labels[inds]
  606. bbox_pred_ = bbox_preds[inds]
  607. img_meta_ = batch_img_metas[i]
  608. pos_is_gts_ = pos_is_gts[i]
  609. bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
  610. img_meta_)
  611. # filter gt bboxes
  612. pos_keep = 1 - pos_is_gts_
  613. keep_inds = pos_is_gts_.new_ones(num_rois)
  614. keep_inds[:len(pos_is_gts_)] = pos_keep
  615. results = InstanceData(bboxes=bboxes[keep_inds.type(torch.bool)])
  616. results_list.append(results)
  617. return results_list
  618. def regress_by_class(self, priors: Tensor, label: Tensor,
  619. bbox_pred: Tensor, img_meta: dict) -> Tensor:
  620. """Regress the bbox for the predicted class. Used in Cascade R-CNN.
  621. Args:
  622. priors (Tensor): Priors from `rpn_head` or last stage
  623. `bbox_head`, has shape (num_proposals, 4).
  624. label (Tensor): Only used when `self.reg_class_agnostic`
  625. is False, has shape (num_proposals, ).
  626. bbox_pred (Tensor): Regression prediction of
  627. current stage `bbox_head`. When `self.reg_class_agnostic`
  628. is False, it has shape (n, num_classes * 4), otherwise
  629. it has shape (n, 4).
  630. img_meta (dict): Image meta info.
  631. Returns:
  632. Tensor: Regressed bboxes, the same shape as input rois.
  633. """
  634. reg_dim = self.bbox_coder.encode_size
  635. if not self.reg_class_agnostic:
  636. label = label * reg_dim
  637. inds = torch.stack([label + i for i in range(reg_dim)], 1)
  638. bbox_pred = torch.gather(bbox_pred, 1, inds)
  639. assert bbox_pred.size()[1] == reg_dim
  640. max_shape = img_meta['img_shape']
  641. regressed_bboxes = self.bbox_coder.decode(
  642. priors, bbox_pred, max_shape=max_shape)
  643. return regressed_bboxes