fcos_head.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import Scale
  6. from mmengine.structures import InstanceData
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import (ConfigType, InstanceList, MultiConfig,
  10. OptInstanceList, RangeType, reduce_mean)
  11. from ..utils import multi_apply
  12. from .anchor_free_head import AnchorFreeHead
  13. INF = 1e8
  14. @MODELS.register_module()
  15. class FCOSHead(AnchorFreeHead):
  16. """Anchor-free head used in `FCOS <https://arxiv.org/abs/1904.01355>`_.
  17. The FCOS head does not use anchor boxes. Instead bounding boxes are
  18. predicted at each pixel and a centerness measure is used to suppress
  19. low-quality predictions.
  20. Here norm_on_bbox, centerness_on_reg, dcn_on_last_conv are training
  21. tricks used in official repo, which will bring remarkable mAP gains
  22. of up to 4.9. Please see https://github.com/tianzhi0549/FCOS for
  23. more detail.
  24. Args:
  25. num_classes (int): Number of categories excluding the background
  26. category.
  27. in_channels (int): Number of channels in the input feature map.
  28. strides (Sequence[int] or Sequence[Tuple[int, int]]): Strides of points
  29. in multiple feature levels. Defaults to (4, 8, 16, 32, 64).
  30. regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple
  31. level points.
  32. center_sampling (bool): If true, use center sampling.
  33. Defaults to False.
  34. center_sample_radius (float): Radius of center sampling.
  35. Defaults to 1.5.
  36. norm_on_bbox (bool): If true, normalize the regression targets with
  37. FPN strides. Defaults to False.
  38. centerness_on_reg (bool): If true, position centerness on the
  39. regress branch. Please refer to https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042.
  40. Defaults to False.
  41. conv_bias (bool or str): If specified as `auto`, it will be decided by
  42. the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
  43. None, otherwise False. Defaults to "auto".
  44. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
  45. loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
  46. loss_centerness (:obj:`ConfigDict`, or dict): Config of centerness
  47. loss.
  48. norm_cfg (:obj:`ConfigDict` or dict): dictionary to construct and
  49. config norm layer. Defaults to
  50. ``norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)``.
  51. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  52. dict]): Initialization config dict.
  53. Example:
  54. >>> self = FCOSHead(11, 7)
  55. >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
  56. >>> cls_score, bbox_pred, centerness = self.forward(feats)
  57. >>> assert len(cls_score) == len(self.scales)
  58. """ # noqa: E501
  59. def __init__(self,
  60. num_classes: int,
  61. in_channels: int,
  62. regress_ranges: RangeType = ((-1, 64), (64, 128), (128, 256),
  63. (256, 512), (512, INF)),
  64. center_sampling: bool = False,
  65. center_sample_radius: float = 1.5,
  66. norm_on_bbox: bool = False,
  67. centerness_on_reg: bool = False,
  68. loss_cls: ConfigType = dict(
  69. type='FocalLoss',
  70. use_sigmoid=True,
  71. gamma=2.0,
  72. alpha=0.25,
  73. loss_weight=1.0),
  74. loss_bbox: ConfigType = dict(type='IoULoss', loss_weight=1.0),
  75. loss_centerness: ConfigType = dict(
  76. type='CrossEntropyLoss',
  77. use_sigmoid=True,
  78. loss_weight=1.0),
  79. norm_cfg: ConfigType = dict(
  80. type='GN', num_groups=32, requires_grad=True),
  81. init_cfg: MultiConfig = dict(
  82. type='Normal',
  83. layer='Conv2d',
  84. std=0.01,
  85. override=dict(
  86. type='Normal',
  87. name='conv_cls',
  88. std=0.01,
  89. bias_prob=0.01)),
  90. **kwargs) -> None:
  91. self.regress_ranges = regress_ranges
  92. self.center_sampling = center_sampling
  93. self.center_sample_radius = center_sample_radius
  94. self.norm_on_bbox = norm_on_bbox
  95. self.centerness_on_reg = centerness_on_reg
  96. super().__init__(
  97. num_classes=num_classes,
  98. in_channels=in_channels,
  99. loss_cls=loss_cls,
  100. loss_bbox=loss_bbox,
  101. norm_cfg=norm_cfg,
  102. init_cfg=init_cfg,
  103. **kwargs)
  104. self.loss_centerness = MODELS.build(loss_centerness)
  105. def _init_layers(self) -> None:
  106. """Initialize layers of the head."""
  107. super()._init_layers()
  108. self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
  109. self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
  110. def forward(
  111. self, x: Tuple[Tensor]
  112. ) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
  113. """Forward features from the upstream network.
  114. Args:
  115. feats (tuple[Tensor]): Features from the upstream network, each is
  116. a 4D-tensor.
  117. Returns:
  118. tuple: A tuple of each level outputs.
  119. - cls_scores (list[Tensor]): Box scores for each scale level, \
  120. each is a 4D-tensor, the channel number is \
  121. num_points * num_classes.
  122. - bbox_preds (list[Tensor]): Box energies / deltas for each \
  123. scale level, each is a 4D-tensor, the channel number is \
  124. num_points * 4.
  125. - centernesses (list[Tensor]): centerness for each scale level, \
  126. each is a 4D-tensor, the channel number is num_points * 1.
  127. """
  128. return multi_apply(self.forward_single, x, self.scales, self.strides)
  129. def forward_single(self, x: Tensor, scale: Scale,
  130. stride: int) -> Tuple[Tensor, Tensor, Tensor]:
  131. """Forward features of a single scale level.
  132. Args:
  133. x (Tensor): FPN feature maps of the specified stride.
  134. scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize
  135. the bbox prediction.
  136. stride (int): The corresponding stride for feature maps, only
  137. used to normalize the bbox prediction when self.norm_on_bbox
  138. is True.
  139. Returns:
  140. tuple: scores for each class, bbox predictions and centerness
  141. predictions of input feature maps.
  142. """
  143. cls_score, bbox_pred, cls_feat, reg_feat = super().forward_single(x)
  144. if self.centerness_on_reg:
  145. centerness = self.conv_centerness(reg_feat)
  146. else:
  147. centerness = self.conv_centerness(cls_feat)
  148. # scale the bbox_pred of different level
  149. # float to avoid overflow when enabling FP16
  150. bbox_pred = scale(bbox_pred).float()
  151. if self.norm_on_bbox:
  152. # bbox_pred needed for gradient computation has been modified
  153. # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
  154. # F.relu(bbox_pred) with bbox_pred.clamp(min=0)
  155. bbox_pred = bbox_pred.clamp(min=0)
  156. if not self.training:
  157. bbox_pred *= stride
  158. else:
  159. bbox_pred = bbox_pred.exp()
  160. return cls_score, bbox_pred, centerness
  161. def loss_by_feat(
  162. self,
  163. cls_scores: List[Tensor],
  164. bbox_preds: List[Tensor],
  165. centernesses: List[Tensor],
  166. batch_gt_instances: InstanceList,
  167. batch_img_metas: List[dict],
  168. batch_gt_instances_ignore: OptInstanceList = None
  169. ) -> Dict[str, Tensor]:
  170. """Calculate the loss based on the features extracted by the detection
  171. head.
  172. Args:
  173. cls_scores (list[Tensor]): Box scores for each scale level,
  174. each is a 4D-tensor, the channel number is
  175. num_points * num_classes.
  176. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  177. level, each is a 4D-tensor, the channel number is
  178. num_points * 4.
  179. centernesses (list[Tensor]): centerness for each scale level, each
  180. is a 4D-tensor, the channel number is num_points * 1.
  181. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  182. gt_instance. It usually includes ``bboxes`` and ``labels``
  183. attributes.
  184. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  185. image size, scaling factor, etc.
  186. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  187. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  188. data that is ignored during training and testing.
  189. Defaults to None.
  190. Returns:
  191. dict[str, Tensor]: A dictionary of loss components.
  192. """
  193. assert len(cls_scores) == len(bbox_preds) == len(centernesses)
  194. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  195. all_level_points = self.prior_generator.grid_priors(
  196. featmap_sizes,
  197. dtype=bbox_preds[0].dtype,
  198. device=bbox_preds[0].device)
  199. labels, bbox_targets = self.get_targets(all_level_points,
  200. batch_gt_instances)
  201. num_imgs = cls_scores[0].size(0)
  202. # flatten cls_scores, bbox_preds and centerness
  203. flatten_cls_scores = [
  204. cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
  205. for cls_score in cls_scores
  206. ]
  207. flatten_bbox_preds = [
  208. bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  209. for bbox_pred in bbox_preds
  210. ]
  211. flatten_centerness = [
  212. centerness.permute(0, 2, 3, 1).reshape(-1)
  213. for centerness in centernesses
  214. ]
  215. flatten_cls_scores = torch.cat(flatten_cls_scores)
  216. flatten_bbox_preds = torch.cat(flatten_bbox_preds)
  217. flatten_centerness = torch.cat(flatten_centerness)
  218. flatten_labels = torch.cat(labels)
  219. flatten_bbox_targets = torch.cat(bbox_targets)
  220. # repeat points to align with bbox_preds
  221. flatten_points = torch.cat(
  222. [points.repeat(num_imgs, 1) for points in all_level_points])
  223. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  224. bg_class_ind = self.num_classes
  225. pos_inds = ((flatten_labels >= 0)
  226. & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
  227. num_pos = torch.tensor(
  228. len(pos_inds), dtype=torch.float, device=bbox_preds[0].device)
  229. num_pos = max(reduce_mean(num_pos), 1.0)
  230. loss_cls = self.loss_cls(
  231. flatten_cls_scores, flatten_labels, avg_factor=num_pos)
  232. pos_bbox_preds = flatten_bbox_preds[pos_inds]
  233. pos_centerness = flatten_centerness[pos_inds]
  234. pos_bbox_targets = flatten_bbox_targets[pos_inds]
  235. pos_centerness_targets = self.centerness_target(pos_bbox_targets)
  236. # centerness weighted iou loss
  237. centerness_denorm = max(
  238. reduce_mean(pos_centerness_targets.sum().detach()), 1e-6)
  239. if len(pos_inds) > 0:
  240. pos_points = flatten_points[pos_inds]
  241. pos_decoded_bbox_preds = self.bbox_coder.decode(
  242. pos_points, pos_bbox_preds)
  243. pos_decoded_target_preds = self.bbox_coder.decode(
  244. pos_points, pos_bbox_targets)
  245. loss_bbox = self.loss_bbox(
  246. pos_decoded_bbox_preds,
  247. pos_decoded_target_preds,
  248. weight=pos_centerness_targets,
  249. avg_factor=centerness_denorm)
  250. loss_centerness = self.loss_centerness(
  251. pos_centerness, pos_centerness_targets, avg_factor=num_pos)
  252. else:
  253. loss_bbox = pos_bbox_preds.sum()
  254. loss_centerness = pos_centerness.sum()
  255. return dict(
  256. loss_cls=loss_cls,
  257. loss_bbox=loss_bbox,
  258. loss_centerness=loss_centerness)
  259. def get_targets(
  260. self, points: List[Tensor], batch_gt_instances: InstanceList
  261. ) -> Tuple[List[Tensor], List[Tensor]]:
  262. """Compute regression, classification and centerness targets for points
  263. in multiple images.
  264. Args:
  265. points (list[Tensor]): Points of each fpn level, each has shape
  266. (num_points, 2).
  267. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  268. gt_instance. It usually includes ``bboxes`` and ``labels``
  269. attributes.
  270. Returns:
  271. tuple: Targets of each level.
  272. - concat_lvl_labels (list[Tensor]): Labels of each level.
  273. - concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \
  274. level.
  275. """
  276. assert len(points) == len(self.regress_ranges)
  277. num_levels = len(points)
  278. # expand regress ranges to align with points
  279. expanded_regress_ranges = [
  280. points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
  281. points[i]) for i in range(num_levels)
  282. ]
  283. # concat all levels points and regress ranges
  284. concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
  285. concat_points = torch.cat(points, dim=0)
  286. # the number of points per img, per lvl
  287. num_points = [center.size(0) for center in points]
  288. # get labels and bbox_targets of each image
  289. labels_list, bbox_targets_list = multi_apply(
  290. self._get_targets_single,
  291. batch_gt_instances,
  292. points=concat_points,
  293. regress_ranges=concat_regress_ranges,
  294. num_points_per_lvl=num_points)
  295. # split to per img, per level
  296. labels_list = [labels.split(num_points, 0) for labels in labels_list]
  297. bbox_targets_list = [
  298. bbox_targets.split(num_points, 0)
  299. for bbox_targets in bbox_targets_list
  300. ]
  301. # concat per level image
  302. concat_lvl_labels = []
  303. concat_lvl_bbox_targets = []
  304. for i in range(num_levels):
  305. concat_lvl_labels.append(
  306. torch.cat([labels[i] for labels in labels_list]))
  307. bbox_targets = torch.cat(
  308. [bbox_targets[i] for bbox_targets in bbox_targets_list])
  309. if self.norm_on_bbox:
  310. bbox_targets = bbox_targets / self.strides[i]
  311. concat_lvl_bbox_targets.append(bbox_targets)
  312. return concat_lvl_labels, concat_lvl_bbox_targets
  313. def _get_targets_single(
  314. self, gt_instances: InstanceData, points: Tensor,
  315. regress_ranges: Tensor,
  316. num_points_per_lvl: List[int]) -> Tuple[Tensor, Tensor]:
  317. """Compute regression and classification targets for a single image."""
  318. num_points = points.size(0)
  319. num_gts = len(gt_instances)
  320. gt_bboxes = gt_instances.bboxes
  321. gt_labels = gt_instances.labels
  322. if num_gts == 0:
  323. return gt_labels.new_full((num_points,), self.num_classes), \
  324. gt_bboxes.new_zeros((num_points, 4))
  325. areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
  326. gt_bboxes[:, 3] - gt_bboxes[:, 1])
  327. # TODO: figure out why these two are different
  328. # areas = areas[None].expand(num_points, num_gts)
  329. areas = areas[None].repeat(num_points, 1)
  330. regress_ranges = regress_ranges[:, None, :].expand(
  331. num_points, num_gts, 2)
  332. gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
  333. xs, ys = points[:, 0], points[:, 1]
  334. xs = xs[:, None].expand(num_points, num_gts)
  335. ys = ys[:, None].expand(num_points, num_gts)
  336. left = xs - gt_bboxes[..., 0]
  337. right = gt_bboxes[..., 2] - xs
  338. top = ys - gt_bboxes[..., 1]
  339. bottom = gt_bboxes[..., 3] - ys
  340. bbox_targets = torch.stack((left, top, right, bottom), -1)
  341. if self.center_sampling:
  342. # condition1: inside a `center bbox`
  343. radius = self.center_sample_radius
  344. center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2
  345. center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2
  346. center_gts = torch.zeros_like(gt_bboxes)
  347. stride = center_xs.new_zeros(center_xs.shape)
  348. # project the points on current lvl back to the `original` sizes
  349. lvl_begin = 0
  350. for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
  351. lvl_end = lvl_begin + num_points_lvl
  352. stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
  353. lvl_begin = lvl_end
  354. x_mins = center_xs - stride
  355. y_mins = center_ys - stride
  356. x_maxs = center_xs + stride
  357. y_maxs = center_ys + stride
  358. center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0],
  359. x_mins, gt_bboxes[..., 0])
  360. center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1],
  361. y_mins, gt_bboxes[..., 1])
  362. center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2],
  363. gt_bboxes[..., 2], x_maxs)
  364. center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3],
  365. gt_bboxes[..., 3], y_maxs)
  366. cb_dist_left = xs - center_gts[..., 0]
  367. cb_dist_right = center_gts[..., 2] - xs
  368. cb_dist_top = ys - center_gts[..., 1]
  369. cb_dist_bottom = center_gts[..., 3] - ys
  370. center_bbox = torch.stack(
  371. (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
  372. inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
  373. else:
  374. # condition1: inside a gt bbox
  375. inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
  376. # condition2: limit the regression range for each location
  377. max_regress_distance = bbox_targets.max(-1)[0]
  378. inside_regress_range = (
  379. (max_regress_distance >= regress_ranges[..., 0])
  380. & (max_regress_distance <= regress_ranges[..., 1]))
  381. # if there are still more than one objects for a location,
  382. # we choose the one with minimal area
  383. areas[inside_gt_bbox_mask == 0] = INF
  384. areas[inside_regress_range == 0] = INF
  385. min_area, min_area_inds = areas.min(dim=1)
  386. labels = gt_labels[min_area_inds]
  387. labels[min_area == INF] = self.num_classes # set as BG
  388. bbox_targets = bbox_targets[range(num_points), min_area_inds]
  389. return labels, bbox_targets
  390. def centerness_target(self, pos_bbox_targets: Tensor) -> Tensor:
  391. """Compute centerness targets.
  392. Args:
  393. pos_bbox_targets (Tensor): BBox targets of positive bboxes in shape
  394. (num_pos, 4)
  395. Returns:
  396. Tensor: Centerness target.
  397. """
  398. # only calculate pos centerness targets, otherwise there may be nan
  399. left_right = pos_bbox_targets[:, [0, 2]]
  400. top_bottom = pos_bbox_targets[:, [1, 3]]
  401. if len(left_right) == 0:
  402. centerness_targets = left_right[..., 0]
  403. else:
  404. centerness_targets = (
  405. left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
  406. top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
  407. return torch.sqrt(centerness_targets)