sabl_retina_head.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple, Union
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from mmcv.cnn import ConvModule
  7. from mmengine.config import ConfigDict
  8. from mmengine.structures import InstanceData
  9. from torch import Tensor
  10. from mmdet.registry import MODELS, TASK_UTILS
  11. from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType,
  12. OptInstanceList)
  13. from ..task_modules.samplers import PseudoSampler
  14. from ..utils import (filter_scores_and_topk, images_to_levels, multi_apply,
  15. unmap)
  16. from .base_dense_head import BaseDenseHead
  17. from .guided_anchor_head import GuidedAnchorHead
  18. @MODELS.register_module()
  19. class SABLRetinaHead(BaseDenseHead):
  20. """Side-Aware Boundary Localization (SABL) for RetinaNet.
  21. The anchor generation, assigning and sampling in SABLRetinaHead
  22. are the same as GuidedAnchorHead for guided anchoring.
  23. Please refer to https://arxiv.org/abs/1912.04260 for more details.
  24. Args:
  25. num_classes (int): Number of classes.
  26. in_channels (int): Number of channels in the input feature map.
  27. stacked_convs (int): Number of Convs for classification and
  28. regression branches. Defaults to 4.
  29. feat_channels (int): Number of hidden channels. Defaults to 256.
  30. approx_anchor_generator (:obj:`ConfigType` or dict): Config dict for
  31. approx generator.
  32. square_anchor_generator (:obj:`ConfigDict` or dict): Config dict for
  33. square generator.
  34. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  35. ConvModule. Defaults to None.
  36. norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  37. Norm Layer. Defaults to None.
  38. bbox_coder (:obj:`ConfigDict` or dict): Config dict for bbox coder.
  39. reg_decoded_bbox (bool): If true, the regression loss would be
  40. applied directly on decoded bounding boxes, converting both
  41. the predicted boxes and regression targets to absolute
  42. coordinates format. Default False. It should be ``True`` when
  43. using ``IoULoss``, ``GIoULoss``, or ``DIoULoss`` in the bbox head.
  44. train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
  45. SABLRetinaHead.
  46. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  47. SABLRetinaHead.
  48. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
  49. loss_bbox_cls (:obj:`ConfigDict` or dict): Config of classification
  50. loss for bbox branch.
  51. loss_bbox_reg (:obj:`ConfigDict` or dict): Config of regression loss
  52. for bbox branch.
  53. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  54. dict], optional): Initialization config dict.
  55. """
  56. def __init__(
  57. self,
  58. num_classes: int,
  59. in_channels: int,
  60. stacked_convs: int = 4,
  61. feat_channels: int = 256,
  62. approx_anchor_generator: ConfigType = dict(
  63. type='AnchorGenerator',
  64. octave_base_scale=4,
  65. scales_per_octave=3,
  66. ratios=[0.5, 1.0, 2.0],
  67. strides=[8, 16, 32, 64, 128]),
  68. square_anchor_generator: ConfigType = dict(
  69. type='AnchorGenerator',
  70. ratios=[1.0],
  71. scales=[4],
  72. strides=[8, 16, 32, 64, 128]),
  73. conv_cfg: OptConfigType = None,
  74. norm_cfg: OptConfigType = None,
  75. bbox_coder: ConfigType = dict(
  76. type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0),
  77. reg_decoded_bbox: bool = False,
  78. train_cfg: OptConfigType = None,
  79. test_cfg: OptConfigType = None,
  80. loss_cls: ConfigType = dict(
  81. type='FocalLoss',
  82. use_sigmoid=True,
  83. gamma=2.0,
  84. alpha=0.25,
  85. loss_weight=1.0),
  86. loss_bbox_cls: ConfigType = dict(
  87. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5),
  88. loss_bbox_reg: ConfigType = dict(
  89. type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5),
  90. init_cfg: MultiConfig = dict(
  91. type='Normal',
  92. layer='Conv2d',
  93. std=0.01,
  94. override=dict(
  95. type='Normal', name='retina_cls', std=0.01, bias_prob=0.01))
  96. ) -> None:
  97. super().__init__(init_cfg=init_cfg)
  98. self.in_channels = in_channels
  99. self.num_classes = num_classes
  100. self.feat_channels = feat_channels
  101. self.num_buckets = bbox_coder['num_buckets']
  102. self.side_num = int(np.ceil(self.num_buckets / 2))
  103. assert (approx_anchor_generator['octave_base_scale'] ==
  104. square_anchor_generator['scales'][0])
  105. assert (approx_anchor_generator['strides'] ==
  106. square_anchor_generator['strides'])
  107. self.approx_anchor_generator = TASK_UTILS.build(
  108. approx_anchor_generator)
  109. self.square_anchor_generator = TASK_UTILS.build(
  110. square_anchor_generator)
  111. self.approxs_per_octave = (
  112. self.approx_anchor_generator.num_base_priors[0])
  113. # one anchor per location
  114. self.num_base_priors = self.square_anchor_generator.num_base_priors[0]
  115. self.stacked_convs = stacked_convs
  116. self.conv_cfg = conv_cfg
  117. self.norm_cfg = norm_cfg
  118. self.reg_decoded_bbox = reg_decoded_bbox
  119. self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
  120. if self.use_sigmoid_cls:
  121. self.cls_out_channels = num_classes
  122. else:
  123. self.cls_out_channels = num_classes + 1
  124. self.bbox_coder = TASK_UTILS.build(bbox_coder)
  125. self.loss_cls = MODELS.build(loss_cls)
  126. self.loss_bbox_cls = MODELS.build(loss_bbox_cls)
  127. self.loss_bbox_reg = MODELS.build(loss_bbox_reg)
  128. self.train_cfg = train_cfg
  129. self.test_cfg = test_cfg
  130. if self.train_cfg:
  131. self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
  132. # use PseudoSampler when sampling is False
  133. if 'sampler' in self.train_cfg:
  134. self.sampler = TASK_UTILS.build(
  135. self.train_cfg['sampler'], default_args=dict(context=self))
  136. else:
  137. self.sampler = PseudoSampler(context=self)
  138. self._init_layers()
  139. def _init_layers(self) -> None:
  140. self.relu = nn.ReLU(inplace=True)
  141. self.cls_convs = nn.ModuleList()
  142. self.reg_convs = nn.ModuleList()
  143. for i in range(self.stacked_convs):
  144. chn = self.in_channels if i == 0 else self.feat_channels
  145. self.cls_convs.append(
  146. ConvModule(
  147. chn,
  148. self.feat_channels,
  149. 3,
  150. stride=1,
  151. padding=1,
  152. conv_cfg=self.conv_cfg,
  153. norm_cfg=self.norm_cfg))
  154. self.reg_convs.append(
  155. ConvModule(
  156. chn,
  157. self.feat_channels,
  158. 3,
  159. stride=1,
  160. padding=1,
  161. conv_cfg=self.conv_cfg,
  162. norm_cfg=self.norm_cfg))
  163. self.retina_cls = nn.Conv2d(
  164. self.feat_channels, self.cls_out_channels, 3, padding=1)
  165. self.retina_bbox_reg = nn.Conv2d(
  166. self.feat_channels, self.side_num * 4, 3, padding=1)
  167. self.retina_bbox_cls = nn.Conv2d(
  168. self.feat_channels, self.side_num * 4, 3, padding=1)
  169. def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]:
  170. cls_feat = x
  171. reg_feat = x
  172. for cls_conv in self.cls_convs:
  173. cls_feat = cls_conv(cls_feat)
  174. for reg_conv in self.reg_convs:
  175. reg_feat = reg_conv(reg_feat)
  176. cls_score = self.retina_cls(cls_feat)
  177. bbox_cls_pred = self.retina_bbox_cls(reg_feat)
  178. bbox_reg_pred = self.retina_bbox_reg(reg_feat)
  179. bbox_pred = (bbox_cls_pred, bbox_reg_pred)
  180. return cls_score, bbox_pred
  181. def forward(self, feats: List[Tensor]) -> Tuple[List[Tensor]]:
  182. return multi_apply(self.forward_single, feats)
  183. def get_anchors(
  184. self,
  185. featmap_sizes: List[tuple],
  186. img_metas: List[dict],
  187. device: Union[torch.device, str] = 'cuda'
  188. ) -> Tuple[List[List[Tensor]], List[List[Tensor]]]:
  189. """Get squares according to feature map sizes and guided anchors.
  190. Args:
  191. featmap_sizes (list[tuple]): Multi-level feature map sizes.
  192. img_metas (list[dict]): Image meta info.
  193. device (torch.device | str): device for returned tensors
  194. Returns:
  195. tuple: square approxs of each image
  196. """
  197. num_imgs = len(img_metas)
  198. # since feature map sizes of all images are the same, we only compute
  199. # squares for one time
  200. multi_level_squares = self.square_anchor_generator.grid_priors(
  201. featmap_sizes, device=device)
  202. squares_list = [multi_level_squares for _ in range(num_imgs)]
  203. return squares_list
  204. def get_targets(self,
  205. approx_list: List[List[Tensor]],
  206. inside_flag_list: List[List[Tensor]],
  207. square_list: List[List[Tensor]],
  208. batch_gt_instances: InstanceList,
  209. batch_img_metas,
  210. batch_gt_instances_ignore: OptInstanceList = None,
  211. unmap_outputs=True) -> tuple:
  212. """Compute bucketing targets.
  213. Args:
  214. approx_list (list[list[Tensor]]): Multi level approxs of each
  215. image.
  216. inside_flag_list (list[list[Tensor]]): Multi level inside flags of
  217. each image.
  218. square_list (list[list[Tensor]]): Multi level squares of each
  219. image.
  220. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  221. gt_instance. It usually includes ``bboxes`` and ``labels``
  222. attributes.
  223. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  224. image size, scaling factor, etc.
  225. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  226. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  227. data that is ignored during training and testing.
  228. Defaults to None.
  229. unmap_outputs (bool): Whether to map outputs back to the original
  230. set of anchors. Defaults to True.
  231. Returns:
  232. tuple: Returns a tuple containing learning targets.
  233. - labels_list (list[Tensor]): Labels of each level.
  234. - label_weights_list (list[Tensor]): Label weights of each level.
  235. - bbox_cls_targets_list (list[Tensor]): BBox cls targets of \
  236. each level.
  237. - bbox_cls_weights_list (list[Tensor]): BBox cls weights of \
  238. each level.
  239. - bbox_reg_targets_list (list[Tensor]): BBox reg targets of \
  240. each level.
  241. - bbox_reg_weights_list (list[Tensor]): BBox reg weights of \
  242. each level.
  243. - num_total_pos (int): Number of positive samples in all images.
  244. - num_total_neg (int): Number of negative samples in all images.
  245. """
  246. num_imgs = len(batch_img_metas)
  247. assert len(approx_list) == len(inside_flag_list) == len(
  248. square_list) == num_imgs
  249. # anchor number of multi levels
  250. num_level_squares = [squares.size(0) for squares in square_list[0]]
  251. # concat all level anchors and flags to a single tensor
  252. inside_flag_flat_list = []
  253. approx_flat_list = []
  254. square_flat_list = []
  255. for i in range(num_imgs):
  256. assert len(square_list[i]) == len(inside_flag_list[i])
  257. inside_flag_flat_list.append(torch.cat(inside_flag_list[i]))
  258. approx_flat_list.append(torch.cat(approx_list[i]))
  259. square_flat_list.append(torch.cat(square_list[i]))
  260. # compute targets for each image
  261. if batch_gt_instances_ignore is None:
  262. batch_gt_instances_ignore = [None for _ in range(num_imgs)]
  263. (all_labels, all_label_weights, all_bbox_cls_targets,
  264. all_bbox_cls_weights, all_bbox_reg_targets, all_bbox_reg_weights,
  265. pos_inds_list, neg_inds_list, sampling_results_list) = multi_apply(
  266. self._get_targets_single,
  267. approx_flat_list,
  268. inside_flag_flat_list,
  269. square_flat_list,
  270. batch_gt_instances,
  271. batch_img_metas,
  272. batch_gt_instances_ignore,
  273. unmap_outputs=unmap_outputs)
  274. # sampled anchors of all images
  275. avg_factor = sum(
  276. [results.avg_factor for results in sampling_results_list])
  277. # split targets to a list w.r.t. multiple levels
  278. labels_list = images_to_levels(all_labels, num_level_squares)
  279. label_weights_list = images_to_levels(all_label_weights,
  280. num_level_squares)
  281. bbox_cls_targets_list = images_to_levels(all_bbox_cls_targets,
  282. num_level_squares)
  283. bbox_cls_weights_list = images_to_levels(all_bbox_cls_weights,
  284. num_level_squares)
  285. bbox_reg_targets_list = images_to_levels(all_bbox_reg_targets,
  286. num_level_squares)
  287. bbox_reg_weights_list = images_to_levels(all_bbox_reg_weights,
  288. num_level_squares)
  289. return (labels_list, label_weights_list, bbox_cls_targets_list,
  290. bbox_cls_weights_list, bbox_reg_targets_list,
  291. bbox_reg_weights_list, avg_factor)
  292. def _get_targets_single(self,
  293. flat_approxs: Tensor,
  294. inside_flags: Tensor,
  295. flat_squares: Tensor,
  296. gt_instances: InstanceData,
  297. img_meta: dict,
  298. gt_instances_ignore: Optional[InstanceData] = None,
  299. unmap_outputs: bool = True) -> tuple:
  300. """Compute regression and classification targets for anchors in a
  301. single image.
  302. Args:
  303. flat_approxs (Tensor): flat approxs of a single image,
  304. shape (n, 4)
  305. inside_flags (Tensor): inside flags of a single image,
  306. shape (n, ).
  307. flat_squares (Tensor): flat squares of a single image,
  308. shape (approxs_per_octave * n, 4)
  309. gt_instances (:obj:`InstanceData`): Ground truth of instance
  310. annotations. It should includes ``bboxes`` and ``labels``
  311. attributes.
  312. img_meta (dict): Meta information for current image.
  313. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  314. to be ignored during training. It includes ``bboxes`` attribute
  315. data that is ignored during training and testing.
  316. Defaults to None.
  317. unmap_outputs (bool): Whether to map outputs back to the original
  318. set of anchors. Defaults to True.
  319. Returns:
  320. tuple:
  321. - labels_list (Tensor): Labels in a single image.
  322. - label_weights (Tensor): Label weights in a single image.
  323. - bbox_cls_targets (Tensor): BBox cls targets in a single image.
  324. - bbox_cls_weights (Tensor): BBox cls weights in a single image.
  325. - bbox_reg_targets (Tensor): BBox reg targets in a single image.
  326. - bbox_reg_weights (Tensor): BBox reg weights in a single image.
  327. - num_total_pos (int): Number of positive samples in a single \
  328. image.
  329. - num_total_neg (int): Number of negative samples in a single \
  330. image.
  331. - sampling_result (:obj:`SamplingResult`): Sampling result object.
  332. """
  333. if not inside_flags.any():
  334. raise ValueError(
  335. 'There is no valid anchor inside the image boundary. Please '
  336. 'check the image size and anchor sizes, or set '
  337. '``allowed_border`` to -1 to skip the condition.')
  338. # assign gt and sample anchors
  339. num_square = flat_squares.size(0)
  340. approxs = flat_approxs.view(num_square, self.approxs_per_octave, 4)
  341. approxs = approxs[inside_flags, ...]
  342. squares = flat_squares[inside_flags, :]
  343. pred_instances = InstanceData()
  344. pred_instances.priors = squares
  345. pred_instances.approxs = approxs
  346. assign_result = self.assigner.assign(pred_instances, gt_instances,
  347. gt_instances_ignore)
  348. sampling_result = self.sampler.sample(assign_result, pred_instances,
  349. gt_instances)
  350. num_valid_squares = squares.shape[0]
  351. bbox_cls_targets = squares.new_zeros(
  352. (num_valid_squares, self.side_num * 4))
  353. bbox_cls_weights = squares.new_zeros(
  354. (num_valid_squares, self.side_num * 4))
  355. bbox_reg_targets = squares.new_zeros(
  356. (num_valid_squares, self.side_num * 4))
  357. bbox_reg_weights = squares.new_zeros(
  358. (num_valid_squares, self.side_num * 4))
  359. labels = squares.new_full((num_valid_squares, ),
  360. self.num_classes,
  361. dtype=torch.long)
  362. label_weights = squares.new_zeros(num_valid_squares, dtype=torch.float)
  363. pos_inds = sampling_result.pos_inds
  364. neg_inds = sampling_result.neg_inds
  365. if len(pos_inds) > 0:
  366. (pos_bbox_reg_targets, pos_bbox_reg_weights, pos_bbox_cls_targets,
  367. pos_bbox_cls_weights) = self.bbox_coder.encode(
  368. sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
  369. bbox_cls_targets[pos_inds, :] = pos_bbox_cls_targets
  370. bbox_reg_targets[pos_inds, :] = pos_bbox_reg_targets
  371. bbox_cls_weights[pos_inds, :] = pos_bbox_cls_weights
  372. bbox_reg_weights[pos_inds, :] = pos_bbox_reg_weights
  373. labels[pos_inds] = sampling_result.pos_gt_labels
  374. if self.train_cfg['pos_weight'] <= 0:
  375. label_weights[pos_inds] = 1.0
  376. else:
  377. label_weights[pos_inds] = self.train_cfg['pos_weight']
  378. if len(neg_inds) > 0:
  379. label_weights[neg_inds] = 1.0
  380. # map up to original set of anchors
  381. if unmap_outputs:
  382. num_total_anchors = flat_squares.size(0)
  383. labels = unmap(
  384. labels, num_total_anchors, inside_flags, fill=self.num_classes)
  385. label_weights = unmap(label_weights, num_total_anchors,
  386. inside_flags)
  387. bbox_cls_targets = unmap(bbox_cls_targets, num_total_anchors,
  388. inside_flags)
  389. bbox_cls_weights = unmap(bbox_cls_weights, num_total_anchors,
  390. inside_flags)
  391. bbox_reg_targets = unmap(bbox_reg_targets, num_total_anchors,
  392. inside_flags)
  393. bbox_reg_weights = unmap(bbox_reg_weights, num_total_anchors,
  394. inside_flags)
  395. return (labels, label_weights, bbox_cls_targets, bbox_cls_weights,
  396. bbox_reg_targets, bbox_reg_weights, pos_inds, neg_inds,
  397. sampling_result)
  398. def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
  399. labels: Tensor, label_weights: Tensor,
  400. bbox_cls_targets: Tensor, bbox_cls_weights: Tensor,
  401. bbox_reg_targets: Tensor, bbox_reg_weights: Tensor,
  402. avg_factor: float) -> Tuple[Tensor]:
  403. """Calculate the loss of a single scale level based on the features
  404. extracted by the detection head.
  405. Args:
  406. cls_score (Tensor): Box scores for each scale level
  407. Has shape (N, num_anchors * num_classes, H, W).
  408. bbox_pred (Tensor): Box energies / deltas for each scale
  409. level with shape (N, num_anchors * 4, H, W).
  410. labels (Tensor): Labels in a single image.
  411. label_weights (Tensor): Label weights in a single level.
  412. bbox_cls_targets (Tensor): BBox cls targets in a single level.
  413. bbox_cls_weights (Tensor): BBox cls weights in a single level.
  414. bbox_reg_targets (Tensor): BBox reg targets in a single level.
  415. bbox_reg_weights (Tensor): BBox reg weights in a single level.
  416. avg_factor (int): Average factor that is used to average the loss.
  417. Returns:
  418. tuple: loss components.
  419. """
  420. # classification loss
  421. labels = labels.reshape(-1)
  422. label_weights = label_weights.reshape(-1)
  423. cls_score = cls_score.permute(0, 2, 3,
  424. 1).reshape(-1, self.cls_out_channels)
  425. loss_cls = self.loss_cls(
  426. cls_score, labels, label_weights, avg_factor=avg_factor)
  427. # regression loss
  428. bbox_cls_targets = bbox_cls_targets.reshape(-1, self.side_num * 4)
  429. bbox_cls_weights = bbox_cls_weights.reshape(-1, self.side_num * 4)
  430. bbox_reg_targets = bbox_reg_targets.reshape(-1, self.side_num * 4)
  431. bbox_reg_weights = bbox_reg_weights.reshape(-1, self.side_num * 4)
  432. (bbox_cls_pred, bbox_reg_pred) = bbox_pred
  433. bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape(
  434. -1, self.side_num * 4)
  435. bbox_reg_pred = bbox_reg_pred.permute(0, 2, 3, 1).reshape(
  436. -1, self.side_num * 4)
  437. loss_bbox_cls = self.loss_bbox_cls(
  438. bbox_cls_pred,
  439. bbox_cls_targets.long(),
  440. bbox_cls_weights,
  441. avg_factor=avg_factor * 4 * self.side_num)
  442. loss_bbox_reg = self.loss_bbox_reg(
  443. bbox_reg_pred,
  444. bbox_reg_targets,
  445. bbox_reg_weights,
  446. avg_factor=avg_factor * 4 * self.bbox_coder.offset_topk)
  447. return loss_cls, loss_bbox_cls, loss_bbox_reg
  448. def loss_by_feat(
  449. self,
  450. cls_scores: List[Tensor],
  451. bbox_preds: List[Tensor],
  452. batch_gt_instances: InstanceList,
  453. batch_img_metas: List[dict],
  454. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  455. """Calculate the loss based on the features extracted by the detection
  456. head.
  457. Args:
  458. cls_scores (list[Tensor]): Box scores for each scale level
  459. has shape (N, num_anchors * num_classes, H, W).
  460. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  461. level with shape (N, num_anchors * 4, H, W).
  462. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  463. gt_instance. It usually includes ``bboxes`` and ``labels``
  464. attributes.
  465. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  466. image size, scaling factor, etc.
  467. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  468. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  469. data that is ignored during training and testing.
  470. Defaults to None.
  471. Returns:
  472. dict: A dictionary of loss components.
  473. """
  474. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  475. assert len(featmap_sizes) == self.approx_anchor_generator.num_levels
  476. device = cls_scores[0].device
  477. # get sampled approxes
  478. approxs_list, inside_flag_list = GuidedAnchorHead.get_sampled_approxs(
  479. self, featmap_sizes, batch_img_metas, device=device)
  480. square_list = self.get_anchors(
  481. featmap_sizes, batch_img_metas, device=device)
  482. cls_reg_targets = self.get_targets(
  483. approxs_list,
  484. inside_flag_list,
  485. square_list,
  486. batch_gt_instances,
  487. batch_img_metas,
  488. batch_gt_instances_ignore=batch_gt_instances_ignore)
  489. (labels_list, label_weights_list, bbox_cls_targets_list,
  490. bbox_cls_weights_list, bbox_reg_targets_list, bbox_reg_weights_list,
  491. avg_factor) = cls_reg_targets
  492. losses_cls, losses_bbox_cls, losses_bbox_reg = multi_apply(
  493. self.loss_by_feat_single,
  494. cls_scores,
  495. bbox_preds,
  496. labels_list,
  497. label_weights_list,
  498. bbox_cls_targets_list,
  499. bbox_cls_weights_list,
  500. bbox_reg_targets_list,
  501. bbox_reg_weights_list,
  502. avg_factor=avg_factor)
  503. return dict(
  504. loss_cls=losses_cls,
  505. loss_bbox_cls=losses_bbox_cls,
  506. loss_bbox_reg=losses_bbox_reg)
  507. def predict_by_feat(self,
  508. cls_scores: List[Tensor],
  509. bbox_preds: List[Tensor],
  510. batch_img_metas: List[dict],
  511. cfg: Optional[ConfigDict] = None,
  512. rescale: bool = False,
  513. with_nms: bool = True) -> InstanceList:
  514. """Transform a batch of output features extracted from the head into
  515. bbox results.
  516. Note: When score_factors is not None, the cls_scores are
  517. usually multiplied by it then obtain the real score used in NMS,
  518. such as CenterNess in FCOS, IoU branch in ATSS.
  519. Args:
  520. cls_scores (list[Tensor]): Classification scores for all
  521. scale levels, each is a 4D-tensor, has shape
  522. (batch_size, num_priors * num_classes, H, W).
  523. bbox_preds (list[Tensor]): Box energies / deltas for all
  524. scale levels, each is a 4D-tensor, has shape
  525. (batch_size, num_priors * 4, H, W).
  526. batch_img_metas (list[dict], Optional): Batch image meta info.
  527. cfg (:obj:`ConfigDict`, optional): Test / postprocessing
  528. configuration, if None, test_cfg would be used.
  529. Defaults to None.
  530. rescale (bool): If True, return boxes in original image space.
  531. Defaults to False.
  532. with_nms (bool): If True, do nms before return boxes.
  533. Defaults to True.
  534. Returns:
  535. list[:obj:`InstanceData`]: Object detection results of each image
  536. after the post process. Each item usually contains following keys.
  537. - scores (Tensor): Classification scores, has a shape
  538. (num_instance, )
  539. - labels (Tensor): Labels of bboxes, has a shape
  540. (num_instances, ).
  541. - bboxes (Tensor): Has a shape (num_instances, 4),
  542. the last dimension 4 arrange as (x1, y1, x2, y2).
  543. """
  544. assert len(cls_scores) == len(bbox_preds)
  545. num_levels = len(cls_scores)
  546. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  547. device = cls_scores[0].device
  548. mlvl_anchors = self.get_anchors(
  549. featmap_sizes, batch_img_metas, device=device)
  550. result_list = []
  551. for img_id in range(len(batch_img_metas)):
  552. cls_score_list = [
  553. cls_scores[i][img_id].detach() for i in range(num_levels)
  554. ]
  555. bbox_cls_pred_list = [
  556. bbox_preds[i][0][img_id].detach() for i in range(num_levels)
  557. ]
  558. bbox_reg_pred_list = [
  559. bbox_preds[i][1][img_id].detach() for i in range(num_levels)
  560. ]
  561. proposals = self._predict_by_feat_single(
  562. cls_scores=cls_score_list,
  563. bbox_cls_preds=bbox_cls_pred_list,
  564. bbox_reg_preds=bbox_reg_pred_list,
  565. mlvl_anchors=mlvl_anchors[img_id],
  566. img_meta=batch_img_metas[img_id],
  567. cfg=cfg,
  568. rescale=rescale,
  569. with_nms=with_nms)
  570. result_list.append(proposals)
  571. return result_list
  572. def _predict_by_feat_single(self,
  573. cls_scores: List[Tensor],
  574. bbox_cls_preds: List[Tensor],
  575. bbox_reg_preds: List[Tensor],
  576. mlvl_anchors: List[Tensor],
  577. img_meta: dict,
  578. cfg: ConfigDict,
  579. rescale: bool = False,
  580. with_nms: bool = True) -> InstanceData:
  581. cfg = self.test_cfg if cfg is None else cfg
  582. nms_pre = cfg.get('nms_pre', -1)
  583. mlvl_bboxes = []
  584. mlvl_scores = []
  585. mlvl_confids = []
  586. mlvl_labels = []
  587. assert len(cls_scores) == len(bbox_cls_preds) == len(
  588. bbox_reg_preds) == len(mlvl_anchors)
  589. for cls_score, bbox_cls_pred, bbox_reg_pred, anchors in zip(
  590. cls_scores, bbox_cls_preds, bbox_reg_preds, mlvl_anchors):
  591. assert cls_score.size()[-2:] == bbox_cls_pred.size(
  592. )[-2:] == bbox_reg_pred.size()[-2::]
  593. cls_score = cls_score.permute(1, 2,
  594. 0).reshape(-1, self.cls_out_channels)
  595. if self.use_sigmoid_cls:
  596. scores = cls_score.sigmoid()
  597. else:
  598. scores = cls_score.softmax(-1)[:, :-1]
  599. bbox_cls_pred = bbox_cls_pred.permute(1, 2, 0).reshape(
  600. -1, self.side_num * 4)
  601. bbox_reg_pred = bbox_reg_pred.permute(1, 2, 0).reshape(
  602. -1, self.side_num * 4)
  603. # After https://github.com/open-mmlab/mmdetection/pull/6268/,
  604. # this operation keeps fewer bboxes under the same `nms_pre`.
  605. # There is no difference in performance for most models. If you
  606. # find a slight drop in performance, you can set a larger
  607. # `nms_pre` than before.
  608. results = filter_scores_and_topk(
  609. scores, cfg.score_thr, nms_pre,
  610. dict(
  611. anchors=anchors,
  612. bbox_cls_pred=bbox_cls_pred,
  613. bbox_reg_pred=bbox_reg_pred))
  614. scores, labels, _, filtered_results = results
  615. anchors = filtered_results['anchors']
  616. bbox_cls_pred = filtered_results['bbox_cls_pred']
  617. bbox_reg_pred = filtered_results['bbox_reg_pred']
  618. bbox_preds = [
  619. bbox_cls_pred.contiguous(),
  620. bbox_reg_pred.contiguous()
  621. ]
  622. bboxes, confids = self.bbox_coder.decode(
  623. anchors.contiguous(),
  624. bbox_preds,
  625. max_shape=img_meta['img_shape'])
  626. mlvl_bboxes.append(bboxes)
  627. mlvl_scores.append(scores)
  628. mlvl_confids.append(confids)
  629. mlvl_labels.append(labels)
  630. results = InstanceData()
  631. results.bboxes = torch.cat(mlvl_bboxes)
  632. results.scores = torch.cat(mlvl_scores)
  633. results.score_factors = torch.cat(mlvl_confids)
  634. results.labels = torch.cat(mlvl_labels)
  635. return self._bbox_post_process(
  636. results=results,
  637. cfg=cfg,
  638. rescale=rescale,
  639. with_nms=with_nms,
  640. img_meta=img_meta)