autoassign_head.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Sequence, Tuple
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import Scale
  7. from mmengine.model import bias_init_with_prob, normal_init
  8. from mmengine.structures import InstanceData
  9. from torch import Tensor
  10. from mmdet.registry import MODELS
  11. from mmdet.structures.bbox import bbox_overlaps
  12. from mmdet.utils import InstanceList, OptInstanceList, reduce_mean
  13. from ..task_modules.prior_generators import MlvlPointGenerator
  14. from ..utils import levels_to_images, multi_apply
  15. from .fcos_head import FCOSHead
  16. EPS = 1e-12
  17. class CenterPrior(nn.Module):
  18. """Center Weighting module to adjust the category-specific prior
  19. distributions.
  20. Args:
  21. force_topk (bool): When no point falls into gt_bbox, forcibly
  22. select the k points closest to the center to calculate
  23. the center prior. Defaults to False.
  24. topk (int): The number of points used to calculate the
  25. center prior when no point falls in gt_bbox. Only work when
  26. force_topk if True. Defaults to 9.
  27. num_classes (int): The class number of dataset. Defaults to 80.
  28. strides (Sequence[int]): The stride of each input feature map.
  29. Defaults to (8, 16, 32, 64, 128).
  30. """
  31. def __init__(
  32. self,
  33. force_topk: bool = False,
  34. topk: int = 9,
  35. num_classes: int = 80,
  36. strides: Sequence[int] = (8, 16, 32, 64, 128)
  37. ) -> None:
  38. super().__init__()
  39. self.mean = nn.Parameter(torch.zeros(num_classes, 2))
  40. self.sigma = nn.Parameter(torch.ones(num_classes, 2))
  41. self.strides = strides
  42. self.force_topk = force_topk
  43. self.topk = topk
  44. def forward(self, anchor_points_list: List[Tensor],
  45. gt_instances: InstanceData,
  46. inside_gt_bbox_mask: Tensor) -> Tuple[Tensor, Tensor]:
  47. """Get the center prior of each point on the feature map for each
  48. instance.
  49. Args:
  50. anchor_points_list (list[Tensor]): list of coordinate
  51. of points on feature map. Each with shape
  52. (num_points, 2).
  53. gt_instances (:obj:`InstanceData`): Ground truth of instance
  54. annotations. It should includes ``bboxes`` and ``labels``
  55. attributes.
  56. inside_gt_bbox_mask (Tensor): Tensor of bool type,
  57. with shape of (num_points, num_gt), each
  58. value is used to mark whether this point falls
  59. within a certain gt.
  60. Returns:
  61. tuple[Tensor, Tensor]:
  62. - center_prior_weights(Tensor): Float tensor with shape of \
  63. (num_points, num_gt). Each value represents the center \
  64. weighting coefficient.
  65. - inside_gt_bbox_mask (Tensor): Tensor of bool type, with shape \
  66. of (num_points, num_gt), each value is used to mark whether this \
  67. point falls within a certain gt or is the topk nearest points for \
  68. a specific gt_bbox.
  69. """
  70. gt_bboxes = gt_instances.bboxes
  71. labels = gt_instances.labels
  72. inside_gt_bbox_mask = inside_gt_bbox_mask.clone()
  73. num_gts = len(labels)
  74. num_points = sum([len(item) for item in anchor_points_list])
  75. if num_gts == 0:
  76. return gt_bboxes.new_zeros(num_points,
  77. num_gts), inside_gt_bbox_mask
  78. center_prior_list = []
  79. for slvl_points, stride in zip(anchor_points_list, self.strides):
  80. # slvl_points: points from single level in FPN, has shape (h*w, 2)
  81. # single_level_points has shape (h*w, num_gt, 2)
  82. single_level_points = slvl_points[:, None, :].expand(
  83. (slvl_points.size(0), len(gt_bboxes), 2))
  84. gt_center_x = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2)
  85. gt_center_y = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2)
  86. gt_center = torch.stack((gt_center_x, gt_center_y), dim=1)
  87. gt_center = gt_center[None]
  88. # instance_center has shape (1, num_gt, 2)
  89. instance_center = self.mean[labels][None]
  90. # instance_sigma has shape (1, num_gt, 2)
  91. instance_sigma = self.sigma[labels][None]
  92. # distance has shape (num_points, num_gt, 2)
  93. distance = (((single_level_points - gt_center) / float(stride) -
  94. instance_center)**2)
  95. center_prior = torch.exp(-distance /
  96. (2 * instance_sigma**2)).prod(dim=-1)
  97. center_prior_list.append(center_prior)
  98. center_prior_weights = torch.cat(center_prior_list, dim=0)
  99. if self.force_topk:
  100. gt_inds_no_points_inside = torch.nonzero(
  101. inside_gt_bbox_mask.sum(0) == 0).reshape(-1)
  102. if gt_inds_no_points_inside.numel():
  103. topk_center_index = \
  104. center_prior_weights[:, gt_inds_no_points_inside].topk(
  105. self.topk,
  106. dim=0)[1]
  107. temp_mask = inside_gt_bbox_mask[:, gt_inds_no_points_inside]
  108. inside_gt_bbox_mask[:, gt_inds_no_points_inside] = \
  109. torch.scatter(temp_mask,
  110. dim=0,
  111. index=topk_center_index,
  112. src=torch.ones_like(
  113. topk_center_index,
  114. dtype=torch.bool))
  115. center_prior_weights[~inside_gt_bbox_mask] = 0
  116. return center_prior_weights, inside_gt_bbox_mask
  117. @MODELS.register_module()
  118. class AutoAssignHead(FCOSHead):
  119. """AutoAssignHead head used in AutoAssign.
  120. More details can be found in the `paper
  121. <https://arxiv.org/abs/2007.03496>`_ .
  122. Args:
  123. force_topk (bool): Used in center prior initialization to
  124. handle extremely small gt. Default is False.
  125. topk (int): The number of points used to calculate the
  126. center prior when no point falls in gt_bbox. Only work when
  127. force_topk if True. Defaults to 9.
  128. pos_loss_weight (float): The loss weight of positive loss
  129. and with default value 0.25.
  130. neg_loss_weight (float): The loss weight of negative loss
  131. and with default value 0.75.
  132. center_loss_weight (float): The loss weight of center prior
  133. loss and with default value 0.75.
  134. """
  135. def __init__(self,
  136. *args,
  137. force_topk: bool = False,
  138. topk: int = 9,
  139. pos_loss_weight: float = 0.25,
  140. neg_loss_weight: float = 0.75,
  141. center_loss_weight: float = 0.75,
  142. **kwargs) -> None:
  143. super().__init__(*args, conv_bias=True, **kwargs)
  144. self.center_prior = CenterPrior(
  145. force_topk=force_topk,
  146. topk=topk,
  147. num_classes=self.num_classes,
  148. strides=self.strides)
  149. self.pos_loss_weight = pos_loss_weight
  150. self.neg_loss_weight = neg_loss_weight
  151. self.center_loss_weight = center_loss_weight
  152. self.prior_generator = MlvlPointGenerator(self.strides, offset=0)
  153. def init_weights(self) -> None:
  154. """Initialize weights of the head.
  155. In particular, we have special initialization for classified conv's and
  156. regression conv's bias
  157. """
  158. super(AutoAssignHead, self).init_weights()
  159. bias_cls = bias_init_with_prob(0.02)
  160. normal_init(self.conv_cls, std=0.01, bias=bias_cls)
  161. normal_init(self.conv_reg, std=0.01, bias=4.0)
  162. def forward_single(self, x: Tensor, scale: Scale,
  163. stride: int) -> Tuple[Tensor, Tensor, Tensor]:
  164. """Forward features of a single scale level.
  165. Args:
  166. x (Tensor): FPN feature maps of the specified stride.
  167. scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize
  168. the bbox prediction.
  169. stride (int): The corresponding stride for feature maps, only
  170. used to normalize the bbox prediction when self.norm_on_bbox
  171. is True.
  172. Returns:
  173. tuple[Tensor, Tensor, Tensor]: scores for each class, bbox
  174. predictions and centerness predictions of input feature maps.
  175. """
  176. cls_score, bbox_pred, cls_feat, reg_feat = super(
  177. FCOSHead, self).forward_single(x)
  178. centerness = self.conv_centerness(reg_feat)
  179. # scale the bbox_pred of different level
  180. # float to avoid overflow when enabling FP16
  181. bbox_pred = scale(bbox_pred).float()
  182. # bbox_pred needed for gradient computation has been modified
  183. # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
  184. # F.relu(bbox_pred) with bbox_pred.clamp(min=0)
  185. bbox_pred = bbox_pred.clamp(min=0)
  186. bbox_pred *= stride
  187. return cls_score, bbox_pred, centerness
  188. def get_pos_loss_single(self, cls_score: Tensor, objectness: Tensor,
  189. reg_loss: Tensor, gt_instances: InstanceData,
  190. center_prior_weights: Tensor) -> Tuple[Tensor]:
  191. """Calculate the positive loss of all points in gt_bboxes.
  192. Args:
  193. cls_score (Tensor): All category scores for each point on
  194. the feature map. The shape is (num_points, num_class).
  195. objectness (Tensor): Foreground probability of all points,
  196. has shape (num_points, 1).
  197. reg_loss (Tensor): The regression loss of each gt_bbox and each
  198. prediction box, has shape of (num_points, num_gt).
  199. gt_instances (:obj:`InstanceData`): Ground truth of instance
  200. annotations. It should includes ``bboxes`` and ``labels``
  201. attributes.
  202. center_prior_weights (Tensor): Float tensor with shape
  203. of (num_points, num_gt). Each value represents
  204. the center weighting coefficient.
  205. Returns:
  206. tuple[Tensor]:
  207. - pos_loss (Tensor): The positive loss of all points in the \
  208. gt_bboxes.
  209. """
  210. gt_labels = gt_instances.labels
  211. # p_loc: localization confidence
  212. p_loc = torch.exp(-reg_loss)
  213. # p_cls: classification confidence
  214. p_cls = (cls_score * objectness)[:, gt_labels]
  215. # p_pos: joint confidence indicator
  216. p_pos = p_cls * p_loc
  217. # 3 is a hyper-parameter to control the contributions of high and
  218. # low confidence locations towards positive losses.
  219. confidence_weight = torch.exp(p_pos * 3)
  220. p_pos_weight = (confidence_weight * center_prior_weights) / (
  221. (confidence_weight * center_prior_weights).sum(
  222. 0, keepdim=True)).clamp(min=EPS)
  223. reweighted_p_pos = (p_pos * p_pos_weight).sum(0)
  224. pos_loss = F.binary_cross_entropy(
  225. reweighted_p_pos,
  226. torch.ones_like(reweighted_p_pos),
  227. reduction='none')
  228. pos_loss = pos_loss.sum() * self.pos_loss_weight
  229. return pos_loss,
  230. def get_neg_loss_single(self, cls_score: Tensor, objectness: Tensor,
  231. gt_instances: InstanceData, ious: Tensor,
  232. inside_gt_bbox_mask: Tensor) -> Tuple[Tensor]:
  233. """Calculate the negative loss of all points in feature map.
  234. Args:
  235. cls_score (Tensor): All category scores for each point on
  236. the feature map. The shape is (num_points, num_class).
  237. objectness (Tensor): Foreground probability of all points
  238. and is shape of (num_points, 1).
  239. gt_instances (:obj:`InstanceData`): Ground truth of instance
  240. annotations. It should includes ``bboxes`` and ``labels``
  241. attributes.
  242. ious (Tensor): Float tensor with shape of (num_points, num_gt).
  243. Each value represent the iou of pred_bbox and gt_bboxes.
  244. inside_gt_bbox_mask (Tensor): Tensor of bool type,
  245. with shape of (num_points, num_gt), each
  246. value is used to mark whether this point falls
  247. within a certain gt.
  248. Returns:
  249. tuple[Tensor]:
  250. - neg_loss (Tensor): The negative loss of all points in the \
  251. feature map.
  252. """
  253. gt_labels = gt_instances.labels
  254. num_gts = len(gt_labels)
  255. joint_conf = (cls_score * objectness)
  256. p_neg_weight = torch.ones_like(joint_conf)
  257. if num_gts > 0:
  258. # the order of dinmension would affect the value of
  259. # p_neg_weight, we strictly follow the original
  260. # implementation.
  261. inside_gt_bbox_mask = inside_gt_bbox_mask.permute(1, 0)
  262. ious = ious.permute(1, 0)
  263. foreground_idxs = torch.nonzero(inside_gt_bbox_mask, as_tuple=True)
  264. temp_weight = (1 / (1 - ious[foreground_idxs]).clamp_(EPS))
  265. def normalize(x):
  266. return (x - x.min() + EPS) / (x.max() - x.min() + EPS)
  267. for instance_idx in range(num_gts):
  268. idxs = foreground_idxs[0] == instance_idx
  269. if idxs.any():
  270. temp_weight[idxs] = normalize(temp_weight[idxs])
  271. p_neg_weight[foreground_idxs[1],
  272. gt_labels[foreground_idxs[0]]] = 1 - temp_weight
  273. logits = (joint_conf * p_neg_weight)
  274. neg_loss = (
  275. logits**2 * F.binary_cross_entropy(
  276. logits, torch.zeros_like(logits), reduction='none'))
  277. neg_loss = neg_loss.sum() * self.neg_loss_weight
  278. return neg_loss,
  279. def loss_by_feat(
  280. self,
  281. cls_scores: List[Tensor],
  282. bbox_preds: List[Tensor],
  283. objectnesses: List[Tensor],
  284. batch_gt_instances: InstanceList,
  285. batch_img_metas: List[dict],
  286. batch_gt_instances_ignore: OptInstanceList = None
  287. ) -> Dict[str, Tensor]:
  288. """Calculate the loss based on the features extracted by the detection
  289. head.
  290. Args:
  291. cls_scores (list[Tensor]): Box scores for each scale level,
  292. each is a 4D-tensor, the channel number is
  293. num_points * num_classes.
  294. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  295. level, each is a 4D-tensor, the channel number is
  296. num_points * 4.
  297. objectnesses (list[Tensor]): objectness for each scale level, each
  298. is a 4D-tensor, the channel number is num_points * 1.
  299. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  300. gt_instance. It usually includes ``bboxes`` and ``labels``
  301. attributes.
  302. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  303. image size, scaling factor, etc.
  304. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  305. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  306. data that is ignored during training and testing.
  307. Defaults to None.
  308. Returns:
  309. dict[str, Tensor]: A dictionary of loss components.
  310. """
  311. assert len(cls_scores) == len(bbox_preds) == len(objectnesses)
  312. all_num_gt = sum([len(item) for item in batch_gt_instances])
  313. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  314. all_level_points = self.prior_generator.grid_priors(
  315. featmap_sizes,
  316. dtype=bbox_preds[0].dtype,
  317. device=bbox_preds[0].device)
  318. inside_gt_bbox_mask_list, bbox_targets_list = self.get_targets(
  319. all_level_points, batch_gt_instances)
  320. center_prior_weight_list = []
  321. temp_inside_gt_bbox_mask_list = []
  322. for gt_instances, inside_gt_bbox_mask in zip(batch_gt_instances,
  323. inside_gt_bbox_mask_list):
  324. center_prior_weight, inside_gt_bbox_mask = \
  325. self.center_prior(all_level_points, gt_instances,
  326. inside_gt_bbox_mask)
  327. center_prior_weight_list.append(center_prior_weight)
  328. temp_inside_gt_bbox_mask_list.append(inside_gt_bbox_mask)
  329. inside_gt_bbox_mask_list = temp_inside_gt_bbox_mask_list
  330. mlvl_points = torch.cat(all_level_points, dim=0)
  331. bbox_preds = levels_to_images(bbox_preds)
  332. cls_scores = levels_to_images(cls_scores)
  333. objectnesses = levels_to_images(objectnesses)
  334. reg_loss_list = []
  335. ious_list = []
  336. num_points = len(mlvl_points)
  337. for bbox_pred, encoded_targets, inside_gt_bbox_mask in zip(
  338. bbox_preds, bbox_targets_list, inside_gt_bbox_mask_list):
  339. temp_num_gt = encoded_targets.size(1)
  340. expand_mlvl_points = mlvl_points[:, None, :].expand(
  341. num_points, temp_num_gt, 2).reshape(-1, 2)
  342. encoded_targets = encoded_targets.reshape(-1, 4)
  343. expand_bbox_pred = bbox_pred[:, None, :].expand(
  344. num_points, temp_num_gt, 4).reshape(-1, 4)
  345. decoded_bbox_preds = self.bbox_coder.decode(
  346. expand_mlvl_points, expand_bbox_pred)
  347. decoded_target_preds = self.bbox_coder.decode(
  348. expand_mlvl_points, encoded_targets)
  349. with torch.no_grad():
  350. ious = bbox_overlaps(
  351. decoded_bbox_preds, decoded_target_preds, is_aligned=True)
  352. ious = ious.reshape(num_points, temp_num_gt)
  353. if temp_num_gt:
  354. ious = ious.max(
  355. dim=-1, keepdim=True).values.repeat(1, temp_num_gt)
  356. else:
  357. ious = ious.new_zeros(num_points, temp_num_gt)
  358. ious[~inside_gt_bbox_mask] = 0
  359. ious_list.append(ious)
  360. loss_bbox = self.loss_bbox(
  361. decoded_bbox_preds,
  362. decoded_target_preds,
  363. weight=None,
  364. reduction_override='none')
  365. reg_loss_list.append(loss_bbox.reshape(num_points, temp_num_gt))
  366. cls_scores = [item.sigmoid() for item in cls_scores]
  367. objectnesses = [item.sigmoid() for item in objectnesses]
  368. pos_loss_list, = multi_apply(self.get_pos_loss_single, cls_scores,
  369. objectnesses, reg_loss_list,
  370. batch_gt_instances,
  371. center_prior_weight_list)
  372. pos_avg_factor = reduce_mean(
  373. bbox_pred.new_tensor(all_num_gt)).clamp_(min=1)
  374. pos_loss = sum(pos_loss_list) / pos_avg_factor
  375. neg_loss_list, = multi_apply(self.get_neg_loss_single, cls_scores,
  376. objectnesses, batch_gt_instances,
  377. ious_list, inside_gt_bbox_mask_list)
  378. neg_avg_factor = sum(item.data.sum()
  379. for item in center_prior_weight_list)
  380. neg_avg_factor = reduce_mean(neg_avg_factor).clamp_(min=1)
  381. neg_loss = sum(neg_loss_list) / neg_avg_factor
  382. center_loss = []
  383. for i in range(len(batch_img_metas)):
  384. if inside_gt_bbox_mask_list[i].any():
  385. center_loss.append(
  386. len(batch_gt_instances[i]) /
  387. center_prior_weight_list[i].sum().clamp_(min=EPS))
  388. # when width or height of gt_bbox is smaller than stride of p3
  389. else:
  390. center_loss.append(center_prior_weight_list[i].sum() * 0)
  391. center_loss = torch.stack(center_loss).mean() * self.center_loss_weight
  392. # avoid dead lock in DDP
  393. if all_num_gt == 0:
  394. pos_loss = bbox_preds[0].sum() * 0
  395. dummy_center_prior_loss = self.center_prior.mean.sum(
  396. ) * 0 + self.center_prior.sigma.sum() * 0
  397. center_loss = objectnesses[0].sum() * 0 + dummy_center_prior_loss
  398. loss = dict(
  399. loss_pos=pos_loss, loss_neg=neg_loss, loss_center=center_loss)
  400. return loss
  401. def get_targets(
  402. self, points: List[Tensor], batch_gt_instances: InstanceList
  403. ) -> Tuple[List[Tensor], List[Tensor]]:
  404. """Compute regression targets and each point inside or outside gt_bbox
  405. in multiple images.
  406. Args:
  407. points (list[Tensor]): Points of all fpn level, each has shape
  408. (num_points, 2).
  409. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  410. gt_instance. It usually includes ``bboxes`` and ``labels``
  411. attributes.
  412. Returns:
  413. tuple(list[Tensor], list[Tensor]):
  414. - inside_gt_bbox_mask_list (list[Tensor]): Each Tensor is with \
  415. bool type and shape of (num_points, num_gt), each value is used \
  416. to mark whether this point falls within a certain gt.
  417. - concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \
  418. level. Each tensor has shape (num_points, num_gt, 4).
  419. """
  420. concat_points = torch.cat(points, dim=0)
  421. # the number of points per img, per lvl
  422. inside_gt_bbox_mask_list, bbox_targets_list = multi_apply(
  423. self._get_targets_single, batch_gt_instances, points=concat_points)
  424. return inside_gt_bbox_mask_list, bbox_targets_list
  425. def _get_targets_single(self, gt_instances: InstanceData,
  426. points: Tensor) -> Tuple[Tensor, Tensor]:
  427. """Compute regression targets and each point inside or outside gt_bbox
  428. for a single image.
  429. Args:
  430. gt_instances (:obj:`InstanceData`): Ground truth of instance
  431. annotations. It should includes ``bboxes`` and ``labels``
  432. attributes.
  433. points (Tensor): Points of all fpn level, has shape
  434. (num_points, 2).
  435. Returns:
  436. tuple[Tensor, Tensor]: Containing the following Tensors:
  437. - inside_gt_bbox_mask (Tensor): Bool tensor with shape \
  438. (num_points, num_gt), each value is used to mark whether this \
  439. point falls within a certain gt.
  440. - bbox_targets (Tensor): BBox targets of each points with each \
  441. gt_bboxes, has shape (num_points, num_gt, 4).
  442. """
  443. gt_bboxes = gt_instances.bboxes
  444. num_points = points.size(0)
  445. num_gts = gt_bboxes.size(0)
  446. gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
  447. xs, ys = points[:, 0], points[:, 1]
  448. xs = xs[:, None]
  449. ys = ys[:, None]
  450. left = xs - gt_bboxes[..., 0]
  451. right = gt_bboxes[..., 2] - xs
  452. top = ys - gt_bboxes[..., 1]
  453. bottom = gt_bboxes[..., 3] - ys
  454. bbox_targets = torch.stack((left, top, right, bottom), -1)
  455. if num_gts:
  456. inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
  457. else:
  458. inside_gt_bbox_mask = bbox_targets.new_zeros((num_points, num_gts),
  459. dtype=torch.bool)
  460. return inside_gt_bbox_mask, bbox_targets