ddod_head.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Sequence, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import ConvModule, Scale
  6. from mmengine.model import bias_init_with_prob, normal_init
  7. from mmengine.structures import InstanceData
  8. from torch import Tensor
  9. from mmdet.registry import MODELS, TASK_UTILS
  10. from mmdet.structures.bbox import bbox_overlaps
  11. from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
  12. OptInstanceList, reduce_mean)
  13. from ..task_modules.prior_generators import anchor_inside_flags
  14. from ..utils import images_to_levels, multi_apply, unmap
  15. from .anchor_head import AnchorHead
  16. EPS = 1e-12
  17. @MODELS.register_module()
  18. class DDODHead(AnchorHead):
  19. """Detection Head of `DDOD <https://arxiv.org/abs/2107.02963>`_.
  20. DDOD head decomposes conjunctions lying in most current one-stage
  21. detectors via label assignment disentanglement, spatial feature
  22. disentanglement, and pyramid supervision disentanglement.
  23. Args:
  24. num_classes (int): Number of categories excluding the
  25. background category.
  26. in_channels (int): Number of channels in the input feature map.
  27. stacked_convs (int): The number of stacked Conv. Defaults to 4.
  28. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  29. convolution layer. Defaults to None.
  30. use_dcn (bool): Use dcn, Same as ATSS when False. Defaults to True.
  31. norm_cfg (:obj:`ConfigDict` or dict): Normal config of ddod head.
  32. Defaults to dict(type='GN', num_groups=32, requires_grad=True).
  33. loss_iou (:obj:`ConfigDict` or dict): Config of IoU loss. Defaults to
  34. dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0).
  35. """
  36. def __init__(self,
  37. num_classes: int,
  38. in_channels: int,
  39. stacked_convs: int = 4,
  40. conv_cfg: OptConfigType = None,
  41. use_dcn: bool = True,
  42. norm_cfg: ConfigType = dict(
  43. type='GN', num_groups=32, requires_grad=True),
  44. loss_iou: ConfigType = dict(
  45. type='CrossEntropyLoss',
  46. use_sigmoid=True,
  47. loss_weight=1.0),
  48. **kwargs) -> None:
  49. self.stacked_convs = stacked_convs
  50. self.conv_cfg = conv_cfg
  51. self.norm_cfg = norm_cfg
  52. self.use_dcn = use_dcn
  53. super().__init__(num_classes, in_channels, **kwargs)
  54. if self.train_cfg:
  55. self.cls_assigner = TASK_UTILS.build(self.train_cfg['assigner'])
  56. self.reg_assigner = TASK_UTILS.build(
  57. self.train_cfg['reg_assigner'])
  58. self.loss_iou = MODELS.build(loss_iou)
  59. def _init_layers(self) -> None:
  60. """Initialize layers of the head."""
  61. self.relu = nn.ReLU(inplace=True)
  62. self.cls_convs = nn.ModuleList()
  63. self.reg_convs = nn.ModuleList()
  64. for i in range(self.stacked_convs):
  65. chn = self.in_channels if i == 0 else self.feat_channels
  66. self.cls_convs.append(
  67. ConvModule(
  68. chn,
  69. self.feat_channels,
  70. 3,
  71. stride=1,
  72. padding=1,
  73. conv_cfg=dict(type='DCN', deform_groups=1)
  74. if i == 0 and self.use_dcn else self.conv_cfg,
  75. norm_cfg=self.norm_cfg))
  76. self.reg_convs.append(
  77. ConvModule(
  78. chn,
  79. self.feat_channels,
  80. 3,
  81. stride=1,
  82. padding=1,
  83. conv_cfg=dict(type='DCN', deform_groups=1)
  84. if i == 0 and self.use_dcn else self.conv_cfg,
  85. norm_cfg=self.norm_cfg))
  86. self.atss_cls = nn.Conv2d(
  87. self.feat_channels,
  88. self.num_base_priors * self.cls_out_channels,
  89. 3,
  90. padding=1)
  91. self.atss_reg = nn.Conv2d(
  92. self.feat_channels, self.num_base_priors * 4, 3, padding=1)
  93. self.atss_iou = nn.Conv2d(
  94. self.feat_channels, self.num_base_priors * 1, 3, padding=1)
  95. self.scales = nn.ModuleList(
  96. [Scale(1.0) for _ in self.prior_generator.strides])
  97. # we use the global list in loss
  98. self.cls_num_pos_samples_per_level = [
  99. 0. for _ in range(len(self.prior_generator.strides))
  100. ]
  101. self.reg_num_pos_samples_per_level = [
  102. 0. for _ in range(len(self.prior_generator.strides))
  103. ]
  104. def init_weights(self) -> None:
  105. """Initialize weights of the head."""
  106. for m in self.cls_convs:
  107. normal_init(m.conv, std=0.01)
  108. for m in self.reg_convs:
  109. normal_init(m.conv, std=0.01)
  110. normal_init(self.atss_reg, std=0.01)
  111. normal_init(self.atss_iou, std=0.01)
  112. bias_cls = bias_init_with_prob(0.01)
  113. normal_init(self.atss_cls, std=0.01, bias=bias_cls)
  114. def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]:
  115. """Forward features from the upstream network.
  116. Args:
  117. x (tuple[Tensor]): Features from the upstream network, each is
  118. a 4D-tensor.
  119. Returns:
  120. tuple: A tuple of classification scores, bbox predictions,
  121. and iou predictions.
  122. - cls_scores (list[Tensor]): Classification scores for all \
  123. scale levels, each is a 4D-tensor, the channels number is \
  124. num_base_priors * num_classes.
  125. - bbox_preds (list[Tensor]): Box energies / deltas for all \
  126. scale levels, each is a 4D-tensor, the channels number is \
  127. num_base_priors * 4.
  128. - iou_preds (list[Tensor]): IoU scores for all scale levels, \
  129. each is a 4D-tensor, the channels number is num_base_priors * 1.
  130. """
  131. return multi_apply(self.forward_single, x, self.scales)
  132. def forward_single(self, x: Tensor, scale: Scale) -> Sequence[Tensor]:
  133. """Forward feature of a single scale level.
  134. Args:
  135. x (Tensor): Features of a single scale level.
  136. scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
  137. the bbox prediction.
  138. Returns:
  139. tuple:
  140. - cls_score (Tensor): Cls scores for a single scale level \
  141. the channels number is num_base_priors * num_classes.
  142. - bbox_pred (Tensor): Box energies / deltas for a single \
  143. scale level, the channels number is num_base_priors * 4.
  144. - iou_pred (Tensor): Iou for a single scale level, the \
  145. channel number is (N, num_base_priors * 1, H, W).
  146. """
  147. cls_feat = x
  148. reg_feat = x
  149. for cls_conv in self.cls_convs:
  150. cls_feat = cls_conv(cls_feat)
  151. for reg_conv in self.reg_convs:
  152. reg_feat = reg_conv(reg_feat)
  153. cls_score = self.atss_cls(cls_feat)
  154. # we just follow atss, not apply exp in bbox_pred
  155. bbox_pred = scale(self.atss_reg(reg_feat)).float()
  156. iou_pred = self.atss_iou(reg_feat)
  157. return cls_score, bbox_pred, iou_pred
  158. def loss_cls_by_feat_single(self, cls_score: Tensor, labels: Tensor,
  159. label_weights: Tensor,
  160. reweight_factor: List[float],
  161. avg_factor: float) -> Tuple[Tensor]:
  162. """Compute cls loss of a single scale level.
  163. Args:
  164. cls_score (Tensor): Box scores for each scale level
  165. Has shape (N, num_base_priors * num_classes, H, W).
  166. labels (Tensor): Labels of each anchors with shape
  167. (N, num_total_anchors).
  168. label_weights (Tensor): Label weights of each anchor with shape
  169. (N, num_total_anchors)
  170. reweight_factor (List[float]): Reweight factor for cls and reg
  171. loss.
  172. avg_factor (float): Average factor that is used to average
  173. the loss. When using sampling method, avg_factor is usually
  174. the sum of positive and negative priors. When using
  175. `PseudoSampler`, `avg_factor` is usually equal to the number
  176. of positive priors.
  177. Returns:
  178. Tuple[Tensor]: A tuple of loss components.
  179. """
  180. cls_score = cls_score.permute(0, 2, 3, 1).reshape(
  181. -1, self.cls_out_channels).contiguous()
  182. labels = labels.reshape(-1)
  183. label_weights = label_weights.reshape(-1)
  184. loss_cls = self.loss_cls(
  185. cls_score, labels, label_weights, avg_factor=avg_factor)
  186. return reweight_factor * loss_cls,
  187. def loss_reg_by_feat_single(self, anchors: Tensor, bbox_pred: Tensor,
  188. iou_pred: Tensor, labels,
  189. label_weights: Tensor, bbox_targets: Tensor,
  190. bbox_weights: Tensor,
  191. reweight_factor: List[float],
  192. avg_factor: float) -> Tuple[Tensor, Tensor]:
  193. """Compute reg loss of a single scale level based on the features
  194. extracted by the detection head.
  195. Args:
  196. anchors (Tensor): Box reference for each scale level with shape
  197. (N, num_total_anchors, 4).
  198. bbox_pred (Tensor): Box energies / deltas for each scale
  199. level with shape (N, num_base_priors * 4, H, W).
  200. iou_pred (Tensor): Iou for a single scale level, the
  201. channel number is (N, num_base_priors * 1, H, W).
  202. labels (Tensor): Labels of each anchors with shape
  203. (N, num_total_anchors).
  204. label_weights (Tensor): Label weights of each anchor with shape
  205. (N, num_total_anchors)
  206. bbox_targets (Tensor): BBox regression targets of each anchor
  207. weight shape (N, num_total_anchors, 4).
  208. bbox_weights (Tensor): BBox weights of all anchors in the
  209. image with shape (N, 4)
  210. reweight_factor (List[float]): Reweight factor for cls and reg
  211. loss.
  212. avg_factor (float): Average factor that is used to average
  213. the loss. When using sampling method, avg_factor is usually
  214. the sum of positive and negative priors. When using
  215. `PseudoSampler`, `avg_factor` is usually equal to the number
  216. of positive priors.
  217. Returns:
  218. Tuple[Tensor, Tensor]: A tuple of loss components.
  219. """
  220. anchors = anchors.reshape(-1, 4)
  221. bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  222. iou_pred = iou_pred.permute(0, 2, 3, 1).reshape(-1, )
  223. bbox_targets = bbox_targets.reshape(-1, 4)
  224. bbox_weights = bbox_weights.reshape(-1, 4)
  225. labels = labels.reshape(-1)
  226. label_weights = label_weights.reshape(-1)
  227. iou_targets = label_weights.new_zeros(labels.shape)
  228. iou_weights = label_weights.new_zeros(labels.shape)
  229. iou_weights[(bbox_weights.sum(axis=1) > 0).nonzero(
  230. as_tuple=False)] = 1.
  231. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  232. bg_class_ind = self.num_classes
  233. pos_inds = ((labels >= 0)
  234. &
  235. (labels < bg_class_ind)).nonzero(as_tuple=False).squeeze(1)
  236. if len(pos_inds) > 0:
  237. pos_bbox_targets = bbox_targets[pos_inds]
  238. pos_bbox_pred = bbox_pred[pos_inds]
  239. pos_anchors = anchors[pos_inds]
  240. pos_decode_bbox_pred = self.bbox_coder.decode(
  241. pos_anchors, pos_bbox_pred)
  242. pos_decode_bbox_targets = self.bbox_coder.decode(
  243. pos_anchors, pos_bbox_targets)
  244. # regression loss
  245. loss_bbox = self.loss_bbox(
  246. pos_decode_bbox_pred,
  247. pos_decode_bbox_targets,
  248. avg_factor=avg_factor)
  249. iou_targets[pos_inds] = bbox_overlaps(
  250. pos_decode_bbox_pred.detach(),
  251. pos_decode_bbox_targets,
  252. is_aligned=True)
  253. loss_iou = self.loss_iou(
  254. iou_pred, iou_targets, iou_weights, avg_factor=avg_factor)
  255. else:
  256. loss_bbox = bbox_pred.sum() * 0
  257. loss_iou = iou_pred.sum() * 0
  258. return reweight_factor * loss_bbox, reweight_factor * loss_iou
  259. def calc_reweight_factor(self, labels_list: List[Tensor]) -> List[float]:
  260. """Compute reweight_factor for regression and classification loss."""
  261. # get pos samples for each level
  262. bg_class_ind = self.num_classes
  263. for ii, each_level_label in enumerate(labels_list):
  264. pos_inds = ((each_level_label >= 0) &
  265. (each_level_label < bg_class_ind)).nonzero(
  266. as_tuple=False).squeeze(1)
  267. self.cls_num_pos_samples_per_level[ii] += len(pos_inds)
  268. # get reweight factor from 1 ~ 2 with bilinear interpolation
  269. min_pos_samples = min(self.cls_num_pos_samples_per_level)
  270. max_pos_samples = max(self.cls_num_pos_samples_per_level)
  271. interval = 1. / (max_pos_samples - min_pos_samples + 1e-10)
  272. reweight_factor_per_level = []
  273. for pos_samples in self.cls_num_pos_samples_per_level:
  274. factor = 2. - (pos_samples - min_pos_samples) * interval
  275. reweight_factor_per_level.append(factor)
  276. return reweight_factor_per_level
  277. def loss_by_feat(
  278. self,
  279. cls_scores: List[Tensor],
  280. bbox_preds: List[Tensor],
  281. iou_preds: List[Tensor],
  282. batch_gt_instances: InstanceList,
  283. batch_img_metas: List[dict],
  284. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  285. """Calculate the loss based on the features extracted by the detection
  286. head.
  287. Args:
  288. cls_scores (list[Tensor]): Box scores for each scale level
  289. Has shape (N, num_base_priors * num_classes, H, W)
  290. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  291. level with shape (N, num_base_priors * 4, H, W)
  292. iou_preds (list[Tensor]): Score factor for all scale level,
  293. each is a 4D-tensor, has shape (batch_size, 1, H, W).
  294. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  295. gt_instance. It usually includes ``bboxes`` and ``labels``
  296. attributes.
  297. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  298. image size, scaling factor, etc.
  299. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  300. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  301. data that is ignored during training and testing.
  302. Defaults to None.
  303. Returns:
  304. dict[str, Tensor]: A dictionary of loss components.
  305. """
  306. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  307. assert len(featmap_sizes) == self.prior_generator.num_levels
  308. device = cls_scores[0].device
  309. anchor_list, valid_flag_list = self.get_anchors(
  310. featmap_sizes, batch_img_metas, device=device)
  311. # calculate common vars for cls and reg assigners at once
  312. targets_com = self.process_predictions_and_anchors(
  313. anchor_list, valid_flag_list, cls_scores, bbox_preds,
  314. batch_img_metas, batch_gt_instances_ignore)
  315. (anchor_list, valid_flag_list, num_level_anchors_list, cls_score_list,
  316. bbox_pred_list, batch_gt_instances_ignore) = targets_com
  317. # classification branch assigner
  318. cls_targets = self.get_cls_targets(
  319. anchor_list,
  320. valid_flag_list,
  321. num_level_anchors_list,
  322. cls_score_list,
  323. bbox_pred_list,
  324. batch_gt_instances,
  325. batch_img_metas,
  326. batch_gt_instances_ignore=batch_gt_instances_ignore)
  327. (cls_anchor_list, labels_list, label_weights_list, bbox_targets_list,
  328. bbox_weights_list, avg_factor) = cls_targets
  329. avg_factor = reduce_mean(
  330. torch.tensor(avg_factor, dtype=torch.float, device=device)).item()
  331. avg_factor = max(avg_factor, 1.0)
  332. reweight_factor_per_level = self.calc_reweight_factor(labels_list)
  333. cls_losses_cls, = multi_apply(
  334. self.loss_cls_by_feat_single,
  335. cls_scores,
  336. labels_list,
  337. label_weights_list,
  338. reweight_factor_per_level,
  339. avg_factor=avg_factor)
  340. # regression branch assigner
  341. reg_targets = self.get_reg_targets(
  342. anchor_list,
  343. valid_flag_list,
  344. num_level_anchors_list,
  345. cls_score_list,
  346. bbox_pred_list,
  347. batch_gt_instances,
  348. batch_img_metas,
  349. batch_gt_instances_ignore=batch_gt_instances_ignore)
  350. (reg_anchor_list, labels_list, label_weights_list, bbox_targets_list,
  351. bbox_weights_list, avg_factor) = reg_targets
  352. avg_factor = reduce_mean(
  353. torch.tensor(avg_factor, dtype=torch.float, device=device)).item()
  354. avg_factor = max(avg_factor, 1.0)
  355. reweight_factor_per_level = self.calc_reweight_factor(labels_list)
  356. reg_losses_bbox, reg_losses_iou = multi_apply(
  357. self.loss_reg_by_feat_single,
  358. reg_anchor_list,
  359. bbox_preds,
  360. iou_preds,
  361. labels_list,
  362. label_weights_list,
  363. bbox_targets_list,
  364. bbox_weights_list,
  365. reweight_factor_per_level,
  366. avg_factor=avg_factor)
  367. return dict(
  368. loss_cls=cls_losses_cls,
  369. loss_bbox=reg_losses_bbox,
  370. loss_iou=reg_losses_iou)
  371. def process_predictions_and_anchors(
  372. self,
  373. anchor_list: List[List[Tensor]],
  374. valid_flag_list: List[List[Tensor]],
  375. cls_scores: List[Tensor],
  376. bbox_preds: List[Tensor],
  377. batch_img_metas: List[dict],
  378. batch_gt_instances_ignore: OptInstanceList = None) -> tuple:
  379. """Compute common vars for regression and classification targets.
  380. Args:
  381. anchor_list (List[List[Tensor]]): anchors of each image.
  382. valid_flag_list (List[List[Tensor]]): Valid flags of each image.
  383. cls_scores (List[Tensor]): Classification scores for all scale
  384. levels, each is a 4D-tensor, the channels number is
  385. num_base_priors * num_classes.
  386. bbox_preds (list[Tensor]): Box energies / deltas for all scale
  387. levels, each is a 4D-tensor, the channels number is
  388. num_base_priors * 4.
  389. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  390. image size, scaling factor, etc.
  391. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  392. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  393. data that is ignored during training and testing.
  394. Defaults to None.
  395. Return:
  396. tuple[Tensor]: A tuple of common loss vars.
  397. """
  398. num_imgs = len(batch_img_metas)
  399. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  400. # anchor number of multi levels
  401. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  402. num_level_anchors_list = [num_level_anchors] * num_imgs
  403. anchor_list_ = []
  404. valid_flag_list_ = []
  405. # concat all level anchors and flags to a single tensor
  406. for i in range(num_imgs):
  407. assert len(anchor_list[i]) == len(valid_flag_list[i])
  408. anchor_list_.append(torch.cat(anchor_list[i]))
  409. valid_flag_list_.append(torch.cat(valid_flag_list[i]))
  410. # compute targets for each image
  411. if batch_gt_instances_ignore is None:
  412. batch_gt_instances_ignore = [None for _ in range(num_imgs)]
  413. num_levels = len(cls_scores)
  414. cls_score_list = []
  415. bbox_pred_list = []
  416. mlvl_cls_score_list = [
  417. cls_score.permute(0, 2, 3, 1).reshape(
  418. num_imgs, -1, self.num_base_priors * self.cls_out_channels)
  419. for cls_score in cls_scores
  420. ]
  421. mlvl_bbox_pred_list = [
  422. bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
  423. self.num_base_priors * 4)
  424. for bbox_pred in bbox_preds
  425. ]
  426. for i in range(num_imgs):
  427. mlvl_cls_tensor_list = [
  428. mlvl_cls_score_list[j][i] for j in range(num_levels)
  429. ]
  430. mlvl_bbox_tensor_list = [
  431. mlvl_bbox_pred_list[j][i] for j in range(num_levels)
  432. ]
  433. cat_mlvl_cls_score = torch.cat(mlvl_cls_tensor_list, dim=0)
  434. cat_mlvl_bbox_pred = torch.cat(mlvl_bbox_tensor_list, dim=0)
  435. cls_score_list.append(cat_mlvl_cls_score)
  436. bbox_pred_list.append(cat_mlvl_bbox_pred)
  437. return (anchor_list_, valid_flag_list_, num_level_anchors_list,
  438. cls_score_list, bbox_pred_list, batch_gt_instances_ignore)
  439. def get_cls_targets(self,
  440. anchor_list: List[Tensor],
  441. valid_flag_list: List[Tensor],
  442. num_level_anchors_list: List[int],
  443. cls_score_list: List[Tensor],
  444. bbox_pred_list: List[Tensor],
  445. batch_gt_instances: InstanceList,
  446. batch_img_metas: List[dict],
  447. batch_gt_instances_ignore: OptInstanceList = None,
  448. unmap_outputs: bool = True) -> tuple:
  449. """Get cls targets for DDOD head.
  450. This method is almost the same as `AnchorHead.get_targets()`.
  451. Besides returning the targets as the parent method does,
  452. it also returns the anchors as the first element of the
  453. returned tuple.
  454. Args:
  455. anchor_list (list[Tensor]): anchors of each image.
  456. valid_flag_list (list[Tensor]): Valid flags of each image.
  457. num_level_anchors_list (list[Tensor]): Number of anchors of each
  458. scale level of all image.
  459. cls_score_list (list[Tensor]): Classification scores for all scale
  460. levels, each is a 4D-tensor, the channels number is
  461. num_base_priors * num_classes.
  462. bbox_pred_list (list[Tensor]): Box energies / deltas for all scale
  463. levels, each is a 4D-tensor, the channels number is
  464. num_base_priors * 4.
  465. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  466. gt_instance. It usually includes ``bboxes`` and ``labels``
  467. attributes.
  468. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  469. image size, scaling factor, etc.
  470. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  471. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  472. data that is ignored during training and testing.
  473. Defaults to None.
  474. unmap_outputs (bool): Whether to map outputs back to the original
  475. set of anchors.
  476. Return:
  477. tuple[Tensor]: A tuple of cls targets components.
  478. """
  479. (all_anchors, all_labels, all_label_weights, all_bbox_targets,
  480. all_bbox_weights, pos_inds_list, neg_inds_list,
  481. sampling_results_list) = multi_apply(
  482. self._get_targets_single,
  483. anchor_list,
  484. valid_flag_list,
  485. cls_score_list,
  486. bbox_pred_list,
  487. num_level_anchors_list,
  488. batch_gt_instances,
  489. batch_img_metas,
  490. batch_gt_instances_ignore,
  491. unmap_outputs=unmap_outputs,
  492. is_cls_assigner=True)
  493. # Get `avg_factor` of all images, which calculate in `SamplingResult`.
  494. # When using sampling method, avg_factor is usually the sum of
  495. # positive and negative priors. When using `PseudoSampler`,
  496. # `avg_factor` is usually equal to the number of positive priors.
  497. avg_factor = sum(
  498. [results.avg_factor for results in sampling_results_list])
  499. # split targets to a list w.r.t. multiple levels
  500. anchors_list = images_to_levels(all_anchors, num_level_anchors_list[0])
  501. labels_list = images_to_levels(all_labels, num_level_anchors_list[0])
  502. label_weights_list = images_to_levels(all_label_weights,
  503. num_level_anchors_list[0])
  504. bbox_targets_list = images_to_levels(all_bbox_targets,
  505. num_level_anchors_list[0])
  506. bbox_weights_list = images_to_levels(all_bbox_weights,
  507. num_level_anchors_list[0])
  508. return (anchors_list, labels_list, label_weights_list,
  509. bbox_targets_list, bbox_weights_list, avg_factor)
  510. def get_reg_targets(self,
  511. anchor_list: List[Tensor],
  512. valid_flag_list: List[Tensor],
  513. num_level_anchors_list: List[int],
  514. cls_score_list: List[Tensor],
  515. bbox_pred_list: List[Tensor],
  516. batch_gt_instances: InstanceList,
  517. batch_img_metas: List[dict],
  518. batch_gt_instances_ignore: OptInstanceList = None,
  519. unmap_outputs: bool = True) -> tuple:
  520. """Get reg targets for DDOD head.
  521. This method is almost the same as `AnchorHead.get_targets()` when
  522. is_cls_assigner is False. Besides returning the targets as the parent
  523. method does, it also returns the anchors as the first element of the
  524. returned tuple.
  525. Args:
  526. anchor_list (list[Tensor]): anchors of each image.
  527. valid_flag_list (list[Tensor]): Valid flags of each image.
  528. num_level_anchors_list (list[Tensor]): Number of anchors of each
  529. scale level of all image.
  530. cls_score_list (list[Tensor]): Classification scores for all scale
  531. levels, each is a 4D-tensor, the channels number is
  532. num_base_priors * num_classes.
  533. bbox_pred_list (list[Tensor]): Box energies / deltas for all scale
  534. levels, each is a 4D-tensor, the channels number is
  535. num_base_priors * 4.
  536. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  537. gt_instance. It usually includes ``bboxes`` and ``labels``
  538. attributes.
  539. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  540. image size, scaling factor, etc.
  541. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  542. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  543. data that is ignored during training and testing.
  544. Defaults to None.
  545. unmap_outputs (bool): Whether to map outputs back to the original
  546. set of anchors.
  547. Return:
  548. tuple[Tensor]: A tuple of reg targets components.
  549. """
  550. (all_anchors, all_labels, all_label_weights, all_bbox_targets,
  551. all_bbox_weights, pos_inds_list, neg_inds_list,
  552. sampling_results_list) = multi_apply(
  553. self._get_targets_single,
  554. anchor_list,
  555. valid_flag_list,
  556. cls_score_list,
  557. bbox_pred_list,
  558. num_level_anchors_list,
  559. batch_gt_instances,
  560. batch_img_metas,
  561. batch_gt_instances_ignore,
  562. unmap_outputs=unmap_outputs,
  563. is_cls_assigner=False)
  564. # Get `avg_factor` of all images, which calculate in `SamplingResult`.
  565. # When using sampling method, avg_factor is usually the sum of
  566. # positive and negative priors. When using `PseudoSampler`,
  567. # `avg_factor` is usually equal to the number of positive priors.
  568. avg_factor = sum(
  569. [results.avg_factor for results in sampling_results_list])
  570. # split targets to a list w.r.t. multiple levels
  571. anchors_list = images_to_levels(all_anchors, num_level_anchors_list[0])
  572. labels_list = images_to_levels(all_labels, num_level_anchors_list[0])
  573. label_weights_list = images_to_levels(all_label_weights,
  574. num_level_anchors_list[0])
  575. bbox_targets_list = images_to_levels(all_bbox_targets,
  576. num_level_anchors_list[0])
  577. bbox_weights_list = images_to_levels(all_bbox_weights,
  578. num_level_anchors_list[0])
  579. return (anchors_list, labels_list, label_weights_list,
  580. bbox_targets_list, bbox_weights_list, avg_factor)
  581. def _get_targets_single(self,
  582. flat_anchors: Tensor,
  583. valid_flags: Tensor,
  584. cls_scores: Tensor,
  585. bbox_preds: Tensor,
  586. num_level_anchors: List[int],
  587. gt_instances: InstanceData,
  588. img_meta: dict,
  589. gt_instances_ignore: Optional[InstanceData] = None,
  590. unmap_outputs: bool = True,
  591. is_cls_assigner: bool = True) -> tuple:
  592. """Compute regression, classification targets for anchors in a single
  593. image.
  594. Args:
  595. flat_anchors (Tensor): Multi-level anchors of the image,
  596. which are concatenated into a single tensor of shape
  597. (num_base_priors, 4).
  598. valid_flags (Tensor): Multi level valid flags of the image,
  599. which are concatenated into a single tensor of
  600. shape (num_base_priors,).
  601. cls_scores (Tensor): Classification scores for all scale
  602. levels of the image.
  603. bbox_preds (Tensor): Box energies / deltas for all scale
  604. levels of the image.
  605. num_level_anchors (List[int]): Number of anchors of each
  606. scale level.
  607. gt_instances (:obj:`InstanceData`): Ground truth of instance
  608. annotations. It usually includes ``bboxes`` and ``labels``
  609. attributes.
  610. img_meta (dict): Meta information for current image.
  611. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  612. to be ignored during training. It includes ``bboxes`` attribute
  613. data that is ignored during training and testing.
  614. Defaults to None.
  615. unmap_outputs (bool): Whether to map outputs back to the original
  616. set of anchors. Defaults to True.
  617. is_cls_assigner (bool): Classification or regression.
  618. Defaults to True.
  619. Returns:
  620. tuple: N is the number of total anchors in the image.
  621. - anchors (Tensor): all anchors in the image with shape (N, 4).
  622. - labels (Tensor): Labels of all anchors in the image with \
  623. shape (N, ).
  624. - label_weights (Tensor): Label weights of all anchor in the \
  625. image with shape (N, ).
  626. - bbox_targets (Tensor): BBox targets of all anchors in the \
  627. image with shape (N, 4).
  628. - bbox_weights (Tensor): BBox weights of all anchors in the \
  629. image with shape (N, 4)
  630. - pos_inds (Tensor): Indices of positive anchor with shape \
  631. (num_pos, ).
  632. - neg_inds (Tensor): Indices of negative anchor with shape \
  633. (num_neg, ).
  634. - sampling_result (:obj:`SamplingResult`): Sampling results.
  635. """
  636. inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
  637. img_meta['img_shape'][:2],
  638. self.train_cfg['allowed_border'])
  639. if not inside_flags.any():
  640. raise ValueError(
  641. 'There is no valid anchor inside the image boundary. Please '
  642. 'check the image size and anchor sizes, or set '
  643. '``allowed_border`` to -1 to skip the condition.')
  644. # assign gt and sample anchors
  645. anchors = flat_anchors[inside_flags, :]
  646. num_level_anchors_inside = self.get_num_level_anchors_inside(
  647. num_level_anchors, inside_flags)
  648. bbox_preds_valid = bbox_preds[inside_flags, :]
  649. cls_scores_valid = cls_scores[inside_flags, :]
  650. assigner = self.cls_assigner if is_cls_assigner else self.reg_assigner
  651. # decode prediction out of assigner
  652. bbox_preds_valid = self.bbox_coder.decode(anchors, bbox_preds_valid)
  653. pred_instances = InstanceData(
  654. priors=anchors, bboxes=bbox_preds_valid, scores=cls_scores_valid)
  655. assign_result = assigner.assign(
  656. pred_instances=pred_instances,
  657. num_level_priors=num_level_anchors_inside,
  658. gt_instances=gt_instances,
  659. gt_instances_ignore=gt_instances_ignore)
  660. sampling_result = self.sampler.sample(
  661. assign_result=assign_result,
  662. pred_instances=pred_instances,
  663. gt_instances=gt_instances)
  664. num_valid_anchors = anchors.shape[0]
  665. bbox_targets = torch.zeros_like(anchors)
  666. bbox_weights = torch.zeros_like(anchors)
  667. labels = anchors.new_full((num_valid_anchors, ),
  668. self.num_classes,
  669. dtype=torch.long)
  670. label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
  671. pos_inds = sampling_result.pos_inds
  672. neg_inds = sampling_result.neg_inds
  673. if len(pos_inds) > 0:
  674. pos_bbox_targets = self.bbox_coder.encode(
  675. sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
  676. bbox_targets[pos_inds, :] = pos_bbox_targets
  677. bbox_weights[pos_inds, :] = 1.0
  678. labels[pos_inds] = sampling_result.pos_gt_labels
  679. if self.train_cfg['pos_weight'] <= 0:
  680. label_weights[pos_inds] = 1.0
  681. else:
  682. label_weights[pos_inds] = self.train_cfg['pos_weight']
  683. if len(neg_inds) > 0:
  684. label_weights[neg_inds] = 1.0
  685. # map up to original set of anchors
  686. if unmap_outputs:
  687. num_total_anchors = flat_anchors.size(0)
  688. anchors = unmap(anchors, num_total_anchors, inside_flags)
  689. labels = unmap(
  690. labels, num_total_anchors, inside_flags, fill=self.num_classes)
  691. label_weights = unmap(label_weights, num_total_anchors,
  692. inside_flags)
  693. bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
  694. bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
  695. return (anchors, labels, label_weights, bbox_targets, bbox_weights,
  696. pos_inds, neg_inds, sampling_result)
  697. def get_num_level_anchors_inside(self, num_level_anchors: List[int],
  698. inside_flags: Tensor) -> List[int]:
  699. """Get the anchors of each scale level inside.
  700. Args:
  701. num_level_anchors (list[int]): Number of anchors of each
  702. scale level.
  703. inside_flags (Tensor): Multi level inside flags of the image,
  704. which are concatenated into a single tensor of
  705. shape (num_base_priors,).
  706. Returns:
  707. list[int]: Number of anchors of each scale level inside.
  708. """
  709. split_inside_flags = torch.split(inside_flags, num_level_anchors)
  710. num_level_anchors_inside = [
  711. int(flags.sum()) for flags in split_inside_flags
  712. ]
  713. return num_level_anchors_inside