gfl_head.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Sequence, Tuple
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule, Scale
  7. from mmengine.config import ConfigDict
  8. from mmengine.structures import InstanceData
  9. from torch import Tensor
  10. from mmdet.registry import MODELS, TASK_UTILS
  11. from mmdet.structures.bbox import bbox_overlaps
  12. from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType,
  13. OptInstanceList, reduce_mean)
  14. from ..task_modules.prior_generators import anchor_inside_flags
  15. from ..task_modules.samplers import PseudoSampler
  16. from ..utils import (filter_scores_and_topk, images_to_levels, multi_apply,
  17. unmap)
  18. from .anchor_head import AnchorHead
  19. class Integral(nn.Module):
  20. """A fixed layer for calculating integral result from distribution.
  21. This layer calculates the target location by :math: ``sum{P(y_i) * y_i}``,
  22. P(y_i) denotes the softmax vector that represents the discrete distribution
  23. y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max}
  24. Args:
  25. reg_max (int): The maximal value of the discrete set. Defaults to 16.
  26. You may want to reset it according to your new dataset or related
  27. settings.
  28. """
  29. def __init__(self, reg_max: int = 16) -> None:
  30. super().__init__()
  31. self.reg_max = reg_max
  32. self.register_buffer('project',
  33. torch.linspace(0, self.reg_max, self.reg_max + 1))
  34. def forward(self, x: Tensor) -> Tensor:
  35. """Forward feature from the regression head to get integral result of
  36. bounding box location.
  37. Args:
  38. x (Tensor): Features of the regression head, shape (N, 4*(n+1)),
  39. n is self.reg_max.
  40. Returns:
  41. x (Tensor): Integral result of box locations, i.e., distance
  42. offsets from the box center in four directions, shape (N, 4).
  43. """
  44. x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1)
  45. x = F.linear(x, self.project.type_as(x)).reshape(-1, 4)
  46. return x
  47. @MODELS.register_module()
  48. class GFLHead(AnchorHead):
  49. """Generalized Focal Loss: Learning Qualified and Distributed Bounding
  50. Boxes for Dense Object Detection.
  51. GFL head structure is similar with ATSS, however GFL uses
  52. 1) joint representation for classification and localization quality, and
  53. 2) flexible General distribution for bounding box locations,
  54. which are supervised by
  55. Quality Focal Loss (QFL) and Distribution Focal Loss (DFL), respectively
  56. https://arxiv.org/abs/2006.04388
  57. Args:
  58. num_classes (int): Number of categories excluding the background
  59. category.
  60. in_channels (int): Number of channels in the input feature map.
  61. stacked_convs (int): Number of conv layers in cls and reg tower.
  62. Defaults to 4.
  63. conv_cfg (:obj:`ConfigDict` or dict, optional): dictionary to construct
  64. and config conv layer. Defaults to None.
  65. norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and
  66. config norm layer. Default: dict(type='GN', num_groups=32,
  67. requires_grad=True).
  68. loss_qfl (:obj:`ConfigDict` or dict): Config of Quality Focal Loss
  69. (QFL).
  70. bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. Defaults
  71. to 'DistancePointBBoxCoder'.
  72. reg_max (int): Max value of integral set :math: ``{0, ..., reg_max}``
  73. in QFL setting. Defaults to 16.
  74. init_cfg (:obj:`ConfigDict` or dict or list[dict] or
  75. list[:obj:`ConfigDict`]): Initialization config dict.
  76. Example:
  77. >>> self = GFLHead(11, 7)
  78. >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
  79. >>> cls_quality_score, bbox_pred = self.forward(feats)
  80. >>> assert len(cls_quality_score) == len(self.scales)
  81. """
  82. def __init__(self,
  83. num_classes: int,
  84. in_channels: int,
  85. stacked_convs: int = 4,
  86. conv_cfg: OptConfigType = None,
  87. norm_cfg: ConfigType = dict(
  88. type='GN', num_groups=32, requires_grad=True),
  89. loss_dfl: ConfigType = dict(
  90. type='DistributionFocalLoss', loss_weight=0.25),
  91. bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
  92. reg_max: int = 16,
  93. init_cfg: MultiConfig = dict(
  94. type='Normal',
  95. layer='Conv2d',
  96. std=0.01,
  97. override=dict(
  98. type='Normal',
  99. name='gfl_cls',
  100. std=0.01,
  101. bias_prob=0.01)),
  102. **kwargs) -> None:
  103. self.stacked_convs = stacked_convs
  104. self.conv_cfg = conv_cfg
  105. self.norm_cfg = norm_cfg
  106. self.reg_max = reg_max
  107. super().__init__(
  108. num_classes=num_classes,
  109. in_channels=in_channels,
  110. bbox_coder=bbox_coder,
  111. init_cfg=init_cfg,
  112. **kwargs)
  113. if self.train_cfg:
  114. self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
  115. if self.train_cfg.get('sampler', None) is not None:
  116. self.sampler = TASK_UTILS.build(
  117. self.train_cfg['sampler'], default_args=dict(context=self))
  118. else:
  119. self.sampler = PseudoSampler(context=self)
  120. self.integral = Integral(self.reg_max)
  121. self.loss_dfl = MODELS.build(loss_dfl)
  122. def _init_layers(self) -> None:
  123. """Initialize layers of the head."""
  124. self.relu = nn.ReLU()
  125. self.cls_convs = nn.ModuleList()
  126. self.reg_convs = nn.ModuleList()
  127. for i in range(self.stacked_convs):
  128. chn = self.in_channels if i == 0 else self.feat_channels
  129. self.cls_convs.append(
  130. ConvModule(
  131. chn,
  132. self.feat_channels,
  133. 3,
  134. stride=1,
  135. padding=1,
  136. conv_cfg=self.conv_cfg,
  137. norm_cfg=self.norm_cfg))
  138. self.reg_convs.append(
  139. ConvModule(
  140. chn,
  141. self.feat_channels,
  142. 3,
  143. stride=1,
  144. padding=1,
  145. conv_cfg=self.conv_cfg,
  146. norm_cfg=self.norm_cfg))
  147. assert self.num_anchors == 1, 'anchor free version'
  148. self.gfl_cls = nn.Conv2d(
  149. self.feat_channels, self.cls_out_channels, 3, padding=1)
  150. self.gfl_reg = nn.Conv2d(
  151. self.feat_channels, 4 * (self.reg_max + 1), 3, padding=1)
  152. self.scales = nn.ModuleList(
  153. [Scale(1.0) for _ in self.prior_generator.strides])
  154. def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]:
  155. """Forward features from the upstream network.
  156. Args:
  157. x (tuple[Tensor]): Features from the upstream network, each is
  158. a 4D-tensor.
  159. Returns:
  160. tuple: Usually a tuple of classification scores and bbox prediction
  161. - cls_scores (list[Tensor]): Classification and quality (IoU)
  162. joint scores for all scale levels, each is a 4D-tensor,
  163. the channel number is num_classes.
  164. - bbox_preds (list[Tensor]): Box distribution logits for all
  165. scale levels, each is a 4D-tensor, the channel number is
  166. 4*(n+1), n is max value of integral set.
  167. """
  168. return multi_apply(self.forward_single, x, self.scales)
  169. def forward_single(self, x: Tensor, scale: Scale) -> Sequence[Tensor]:
  170. """Forward feature of a single scale level.
  171. Args:
  172. x (Tensor): Features of a single scale level.
  173. scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
  174. the bbox prediction.
  175. Returns:
  176. tuple:
  177. - cls_score (Tensor): Cls and quality joint scores for a single
  178. scale level the channel number is num_classes.
  179. - bbox_pred (Tensor): Box distribution logits for a single scale
  180. level, the channel number is 4*(n+1), n is max value of
  181. integral set.
  182. """
  183. cls_feat = x
  184. reg_feat = x
  185. for cls_conv in self.cls_convs:
  186. cls_feat = cls_conv(cls_feat)
  187. for reg_conv in self.reg_convs:
  188. reg_feat = reg_conv(reg_feat)
  189. cls_score = self.gfl_cls(cls_feat)
  190. bbox_pred = scale(self.gfl_reg(reg_feat)).float()
  191. return cls_score, bbox_pred
  192. def anchor_center(self, anchors: Tensor) -> Tensor:
  193. """Get anchor centers from anchors.
  194. Args:
  195. anchors (Tensor): Anchor list with shape (N, 4), ``xyxy`` format.
  196. Returns:
  197. Tensor: Anchor centers with shape (N, 2), ``xy`` format.
  198. """
  199. anchors_cx = (anchors[..., 2] + anchors[..., 0]) / 2
  200. anchors_cy = (anchors[..., 3] + anchors[..., 1]) / 2
  201. return torch.stack([anchors_cx, anchors_cy], dim=-1)
  202. def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor,
  203. bbox_pred: Tensor, labels: Tensor,
  204. label_weights: Tensor, bbox_targets: Tensor,
  205. stride: Tuple[int], avg_factor: int) -> dict:
  206. """Calculate the loss of a single scale level based on the features
  207. extracted by the detection head.
  208. Args:
  209. anchors (Tensor): Box reference for each scale level with shape
  210. (N, num_total_anchors, 4).
  211. cls_score (Tensor): Cls and quality joint scores for each scale
  212. level has shape (N, num_classes, H, W).
  213. bbox_pred (Tensor): Box distribution logits for each scale
  214. level with shape (N, 4*(n+1), H, W), n is max value of integral
  215. set.
  216. labels (Tensor): Labels of each anchors with shape
  217. (N, num_total_anchors).
  218. label_weights (Tensor): Label weights of each anchor with shape
  219. (N, num_total_anchors)
  220. bbox_targets (Tensor): BBox regression targets of each anchor
  221. weight shape (N, num_total_anchors, 4).
  222. stride (Tuple[int]): Stride in this scale level.
  223. avg_factor (int): Average factor that is used to average
  224. the loss. When using sampling method, avg_factor is usually
  225. the sum of positive and negative priors. When using
  226. `PseudoSampler`, `avg_factor` is usually equal to the number
  227. of positive priors.
  228. Returns:
  229. dict[str, Tensor]: A dictionary of loss components.
  230. """
  231. assert stride[0] == stride[1], 'h stride is not equal to w stride!'
  232. anchors = anchors.reshape(-1, 4)
  233. cls_score = cls_score.permute(0, 2, 3,
  234. 1).reshape(-1, self.cls_out_channels)
  235. bbox_pred = bbox_pred.permute(0, 2, 3,
  236. 1).reshape(-1, 4 * (self.reg_max + 1))
  237. bbox_targets = bbox_targets.reshape(-1, 4)
  238. labels = labels.reshape(-1)
  239. label_weights = label_weights.reshape(-1)
  240. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  241. bg_class_ind = self.num_classes
  242. pos_inds = ((labels >= 0)
  243. & (labels < bg_class_ind)).nonzero().squeeze(1)
  244. score = label_weights.new_zeros(labels.shape)
  245. if len(pos_inds) > 0:
  246. pos_bbox_targets = bbox_targets[pos_inds]
  247. pos_bbox_pred = bbox_pred[pos_inds]
  248. pos_anchors = anchors[pos_inds]
  249. pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]
  250. weight_targets = cls_score.detach().sigmoid()
  251. weight_targets = weight_targets.max(dim=1)[0][pos_inds]
  252. pos_bbox_pred_corners = self.integral(pos_bbox_pred)
  253. pos_decode_bbox_pred = self.bbox_coder.decode(
  254. pos_anchor_centers, pos_bbox_pred_corners)
  255. pos_decode_bbox_targets = pos_bbox_targets / stride[0]
  256. score[pos_inds] = bbox_overlaps(
  257. pos_decode_bbox_pred.detach(),
  258. pos_decode_bbox_targets,
  259. is_aligned=True)
  260. pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
  261. target_corners = self.bbox_coder.encode(pos_anchor_centers,
  262. pos_decode_bbox_targets,
  263. self.reg_max).reshape(-1)
  264. # regression loss
  265. loss_bbox = self.loss_bbox(
  266. pos_decode_bbox_pred,
  267. pos_decode_bbox_targets,
  268. weight=weight_targets,
  269. avg_factor=1.0)
  270. # dfl loss
  271. loss_dfl = self.loss_dfl(
  272. pred_corners,
  273. target_corners,
  274. weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
  275. avg_factor=4.0)
  276. else:
  277. loss_bbox = bbox_pred.sum() * 0
  278. loss_dfl = bbox_pred.sum() * 0
  279. weight_targets = bbox_pred.new_tensor(0)
  280. # cls (qfl) loss
  281. loss_cls = self.loss_cls(
  282. cls_score, (labels, score),
  283. weight=label_weights,
  284. avg_factor=avg_factor)
  285. return loss_cls, loss_bbox, loss_dfl, weight_targets.sum()
  286. def loss_by_feat(
  287. self,
  288. cls_scores: List[Tensor],
  289. bbox_preds: List[Tensor],
  290. batch_gt_instances: InstanceList,
  291. batch_img_metas: List[dict],
  292. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  293. """Calculate the loss based on the features extracted by the detection
  294. head.
  295. Args:
  296. cls_scores (list[Tensor]): Cls and quality scores for each scale
  297. level has shape (N, num_classes, H, W).
  298. bbox_preds (list[Tensor]): Box distribution logits for each scale
  299. level with shape (N, 4*(n+1), H, W), n is max value of integral
  300. set.
  301. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  302. gt_instance. It usually includes ``bboxes`` and ``labels``
  303. attributes.
  304. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  305. image size, scaling factor, etc.
  306. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  307. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  308. data that is ignored during training and testing.
  309. Defaults to None.
  310. Returns:
  311. dict[str, Tensor]: A dictionary of loss components.
  312. """
  313. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  314. assert len(featmap_sizes) == self.prior_generator.num_levels
  315. device = cls_scores[0].device
  316. anchor_list, valid_flag_list = self.get_anchors(
  317. featmap_sizes, batch_img_metas, device=device)
  318. cls_reg_targets = self.get_targets(
  319. anchor_list,
  320. valid_flag_list,
  321. batch_gt_instances,
  322. batch_img_metas,
  323. batch_gt_instances_ignore=batch_gt_instances_ignore)
  324. (anchor_list, labels_list, label_weights_list, bbox_targets_list,
  325. bbox_weights_list, avg_factor) = cls_reg_targets
  326. avg_factor = reduce_mean(
  327. torch.tensor(avg_factor, dtype=torch.float, device=device)).item()
  328. losses_cls, losses_bbox, losses_dfl,\
  329. avg_factor = multi_apply(
  330. self.loss_by_feat_single,
  331. anchor_list,
  332. cls_scores,
  333. bbox_preds,
  334. labels_list,
  335. label_weights_list,
  336. bbox_targets_list,
  337. self.prior_generator.strides,
  338. avg_factor=avg_factor)
  339. avg_factor = sum(avg_factor)
  340. avg_factor = reduce_mean(avg_factor).clamp_(min=1).item()
  341. losses_bbox = list(map(lambda x: x / avg_factor, losses_bbox))
  342. losses_dfl = list(map(lambda x: x / avg_factor, losses_dfl))
  343. return dict(
  344. loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dfl=losses_dfl)
  345. def _predict_by_feat_single(self,
  346. cls_score_list: List[Tensor],
  347. bbox_pred_list: List[Tensor],
  348. score_factor_list: List[Tensor],
  349. mlvl_priors: List[Tensor],
  350. img_meta: dict,
  351. cfg: ConfigDict,
  352. rescale: bool = False,
  353. with_nms: bool = True) -> InstanceData:
  354. """Transform a single image's features extracted from the head into
  355. bbox results.
  356. Args:
  357. cls_score_list (list[Tensor]): Box scores from all scale
  358. levels of a single image, each item has shape
  359. (num_priors * num_classes, H, W).
  360. bbox_pred_list (list[Tensor]): Box energies / deltas from
  361. all scale levels of a single image, each item has shape
  362. (num_priors * 4, H, W).
  363. score_factor_list (list[Tensor]): Score factor from all scale
  364. levels of a single image. GFL head does not need this value.
  365. mlvl_priors (list[Tensor]): Each element in the list is
  366. the priors of a single level in feature pyramid, has shape
  367. (num_priors, 4).
  368. img_meta (dict): Image meta info.
  369. cfg (:obj: `ConfigDict`): Test / postprocessing configuration,
  370. if None, test_cfg would be used.
  371. rescale (bool): If True, return boxes in original image space.
  372. Defaults to False.
  373. with_nms (bool): If True, do nms before return boxes.
  374. Defaults to True.
  375. Returns:
  376. tuple[Tensor]: Results of detected bboxes and labels. If with_nms
  377. is False and mlvl_score_factor is None, return mlvl_bboxes and
  378. mlvl_scores, else return mlvl_bboxes, mlvl_scores and
  379. mlvl_score_factor. Usually with_nms is False is used for aug
  380. test. If with_nms is True, then return the following format
  381. - det_bboxes (Tensor): Predicted bboxes with shape
  382. [num_bboxes, 5], where the first 4 columns are bounding
  383. box positions (tl_x, tl_y, br_x, br_y) and the 5-th
  384. column are scores between 0 and 1.
  385. - det_labels (Tensor): Predicted labels of the corresponding
  386. box with shape [num_bboxes].
  387. """
  388. cfg = self.test_cfg if cfg is None else cfg
  389. img_shape = img_meta['img_shape']
  390. nms_pre = cfg.get('nms_pre', -1)
  391. mlvl_bboxes = []
  392. mlvl_scores = []
  393. mlvl_labels = []
  394. for level_idx, (cls_score, bbox_pred, stride, priors) in enumerate(
  395. zip(cls_score_list, bbox_pred_list,
  396. self.prior_generator.strides, mlvl_priors)):
  397. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  398. assert stride[0] == stride[1]
  399. bbox_pred = bbox_pred.permute(1, 2, 0)
  400. bbox_pred = self.integral(bbox_pred) * stride[0]
  401. scores = cls_score.permute(1, 2, 0).reshape(
  402. -1, self.cls_out_channels).sigmoid()
  403. # After https://github.com/open-mmlab/mmdetection/pull/6268/,
  404. # this operation keeps fewer bboxes under the same `nms_pre`.
  405. # There is no difference in performance for most models. If you
  406. # find a slight drop in performance, you can set a larger
  407. # `nms_pre` than before.
  408. results = filter_scores_and_topk(
  409. scores, cfg.score_thr, nms_pre,
  410. dict(bbox_pred=bbox_pred, priors=priors))
  411. scores, labels, _, filtered_results = results
  412. bbox_pred = filtered_results['bbox_pred']
  413. priors = filtered_results['priors']
  414. bboxes = self.bbox_coder.decode(
  415. self.anchor_center(priors), bbox_pred, max_shape=img_shape)
  416. mlvl_bboxes.append(bboxes)
  417. mlvl_scores.append(scores)
  418. mlvl_labels.append(labels)
  419. results = InstanceData()
  420. results.bboxes = torch.cat(mlvl_bboxes)
  421. results.scores = torch.cat(mlvl_scores)
  422. results.labels = torch.cat(mlvl_labels)
  423. return self._bbox_post_process(
  424. results=results,
  425. cfg=cfg,
  426. rescale=rescale,
  427. with_nms=with_nms,
  428. img_meta=img_meta)
  429. def get_targets(self,
  430. anchor_list: List[Tensor],
  431. valid_flag_list: List[Tensor],
  432. batch_gt_instances: InstanceList,
  433. batch_img_metas: List[dict],
  434. batch_gt_instances_ignore: OptInstanceList = None,
  435. unmap_outputs=True) -> tuple:
  436. """Get targets for GFL head.
  437. This method is almost the same as `AnchorHead.get_targets()`. Besides
  438. returning the targets as the parent method does, it also returns the
  439. anchors as the first element of the returned tuple.
  440. """
  441. num_imgs = len(batch_img_metas)
  442. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  443. # anchor number of multi levels
  444. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  445. num_level_anchors_list = [num_level_anchors] * num_imgs
  446. # concat all level anchors and flags to a single tensor
  447. for i in range(num_imgs):
  448. assert len(anchor_list[i]) == len(valid_flag_list[i])
  449. anchor_list[i] = torch.cat(anchor_list[i])
  450. valid_flag_list[i] = torch.cat(valid_flag_list[i])
  451. # compute targets for each image
  452. if batch_gt_instances_ignore is None:
  453. batch_gt_instances_ignore = [None] * num_imgs
  454. (all_anchors, all_labels, all_label_weights, all_bbox_targets,
  455. all_bbox_weights, pos_inds_list, neg_inds_list,
  456. sampling_results_list) = multi_apply(
  457. self._get_targets_single,
  458. anchor_list,
  459. valid_flag_list,
  460. num_level_anchors_list,
  461. batch_gt_instances,
  462. batch_img_metas,
  463. batch_gt_instances_ignore,
  464. unmap_outputs=unmap_outputs)
  465. # Get `avg_factor` of all images, which calculate in `SamplingResult`.
  466. # When using sampling method, avg_factor is usually the sum of
  467. # positive and negative priors. When using `PseudoSampler`,
  468. # `avg_factor` is usually equal to the number of positive priors.
  469. avg_factor = sum(
  470. [results.avg_factor for results in sampling_results_list])
  471. # split targets to a list w.r.t. multiple levels
  472. anchors_list = images_to_levels(all_anchors, num_level_anchors)
  473. labels_list = images_to_levels(all_labels, num_level_anchors)
  474. label_weights_list = images_to_levels(all_label_weights,
  475. num_level_anchors)
  476. bbox_targets_list = images_to_levels(all_bbox_targets,
  477. num_level_anchors)
  478. bbox_weights_list = images_to_levels(all_bbox_weights,
  479. num_level_anchors)
  480. return (anchors_list, labels_list, label_weights_list,
  481. bbox_targets_list, bbox_weights_list, avg_factor)
  482. def _get_targets_single(self,
  483. flat_anchors: Tensor,
  484. valid_flags: Tensor,
  485. num_level_anchors: List[int],
  486. gt_instances: InstanceData,
  487. img_meta: dict,
  488. gt_instances_ignore: Optional[InstanceData] = None,
  489. unmap_outputs: bool = True) -> tuple:
  490. """Compute regression, classification targets for anchors in a single
  491. image.
  492. Args:
  493. flat_anchors (Tensor): Multi-level anchors of the image, which are
  494. concatenated into a single tensor of shape (num_anchors, 4)
  495. valid_flags (Tensor): Multi level valid flags of the image,
  496. which are concatenated into a single tensor of
  497. shape (num_anchors,).
  498. num_level_anchors (list[int]): Number of anchors of each scale
  499. level.
  500. gt_instances (:obj:`InstanceData`): Ground truth of instance
  501. annotations. It usually includes ``bboxes`` and ``labels``
  502. attributes.
  503. img_meta (dict): Meta information for current image.
  504. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  505. to be ignored during training. It includes ``bboxes`` attribute
  506. data that is ignored during training and testing.
  507. Defaults to None.
  508. unmap_outputs (bool): Whether to map outputs back to the original
  509. set of anchors. Defaults to True.
  510. Returns:
  511. tuple: N is the number of total anchors in the image.
  512. - anchors (Tensor): All anchors in the image with shape (N, 4).
  513. - labels (Tensor): Labels of all anchors in the image with
  514. shape (N,).
  515. - label_weights (Tensor): Label weights of all anchor in the
  516. image with shape (N,).
  517. - bbox_targets (Tensor): BBox targets of all anchors in the
  518. image with shape (N, 4).
  519. - bbox_weights (Tensor): BBox weights of all anchors in the
  520. image with shape (N, 4).
  521. - pos_inds (Tensor): Indices of positive anchor with shape
  522. (num_pos,).
  523. - neg_inds (Tensor): Indices of negative anchor with shape
  524. (num_neg,).
  525. - sampling_result (:obj:`SamplingResult`): Sampling results.
  526. """
  527. inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
  528. img_meta['img_shape'][:2],
  529. self.train_cfg['allowed_border'])
  530. if not inside_flags.any():
  531. raise ValueError(
  532. 'There is no valid anchor inside the image boundary. Please '
  533. 'check the image size and anchor sizes, or set '
  534. '``allowed_border`` to -1 to skip the condition.')
  535. # assign gt and sample anchors
  536. anchors = flat_anchors[inside_flags, :]
  537. num_level_anchors_inside = self.get_num_level_anchors_inside(
  538. num_level_anchors, inside_flags)
  539. pred_instances = InstanceData(priors=anchors)
  540. assign_result = self.assigner.assign(
  541. pred_instances=pred_instances,
  542. num_level_priors=num_level_anchors_inside,
  543. gt_instances=gt_instances,
  544. gt_instances_ignore=gt_instances_ignore)
  545. sampling_result = self.sampler.sample(
  546. assign_result=assign_result,
  547. pred_instances=pred_instances,
  548. gt_instances=gt_instances)
  549. num_valid_anchors = anchors.shape[0]
  550. bbox_targets = torch.zeros_like(anchors)
  551. bbox_weights = torch.zeros_like(anchors)
  552. labels = anchors.new_full((num_valid_anchors, ),
  553. self.num_classes,
  554. dtype=torch.long)
  555. label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
  556. pos_inds = sampling_result.pos_inds
  557. neg_inds = sampling_result.neg_inds
  558. if len(pos_inds) > 0:
  559. pos_bbox_targets = sampling_result.pos_gt_bboxes
  560. bbox_targets[pos_inds, :] = pos_bbox_targets
  561. bbox_weights[pos_inds, :] = 1.0
  562. labels[pos_inds] = sampling_result.pos_gt_labels
  563. if self.train_cfg['pos_weight'] <= 0:
  564. label_weights[pos_inds] = 1.0
  565. else:
  566. label_weights[pos_inds] = self.train_cfg['pos_weight']
  567. if len(neg_inds) > 0:
  568. label_weights[neg_inds] = 1.0
  569. # map up to original set of anchors
  570. if unmap_outputs:
  571. num_total_anchors = flat_anchors.size(0)
  572. anchors = unmap(anchors, num_total_anchors, inside_flags)
  573. labels = unmap(
  574. labels, num_total_anchors, inside_flags, fill=self.num_classes)
  575. label_weights = unmap(label_weights, num_total_anchors,
  576. inside_flags)
  577. bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
  578. bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
  579. return (anchors, labels, label_weights, bbox_targets, bbox_weights,
  580. pos_inds, neg_inds, sampling_result)
  581. def get_num_level_anchors_inside(self, num_level_anchors: List[int],
  582. inside_flags: Tensor) -> List[int]:
  583. """Get the number of valid anchors in every level."""
  584. split_inside_flags = torch.split(inside_flags, num_level_anchors)
  585. num_level_anchors_inside = [
  586. int(flags.sum()) for flags in split_inside_flags
  587. ]
  588. return num_level_anchors_inside