atss_head.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Sequence, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import ConvModule, Scale
  6. from mmengine.structures import InstanceData
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType,
  10. OptInstanceList, reduce_mean)
  11. from ..task_modules.prior_generators import anchor_inside_flags
  12. from ..utils import images_to_levels, multi_apply, unmap
  13. from .anchor_head import AnchorHead
  14. @MODELS.register_module()
  15. class ATSSHead(AnchorHead):
  16. """Detection Head of `ATSS <https://arxiv.org/abs/1912.02424>`_.
  17. ATSS head structure is similar with FCOS, however ATSS use anchor boxes
  18. and assign label by Adaptive Training Sample Selection instead max-iou.
  19. Args:
  20. num_classes (int): Number of categories excluding the background
  21. category.
  22. in_channels (int): Number of channels in the input feature map.
  23. pred_kernel_size (int): Kernel size of ``nn.Conv2d``
  24. stacked_convs (int): Number of stacking convs of the head.
  25. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
  26. convolution layer. Defaults to None.
  27. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
  28. layer. Defaults to ``dict(type='GN', num_groups=32,
  29. requires_grad=True)``.
  30. reg_decoded_bbox (bool): If true, the regression loss would be
  31. applied directly on decoded bounding boxes, converting both
  32. the predicted boxes and regression targets to absolute
  33. coordinates format. Defaults to False. It should be `True` when
  34. using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
  35. loss_centerness (:obj:`ConfigDict` or dict): Config of centerness loss.
  36. Defaults to ``dict(type='CrossEntropyLoss', use_sigmoid=True,
  37. loss_weight=1.0)``.
  38. init_cfg (:obj:`ConfigDict` or dict or list[dict] or
  39. list[:obj:`ConfigDict`]): Initialization config dict.
  40. """
  41. def __init__(self,
  42. num_classes: int,
  43. in_channels: int,
  44. pred_kernel_size: int = 3,
  45. stacked_convs: int = 4,
  46. conv_cfg: OptConfigType = None,
  47. norm_cfg: ConfigType = dict(
  48. type='GN', num_groups=32, requires_grad=True),
  49. reg_decoded_bbox: bool = True,
  50. loss_centerness: ConfigType = dict(
  51. type='CrossEntropyLoss',
  52. use_sigmoid=True,
  53. loss_weight=1.0),
  54. init_cfg: MultiConfig = dict(
  55. type='Normal',
  56. layer='Conv2d',
  57. std=0.01,
  58. override=dict(
  59. type='Normal',
  60. name='atss_cls',
  61. std=0.01,
  62. bias_prob=0.01)),
  63. **kwargs) -> None:
  64. self.pred_kernel_size = pred_kernel_size
  65. self.stacked_convs = stacked_convs
  66. self.conv_cfg = conv_cfg
  67. self.norm_cfg = norm_cfg
  68. super().__init__(
  69. num_classes=num_classes,
  70. in_channels=in_channels,
  71. reg_decoded_bbox=reg_decoded_bbox,
  72. init_cfg=init_cfg,
  73. **kwargs)
  74. self.sampling = False
  75. self.loss_centerness = MODELS.build(loss_centerness)
  76. def _init_layers(self) -> None:
  77. """Initialize layers of the head."""
  78. self.relu = nn.ReLU(inplace=True)
  79. self.cls_convs = nn.ModuleList()
  80. self.reg_convs = nn.ModuleList()
  81. for i in range(self.stacked_convs):
  82. chn = self.in_channels if i == 0 else self.feat_channels
  83. self.cls_convs.append(
  84. ConvModule(
  85. chn,
  86. self.feat_channels,
  87. 3,
  88. stride=1,
  89. padding=1,
  90. conv_cfg=self.conv_cfg,
  91. norm_cfg=self.norm_cfg))
  92. self.reg_convs.append(
  93. ConvModule(
  94. chn,
  95. self.feat_channels,
  96. 3,
  97. stride=1,
  98. padding=1,
  99. conv_cfg=self.conv_cfg,
  100. norm_cfg=self.norm_cfg))
  101. pred_pad_size = self.pred_kernel_size // 2
  102. self.atss_cls = nn.Conv2d(
  103. self.feat_channels,
  104. self.num_anchors * self.cls_out_channels,
  105. self.pred_kernel_size,
  106. padding=pred_pad_size)
  107. self.atss_reg = nn.Conv2d(
  108. self.feat_channels,
  109. self.num_base_priors * 4,
  110. self.pred_kernel_size,
  111. padding=pred_pad_size)
  112. self.atss_centerness = nn.Conv2d(
  113. self.feat_channels,
  114. self.num_base_priors * 1,
  115. self.pred_kernel_size,
  116. padding=pred_pad_size)
  117. self.scales = nn.ModuleList(
  118. [Scale(1.0) for _ in self.prior_generator.strides])
  119. def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]:
  120. """Forward features from the upstream network.
  121. Args:
  122. x (tuple[Tensor]): Features from the upstream network, each is
  123. a 4D-tensor.
  124. Returns:
  125. tuple: Usually a tuple of classification scores and bbox prediction
  126. cls_scores (list[Tensor]): Classification scores for all scale
  127. levels, each is a 4D-tensor, the channels number is
  128. num_anchors * num_classes.
  129. bbox_preds (list[Tensor]): Box energies / deltas for all scale
  130. levels, each is a 4D-tensor, the channels number is
  131. num_anchors * 4.
  132. """
  133. return multi_apply(self.forward_single, x, self.scales)
  134. def forward_single(self, x: Tensor, scale: Scale) -> Sequence[Tensor]:
  135. """Forward feature of a single scale level.
  136. Args:
  137. x (Tensor): Features of a single scale level.
  138. scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
  139. the bbox prediction.
  140. Returns:
  141. tuple:
  142. cls_score (Tensor): Cls scores for a single scale level
  143. the channels number is num_anchors * num_classes.
  144. bbox_pred (Tensor): Box energies / deltas for a single scale
  145. level, the channels number is num_anchors * 4.
  146. centerness (Tensor): Centerness for a single scale level, the
  147. channel number is (N, num_anchors * 1, H, W).
  148. """
  149. cls_feat = x
  150. reg_feat = x
  151. for cls_conv in self.cls_convs:
  152. cls_feat = cls_conv(cls_feat)
  153. for reg_conv in self.reg_convs:
  154. reg_feat = reg_conv(reg_feat)
  155. cls_score = self.atss_cls(cls_feat)
  156. # we just follow atss, not apply exp in bbox_pred
  157. bbox_pred = scale(self.atss_reg(reg_feat)).float()
  158. centerness = self.atss_centerness(reg_feat)
  159. return cls_score, bbox_pred, centerness
  160. def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor,
  161. bbox_pred: Tensor, centerness: Tensor,
  162. labels: Tensor, label_weights: Tensor,
  163. bbox_targets: Tensor, avg_factor: float) -> dict:
  164. """Calculate the loss of a single scale level based on the features
  165. extracted by the detection head.
  166. Args:
  167. cls_score (Tensor): Box scores for each scale level
  168. Has shape (N, num_anchors * num_classes, H, W).
  169. bbox_pred (Tensor): Box energies / deltas for each scale
  170. level with shape (N, num_anchors * 4, H, W).
  171. anchors (Tensor): Box reference for each scale level with shape
  172. (N, num_total_anchors, 4).
  173. labels (Tensor): Labels of each anchors with shape
  174. (N, num_total_anchors).
  175. label_weights (Tensor): Label weights of each anchor with shape
  176. (N, num_total_anchors)
  177. bbox_targets (Tensor): BBox regression targets of each anchor
  178. weight shape (N, num_total_anchors, 4).
  179. avg_factor (float): Average factor that is used to average
  180. the loss. When using sampling method, avg_factor is usually
  181. the sum of positive and negative priors. When using
  182. `PseudoSampler`, `avg_factor` is usually equal to the number
  183. of positive priors.
  184. Returns:
  185. dict[str, Tensor]: A dictionary of loss components.
  186. """
  187. anchors = anchors.reshape(-1, 4)
  188. cls_score = cls_score.permute(0, 2, 3, 1).reshape(
  189. -1, self.cls_out_channels).contiguous()
  190. bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  191. centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
  192. bbox_targets = bbox_targets.reshape(-1, 4)
  193. labels = labels.reshape(-1)
  194. label_weights = label_weights.reshape(-1)
  195. # classification loss
  196. loss_cls = self.loss_cls(
  197. cls_score, labels, label_weights, avg_factor=avg_factor)
  198. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  199. bg_class_ind = self.num_classes
  200. pos_inds = ((labels >= 0)
  201. & (labels < bg_class_ind)).nonzero().squeeze(1)
  202. if len(pos_inds) > 0:
  203. pos_bbox_targets = bbox_targets[pos_inds]
  204. pos_bbox_pred = bbox_pred[pos_inds]
  205. pos_anchors = anchors[pos_inds]
  206. pos_centerness = centerness[pos_inds]
  207. centerness_targets = self.centerness_target(
  208. pos_anchors, pos_bbox_targets)
  209. pos_decode_bbox_pred = self.bbox_coder.decode(
  210. pos_anchors, pos_bbox_pred)
  211. # regression loss
  212. loss_bbox = self.loss_bbox(
  213. pos_decode_bbox_pred,
  214. pos_bbox_targets,
  215. weight=centerness_targets,
  216. avg_factor=1.0)
  217. # centerness loss
  218. loss_centerness = self.loss_centerness(
  219. pos_centerness, centerness_targets, avg_factor=avg_factor)
  220. else:
  221. loss_bbox = bbox_pred.sum() * 0
  222. loss_centerness = centerness.sum() * 0
  223. centerness_targets = bbox_targets.new_tensor(0.)
  224. return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
  225. def loss_by_feat(
  226. self,
  227. cls_scores: List[Tensor],
  228. bbox_preds: List[Tensor],
  229. centernesses: List[Tensor],
  230. batch_gt_instances: InstanceList,
  231. batch_img_metas: List[dict],
  232. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  233. """Calculate the loss based on the features extracted by the detection
  234. head.
  235. Args:
  236. cls_scores (list[Tensor]): Box scores for each scale level
  237. Has shape (N, num_anchors * num_classes, H, W)
  238. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  239. level with shape (N, num_anchors * 4, H, W)
  240. centernesses (list[Tensor]): Centerness for each scale
  241. level with shape (N, num_anchors * 1, H, W)
  242. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  243. gt_instance. It usually includes ``bboxes`` and ``labels``
  244. attributes.
  245. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  246. image size, scaling factor, etc.
  247. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  248. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  249. data that is ignored during training and testing.
  250. Defaults to None.
  251. Returns:
  252. dict[str, Tensor]: A dictionary of loss components.
  253. """
  254. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  255. assert len(featmap_sizes) == self.prior_generator.num_levels
  256. device = cls_scores[0].device
  257. anchor_list, valid_flag_list = self.get_anchors(
  258. featmap_sizes, batch_img_metas, device=device)
  259. cls_reg_targets = self.get_targets(
  260. anchor_list,
  261. valid_flag_list,
  262. batch_gt_instances,
  263. batch_img_metas,
  264. batch_gt_instances_ignore=batch_gt_instances_ignore)
  265. (anchor_list, labels_list, label_weights_list, bbox_targets_list,
  266. bbox_weights_list, avg_factor) = cls_reg_targets
  267. avg_factor = reduce_mean(
  268. torch.tensor(avg_factor, dtype=torch.float, device=device)).item()
  269. losses_cls, losses_bbox, loss_centerness, \
  270. bbox_avg_factor = multi_apply(
  271. self.loss_by_feat_single,
  272. anchor_list,
  273. cls_scores,
  274. bbox_preds,
  275. centernesses,
  276. labels_list,
  277. label_weights_list,
  278. bbox_targets_list,
  279. avg_factor=avg_factor)
  280. bbox_avg_factor = sum(bbox_avg_factor)
  281. bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item()
  282. losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
  283. return dict(
  284. loss_cls=losses_cls,
  285. loss_bbox=losses_bbox,
  286. loss_centerness=loss_centerness)
  287. def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor:
  288. """Calculate the centerness between anchors and gts.
  289. Only calculate pos centerness targets, otherwise there may be nan.
  290. Args:
  291. anchors (Tensor): Anchors with shape (N, 4), "xyxy" format.
  292. gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format.
  293. Returns:
  294. Tensor: Centerness between anchors and gts.
  295. """
  296. anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
  297. anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
  298. l_ = anchors_cx - gts[:, 0]
  299. t_ = anchors_cy - gts[:, 1]
  300. r_ = gts[:, 2] - anchors_cx
  301. b_ = gts[:, 3] - anchors_cy
  302. left_right = torch.stack([l_, r_], dim=1)
  303. top_bottom = torch.stack([t_, b_], dim=1)
  304. centerness = torch.sqrt(
  305. (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
  306. (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
  307. assert not torch.isnan(centerness).any()
  308. return centerness
  309. def get_targets(self,
  310. anchor_list: List[List[Tensor]],
  311. valid_flag_list: List[List[Tensor]],
  312. batch_gt_instances: InstanceList,
  313. batch_img_metas: List[dict],
  314. batch_gt_instances_ignore: OptInstanceList = None,
  315. unmap_outputs: bool = True) -> tuple:
  316. """Get targets for ATSS head.
  317. This method is almost the same as `AnchorHead.get_targets()`. Besides
  318. returning the targets as the parent method does, it also returns the
  319. anchors as the first element of the returned tuple.
  320. """
  321. num_imgs = len(batch_img_metas)
  322. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  323. # anchor number of multi levels
  324. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  325. num_level_anchors_list = [num_level_anchors] * num_imgs
  326. # concat all level anchors and flags to a single tensor
  327. for i in range(num_imgs):
  328. assert len(anchor_list[i]) == len(valid_flag_list[i])
  329. anchor_list[i] = torch.cat(anchor_list[i])
  330. valid_flag_list[i] = torch.cat(valid_flag_list[i])
  331. # compute targets for each image
  332. if batch_gt_instances_ignore is None:
  333. batch_gt_instances_ignore = [None] * num_imgs
  334. (all_anchors, all_labels, all_label_weights, all_bbox_targets,
  335. all_bbox_weights, pos_inds_list, neg_inds_list,
  336. sampling_results_list) = multi_apply(
  337. self._get_targets_single,
  338. anchor_list,
  339. valid_flag_list,
  340. num_level_anchors_list,
  341. batch_gt_instances,
  342. batch_img_metas,
  343. batch_gt_instances_ignore,
  344. unmap_outputs=unmap_outputs)
  345. # Get `avg_factor` of all images, which calculate in `SamplingResult`.
  346. # When using sampling method, avg_factor is usually the sum of
  347. # positive and negative priors. When using `PseudoSampler`,
  348. # `avg_factor` is usually equal to the number of positive priors.
  349. avg_factor = sum(
  350. [results.avg_factor for results in sampling_results_list])
  351. # split targets to a list w.r.t. multiple levels
  352. anchors_list = images_to_levels(all_anchors, num_level_anchors)
  353. labels_list = images_to_levels(all_labels, num_level_anchors)
  354. label_weights_list = images_to_levels(all_label_weights,
  355. num_level_anchors)
  356. bbox_targets_list = images_to_levels(all_bbox_targets,
  357. num_level_anchors)
  358. bbox_weights_list = images_to_levels(all_bbox_weights,
  359. num_level_anchors)
  360. return (anchors_list, labels_list, label_weights_list,
  361. bbox_targets_list, bbox_weights_list, avg_factor)
  362. def _get_targets_single(self,
  363. flat_anchors: Tensor,
  364. valid_flags: Tensor,
  365. num_level_anchors: List[int],
  366. gt_instances: InstanceData,
  367. img_meta: dict,
  368. gt_instances_ignore: Optional[InstanceData] = None,
  369. unmap_outputs: bool = True) -> tuple:
  370. """Compute regression, classification targets for anchors in a single
  371. image.
  372. Args:
  373. flat_anchors (Tensor): Multi-level anchors of the image, which are
  374. concatenated into a single tensor of shape (num_anchors ,4)
  375. valid_flags (Tensor): Multi level valid flags of the image,
  376. which are concatenated into a single tensor of
  377. shape (num_anchors,).
  378. num_level_anchors (List[int]): Number of anchors of each scale
  379. level.
  380. gt_instances (:obj:`InstanceData`): Ground truth of instance
  381. annotations. It usually includes ``bboxes`` and ``labels``
  382. attributes.
  383. img_meta (dict): Meta information for current image.
  384. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  385. to be ignored during training. It includes ``bboxes`` attribute
  386. data that is ignored during training and testing.
  387. Defaults to None.
  388. unmap_outputs (bool): Whether to map outputs back to the original
  389. set of anchors.
  390. Returns:
  391. tuple: N is the number of total anchors in the image.
  392. labels (Tensor): Labels of all anchors in the image with shape
  393. (N,).
  394. label_weights (Tensor): Label weights of all anchor in the
  395. image with shape (N,).
  396. bbox_targets (Tensor): BBox targets of all anchors in the
  397. image with shape (N, 4).
  398. bbox_weights (Tensor): BBox weights of all anchors in the
  399. image with shape (N, 4)
  400. pos_inds (Tensor): Indices of positive anchor with shape
  401. (num_pos,).
  402. neg_inds (Tensor): Indices of negative anchor with shape
  403. (num_neg,).
  404. sampling_result (:obj:`SamplingResult`): Sampling results.
  405. """
  406. inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
  407. img_meta['img_shape'][:2],
  408. self.train_cfg['allowed_border'])
  409. if not inside_flags.any():
  410. raise ValueError(
  411. 'There is no valid anchor inside the image boundary. Please '
  412. 'check the image size and anchor sizes, or set '
  413. '``allowed_border`` to -1 to skip the condition.')
  414. # assign gt and sample anchors
  415. anchors = flat_anchors[inside_flags, :]
  416. num_level_anchors_inside = self.get_num_level_anchors_inside(
  417. num_level_anchors, inside_flags)
  418. pred_instances = InstanceData(priors=anchors)
  419. assign_result = self.assigner.assign(pred_instances,
  420. num_level_anchors_inside,
  421. gt_instances, gt_instances_ignore)
  422. sampling_result = self.sampler.sample(assign_result, pred_instances,
  423. gt_instances)
  424. num_valid_anchors = anchors.shape[0]
  425. bbox_targets = torch.zeros_like(anchors)
  426. bbox_weights = torch.zeros_like(anchors)
  427. labels = anchors.new_full((num_valid_anchors, ),
  428. self.num_classes,
  429. dtype=torch.long)
  430. label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
  431. pos_inds = sampling_result.pos_inds
  432. neg_inds = sampling_result.neg_inds
  433. if len(pos_inds) > 0:
  434. if self.reg_decoded_bbox:
  435. pos_bbox_targets = sampling_result.pos_gt_bboxes
  436. else:
  437. pos_bbox_targets = self.bbox_coder.encode(
  438. sampling_result.pos_priors, sampling_result.pos_gt_bboxes)
  439. bbox_targets[pos_inds, :] = pos_bbox_targets
  440. bbox_weights[pos_inds, :] = 1.0
  441. labels[pos_inds] = sampling_result.pos_gt_labels
  442. if self.train_cfg['pos_weight'] <= 0:
  443. label_weights[pos_inds] = 1.0
  444. else:
  445. label_weights[pos_inds] = self.train_cfg['pos_weight']
  446. if len(neg_inds) > 0:
  447. label_weights[neg_inds] = 1.0
  448. # map up to original set of anchors
  449. if unmap_outputs:
  450. num_total_anchors = flat_anchors.size(0)
  451. anchors = unmap(anchors, num_total_anchors, inside_flags)
  452. labels = unmap(
  453. labels, num_total_anchors, inside_flags, fill=self.num_classes)
  454. label_weights = unmap(label_weights, num_total_anchors,
  455. inside_flags)
  456. bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
  457. bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
  458. return (anchors, labels, label_weights, bbox_targets, bbox_weights,
  459. pos_inds, neg_inds, sampling_result)
  460. def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
  461. """Get the number of valid anchors in every level."""
  462. split_inside_flags = torch.split(inside_flags, num_level_anchors)
  463. num_level_anchors_inside = [
  464. int(flags.sum()) for flags in split_inside_flags
  465. ]
  466. return num_level_anchors_inside