centernet_update_head.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Optional, Sequence, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import Scale
  6. from mmengine.structures import InstanceData
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.structures.bbox import bbox2distance
  10. from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
  11. OptInstanceList, reduce_mean)
  12. from ..utils import multi_apply
  13. from .anchor_free_head import AnchorFreeHead
  14. INF = 1000000000
  15. RangeType = Sequence[Tuple[int, int]]
  16. def _transpose(tensor_list: List[Tensor],
  17. num_point_list: list) -> List[Tensor]:
  18. """This function is used to transpose image first tensors to level first
  19. ones."""
  20. for img_idx in range(len(tensor_list)):
  21. tensor_list[img_idx] = torch.split(
  22. tensor_list[img_idx], num_point_list, dim=0)
  23. tensors_level_first = []
  24. for targets_per_level in zip(*tensor_list):
  25. tensors_level_first.append(torch.cat(targets_per_level, dim=0))
  26. return tensors_level_first
  27. @MODELS.register_module()
  28. class CenterNetUpdateHead(AnchorFreeHead):
  29. """CenterNetUpdateHead is an improved version of CenterNet in CenterNet2.
  30. Paper link `<https://arxiv.org/abs/2103.07461>`_.
  31. Args:
  32. num_classes (int): Number of categories excluding the background
  33. category.
  34. in_channels (int): Number of channel in the input feature map.
  35. regress_ranges (Sequence[Tuple[int, int]]): Regress range of multiple
  36. level points.
  37. hm_min_radius (int): Heatmap target minimum radius of cls branch.
  38. Defaults to 4.
  39. hm_min_overlap (float): Heatmap target minimum overlap of cls branch.
  40. Defaults to 0.8.
  41. more_pos_thresh (float): The filtering threshold when the cls branch
  42. adds more positive samples. Defaults to 0.2.
  43. more_pos_topk (int): The maximum number of additional positive samples
  44. added to each gt. Defaults to 9.
  45. soft_weight_on_reg (bool): Whether to use the soft target of the
  46. cls branch as the soft weight of the bbox branch.
  47. Defaults to False.
  48. loss_cls (:obj:`ConfigDict` or dict): Config of cls loss. Defaults to
  49. dict(type='GaussianFocalLoss', loss_weight=1.0)
  50. loss_bbox (:obj:`ConfigDict` or dict): Config of bbox loss. Defaults to
  51. dict(type='GIoULoss', loss_weight=2.0).
  52. norm_cfg (:obj:`ConfigDict` or dict, optional): dictionary to construct
  53. and config norm layer. Defaults to
  54. ``norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)``.
  55. train_cfg (:obj:`ConfigDict` or dict, optional): Training config.
  56. Unused in CenterNet. Reserved for compatibility with
  57. SingleStageDetector.
  58. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config
  59. of CenterNet.
  60. """
  61. def __init__(self,
  62. num_classes: int,
  63. in_channels: int,
  64. regress_ranges: RangeType = ((0, 80), (64, 160), (128, 320),
  65. (256, 640), (512, INF)),
  66. hm_min_radius: int = 4,
  67. hm_min_overlap: float = 0.8,
  68. more_pos_thresh: float = 0.2,
  69. more_pos_topk: int = 9,
  70. soft_weight_on_reg: bool = False,
  71. loss_cls: ConfigType = dict(
  72. type='GaussianFocalLoss',
  73. pos_weight=0.25,
  74. neg_weight=0.75,
  75. loss_weight=1.0),
  76. loss_bbox: ConfigType = dict(
  77. type='GIoULoss', loss_weight=2.0),
  78. norm_cfg: OptConfigType = dict(
  79. type='GN', num_groups=32, requires_grad=True),
  80. train_cfg: OptConfigType = None,
  81. test_cfg: OptConfigType = None,
  82. **kwargs) -> None:
  83. super().__init__(
  84. num_classes=num_classes,
  85. in_channels=in_channels,
  86. loss_cls=loss_cls,
  87. loss_bbox=loss_bbox,
  88. norm_cfg=norm_cfg,
  89. train_cfg=train_cfg,
  90. test_cfg=test_cfg,
  91. **kwargs)
  92. self.soft_weight_on_reg = soft_weight_on_reg
  93. self.hm_min_radius = hm_min_radius
  94. self.more_pos_thresh = more_pos_thresh
  95. self.more_pos_topk = more_pos_topk
  96. self.delta = (1 - hm_min_overlap) / (1 + hm_min_overlap)
  97. self.sigmoid_clamp = 0.0001
  98. # GaussianFocalLoss must be sigmoid mode
  99. self.use_sigmoid_cls = True
  100. self.cls_out_channels = num_classes
  101. self.regress_ranges = regress_ranges
  102. self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
  103. def _init_predictor(self) -> None:
  104. """Initialize predictor layers of the head."""
  105. self.conv_cls = nn.Conv2d(
  106. self.feat_channels, self.num_classes, 3, padding=1)
  107. self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
  108. def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
  109. """Forward features from the upstream network.
  110. Args:
  111. x (tuple[Tensor]): Features from the upstream network, each is
  112. a 4D-tensor.
  113. Returns:
  114. tuple: A tuple of each level outputs.
  115. - cls_scores (list[Tensor]): Box scores for each scale level, \
  116. each is a 4D-tensor, the channel number is num_classes.
  117. - bbox_preds (list[Tensor]): Box energies / deltas for each \
  118. scale level, each is a 4D-tensor, the channel number is 4.
  119. """
  120. return multi_apply(self.forward_single, x, self.scales, self.strides)
  121. def forward_single(self, x: Tensor, scale: Scale,
  122. stride: int) -> Tuple[Tensor, Tensor]:
  123. """Forward features of a single scale level.
  124. Args:
  125. x (Tensor): FPN feature maps of the specified stride.
  126. scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize
  127. the bbox prediction.
  128. stride (int): The corresponding stride for feature maps.
  129. Returns:
  130. tuple: scores for each class, bbox predictions of
  131. input feature maps.
  132. """
  133. cls_score, bbox_pred, _, _ = super().forward_single(x)
  134. # scale the bbox_pred of different level
  135. # float to avoid overflow when enabling FP16
  136. bbox_pred = scale(bbox_pred).float()
  137. # bbox_pred needed for gradient computation has been modified
  138. # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
  139. # F.relu(bbox_pred) with bbox_pred.clamp(min=0)
  140. bbox_pred = bbox_pred.clamp(min=0)
  141. if not self.training:
  142. bbox_pred *= stride
  143. return cls_score, bbox_pred
  144. def loss_by_feat(
  145. self,
  146. cls_scores: List[Tensor],
  147. bbox_preds: List[Tensor],
  148. batch_gt_instances: InstanceList,
  149. batch_img_metas: List[dict],
  150. batch_gt_instances_ignore: OptInstanceList = None
  151. ) -> Dict[str, Tensor]:
  152. """Calculate the loss based on the features extracted by the detection
  153. head.
  154. Args:
  155. cls_scores (list[Tensor]): Box scores for each scale level,
  156. each is a 4D-tensor, the channel number is num_classes.
  157. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  158. level, each is a 4D-tensor, the channel number is 4.
  159. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  160. gt_instance. It usually includes ``bboxes`` and ``labels``
  161. attributes.
  162. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  163. image size, scaling factor, etc.
  164. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  165. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  166. data that is ignored during training and testing.
  167. Defaults to None.
  168. Returns:
  169. dict[str, Tensor]: A dictionary of loss components.
  170. """
  171. num_imgs = cls_scores[0].size(0)
  172. assert len(cls_scores) == len(bbox_preds)
  173. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  174. all_level_points = self.prior_generator.grid_priors(
  175. featmap_sizes,
  176. dtype=bbox_preds[0].dtype,
  177. device=bbox_preds[0].device)
  178. # 1 flatten outputs
  179. flatten_cls_scores = [
  180. cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
  181. for cls_score in cls_scores
  182. ]
  183. flatten_bbox_preds = [
  184. bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  185. for bbox_pred in bbox_preds
  186. ]
  187. flatten_cls_scores = torch.cat(flatten_cls_scores)
  188. flatten_bbox_preds = torch.cat(flatten_bbox_preds)
  189. # repeat points to align with bbox_preds
  190. flatten_points = torch.cat(
  191. [points.repeat(num_imgs, 1) for points in all_level_points])
  192. assert (torch.isfinite(flatten_bbox_preds).all().item())
  193. # 2 calc reg and cls branch targets
  194. cls_targets, bbox_targets = self.get_targets(all_level_points,
  195. batch_gt_instances)
  196. # 3 add more pos index for cls branch
  197. featmap_sizes = flatten_points.new_tensor(featmap_sizes)
  198. pos_inds, cls_labels = self.add_cls_pos_inds(flatten_points,
  199. flatten_bbox_preds,
  200. featmap_sizes,
  201. batch_gt_instances)
  202. # 4 calc cls loss
  203. if pos_inds is None:
  204. # num_gts=0
  205. num_pos_cls = bbox_preds[0].new_tensor(0, dtype=torch.float)
  206. else:
  207. num_pos_cls = bbox_preds[0].new_tensor(
  208. len(pos_inds), dtype=torch.float)
  209. num_pos_cls = max(reduce_mean(num_pos_cls), 1.0)
  210. flatten_cls_scores = flatten_cls_scores.sigmoid().clamp(
  211. min=self.sigmoid_clamp, max=1 - self.sigmoid_clamp)
  212. cls_loss = self.loss_cls(
  213. flatten_cls_scores,
  214. cls_targets,
  215. pos_inds=pos_inds,
  216. pos_labels=cls_labels,
  217. avg_factor=num_pos_cls)
  218. # 5 calc reg loss
  219. pos_bbox_inds = torch.nonzero(
  220. bbox_targets.max(dim=1)[0] >= 0).squeeze(1)
  221. pos_bbox_preds = flatten_bbox_preds[pos_bbox_inds]
  222. pos_bbox_targets = bbox_targets[pos_bbox_inds]
  223. bbox_weight_map = cls_targets.max(dim=1)[0]
  224. bbox_weight_map = bbox_weight_map[pos_bbox_inds]
  225. bbox_weight_map = bbox_weight_map if self.soft_weight_on_reg \
  226. else torch.ones_like(bbox_weight_map)
  227. num_pos_bbox = max(reduce_mean(bbox_weight_map.sum()), 1.0)
  228. if len(pos_bbox_inds) > 0:
  229. pos_points = flatten_points[pos_bbox_inds]
  230. pos_decoded_bbox_preds = self.bbox_coder.decode(
  231. pos_points, pos_bbox_preds)
  232. pos_decoded_target_preds = self.bbox_coder.decode(
  233. pos_points, pos_bbox_targets)
  234. bbox_loss = self.loss_bbox(
  235. pos_decoded_bbox_preds,
  236. pos_decoded_target_preds,
  237. weight=bbox_weight_map,
  238. avg_factor=num_pos_bbox)
  239. else:
  240. bbox_loss = flatten_bbox_preds.sum() * 0
  241. return dict(loss_cls=cls_loss, loss_bbox=bbox_loss)
  242. def get_targets(
  243. self,
  244. points: List[Tensor],
  245. batch_gt_instances: InstanceList,
  246. ) -> Tuple[Tensor, Tensor]:
  247. """Compute classification and bbox targets for points in multiple
  248. images.
  249. Args:
  250. points (list[Tensor]): Points of each fpn level, each has shape
  251. (num_points, 2).
  252. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  253. gt_instance. It usually includes ``bboxes`` and ``labels``
  254. attributes.
  255. Returns:
  256. tuple: Targets of each level.
  257. - concat_lvl_labels (Tensor): Labels of all level and batch.
  258. - concat_lvl_bbox_targets (Tensor): BBox targets of all \
  259. level and batch.
  260. """
  261. assert len(points) == len(self.regress_ranges)
  262. num_levels = len(points)
  263. # the number of points per img, per lvl
  264. num_points = [center.size(0) for center in points]
  265. # expand regress ranges to align with points
  266. expanded_regress_ranges = [
  267. points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
  268. points[i]) for i in range(num_levels)
  269. ]
  270. # concat all levels points and regress ranges
  271. concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
  272. concat_points = torch.cat(points, dim=0)
  273. concat_strides = torch.cat([
  274. concat_points.new_ones(num_points[i]) * self.strides[i]
  275. for i in range(num_levels)
  276. ])
  277. # get labels and bbox_targets of each image
  278. cls_targets_list, bbox_targets_list = multi_apply(
  279. self._get_targets_single,
  280. batch_gt_instances,
  281. points=concat_points,
  282. regress_ranges=concat_regress_ranges,
  283. strides=concat_strides)
  284. bbox_targets_list = _transpose(bbox_targets_list, num_points)
  285. cls_targets_list = _transpose(cls_targets_list, num_points)
  286. concat_lvl_bbox_targets = torch.cat(bbox_targets_list, 0)
  287. concat_lvl_cls_targets = torch.cat(cls_targets_list, dim=0)
  288. return concat_lvl_cls_targets, concat_lvl_bbox_targets
  289. def _get_targets_single(self, gt_instances: InstanceData, points: Tensor,
  290. regress_ranges: Tensor,
  291. strides: Tensor) -> Tuple[Tensor, Tensor]:
  292. """Compute classification and bbox targets for a single image."""
  293. num_points = points.size(0)
  294. num_gts = len(gt_instances)
  295. gt_bboxes = gt_instances.bboxes
  296. gt_labels = gt_instances.labels
  297. if num_gts == 0:
  298. return gt_labels.new_full((num_points,
  299. self.num_classes),
  300. self.num_classes), \
  301. gt_bboxes.new_full((num_points, 4), -1)
  302. # Calculate the regression tblr target corresponding to all points
  303. points = points[:, None].expand(num_points, num_gts, 2)
  304. gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
  305. strides = strides[:, None, None].expand(num_points, num_gts, 2)
  306. bbox_target = bbox2distance(points, gt_bboxes) # M x N x 4
  307. # condition1: inside a gt bbox
  308. inside_gt_bbox_mask = bbox_target.min(dim=2)[0] > 0 # M x N
  309. # condition2: Calculate the nearest points from
  310. # the upper, lower, left and right ranges from
  311. # the center of the gt bbox
  312. centers = ((gt_bboxes[..., [0, 1]] + gt_bboxes[..., [2, 3]]) / 2)
  313. centers_discret = ((centers / strides).int() * strides).float() + \
  314. strides / 2
  315. centers_discret_dist = points - centers_discret
  316. dist_x = centers_discret_dist[..., 0].abs()
  317. dist_y = centers_discret_dist[..., 1].abs()
  318. inside_gt_center3x3_mask = (dist_x <= strides[..., 0]) & \
  319. (dist_y <= strides[..., 0])
  320. # condition3: limit the regression range for each location
  321. bbox_target_wh = bbox_target[..., :2] + bbox_target[..., 2:]
  322. crit = (bbox_target_wh**2).sum(dim=2)**0.5 / 2
  323. inside_fpn_level_mask = (crit >= regress_ranges[:, [0]]) & \
  324. (crit <= regress_ranges[:, [1]])
  325. bbox_target_mask = inside_gt_bbox_mask & \
  326. inside_gt_center3x3_mask & \
  327. inside_fpn_level_mask
  328. # Calculate the distance weight map
  329. gt_center_peak_mask = ((centers_discret_dist**2).sum(dim=2) == 0)
  330. weighted_dist = ((points - centers)**2).sum(dim=2) # M x N
  331. weighted_dist[gt_center_peak_mask] = 0
  332. areas = (gt_bboxes[..., 2] - gt_bboxes[..., 0]) * (
  333. gt_bboxes[..., 3] - gt_bboxes[..., 1])
  334. radius = self.delta**2 * 2 * areas
  335. radius = torch.clamp(radius, min=self.hm_min_radius**2)
  336. weighted_dist = weighted_dist / radius
  337. # Calculate bbox_target
  338. bbox_weighted_dist = weighted_dist.clone()
  339. bbox_weighted_dist[bbox_target_mask == 0] = INF * 1.0
  340. min_dist, min_inds = bbox_weighted_dist.min(dim=1)
  341. bbox_target = bbox_target[range(len(bbox_target)),
  342. min_inds] # M x N x 4 --> M x 4
  343. bbox_target[min_dist == INF] = -INF
  344. # Convert to feature map scale
  345. bbox_target /= strides[:, 0, :].repeat(1, 2)
  346. # Calculate cls_target
  347. cls_target = self._create_heatmaps_from_dist(weighted_dist, gt_labels)
  348. return cls_target, bbox_target
  349. @torch.no_grad()
  350. def add_cls_pos_inds(
  351. self, flatten_points: Tensor, flatten_bbox_preds: Tensor,
  352. featmap_sizes: Tensor, batch_gt_instances: InstanceList
  353. ) -> Tuple[Optional[Tensor], Optional[Tensor]]:
  354. """Provide additional adaptive positive samples to the classification
  355. branch.
  356. Args:
  357. flatten_points (Tensor): The point after flatten, including
  358. batch image and all levels. The shape is (N, 2).
  359. flatten_bbox_preds (Tensor): The bbox predicts after flatten,
  360. including batch image and all levels. The shape is (N, 4).
  361. featmap_sizes (Tensor): Feature map size of all layers.
  362. The shape is (5, 2).
  363. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  364. gt_instance. It usually includes ``bboxes`` and ``labels``
  365. attributes.
  366. Returns:
  367. tuple:
  368. - pos_inds (Tensor): Adaptively selected positive sample index.
  369. - cls_labels (Tensor): Corresponding positive class label.
  370. """
  371. outputs = self._get_center3x3_region_index_targets(
  372. batch_gt_instances, featmap_sizes)
  373. cls_labels, fpn_level_masks, center3x3_inds, \
  374. center3x3_bbox_targets, center3x3_masks = outputs
  375. num_gts, total_level, K = cls_labels.shape[0], len(
  376. self.strides), center3x3_masks.shape[-1]
  377. if num_gts == 0:
  378. return None, None
  379. # The out-of-bounds index is forcibly set to 0
  380. # to prevent loss calculation errors
  381. center3x3_inds[center3x3_masks == 0] = 0
  382. reg_pred_center3x3 = flatten_bbox_preds[center3x3_inds]
  383. center3x3_points = flatten_points[center3x3_inds].view(-1, 2)
  384. center3x3_bbox_targets_expand = center3x3_bbox_targets.view(
  385. -1, 4).clamp(min=0)
  386. pos_decoded_bbox_preds = self.bbox_coder.decode(
  387. center3x3_points, reg_pred_center3x3.view(-1, 4))
  388. pos_decoded_target_preds = self.bbox_coder.decode(
  389. center3x3_points, center3x3_bbox_targets_expand)
  390. center3x3_bbox_loss = self.loss_bbox(
  391. pos_decoded_bbox_preds,
  392. pos_decoded_target_preds,
  393. None,
  394. reduction_override='none').view(num_gts, total_level,
  395. K) / self.loss_bbox.loss_weight
  396. # Invalid index Loss set to infinity
  397. center3x3_bbox_loss[center3x3_masks == 0] = INF
  398. # 4 is the center point of the sampled 9 points, the center point
  399. # of gt bbox after discretization.
  400. # The center point of gt bbox after discretization
  401. # must be a positive sample, so we force its loss to be set to 0.
  402. center3x3_bbox_loss.view(-1, K)[fpn_level_masks.view(-1), 4] = 0
  403. center3x3_bbox_loss = center3x3_bbox_loss.view(num_gts, -1)
  404. loss_thr = torch.kthvalue(
  405. center3x3_bbox_loss, self.more_pos_topk, dim=1)[0]
  406. loss_thr[loss_thr > self.more_pos_thresh] = self.more_pos_thresh
  407. new_pos = center3x3_bbox_loss < loss_thr.view(num_gts, 1)
  408. pos_inds = center3x3_inds.view(num_gts, -1)[new_pos]
  409. cls_labels = cls_labels.view(num_gts,
  410. 1).expand(num_gts,
  411. total_level * K)[new_pos]
  412. return pos_inds, cls_labels
  413. def _create_heatmaps_from_dist(self, weighted_dist: Tensor,
  414. cls_labels: Tensor) -> Tensor:
  415. """Generate heatmaps of classification branch based on weighted
  416. distance map."""
  417. heatmaps = weighted_dist.new_zeros(
  418. (weighted_dist.shape[0], self.num_classes))
  419. for c in range(self.num_classes):
  420. inds = (cls_labels == c) # N
  421. if inds.int().sum() == 0:
  422. continue
  423. heatmaps[:, c] = torch.exp(-weighted_dist[:, inds].min(dim=1)[0])
  424. zeros = heatmaps[:, c] < 1e-4
  425. heatmaps[zeros, c] = 0
  426. return heatmaps
  427. def _get_center3x3_region_index_targets(self,
  428. bacth_gt_instances: InstanceList,
  429. shapes_per_level: Tensor) -> tuple:
  430. """Get the center (and the 3x3 region near center) locations and target
  431. of each objects."""
  432. cls_labels = []
  433. inside_fpn_level_masks = []
  434. center3x3_inds = []
  435. center3x3_masks = []
  436. center3x3_bbox_targets = []
  437. total_levels = len(self.strides)
  438. batch = len(bacth_gt_instances)
  439. shapes_per_level = shapes_per_level.long()
  440. area_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1])
  441. # Select a total of 9 positions of 3x3 in the center of the gt bbox
  442. # as candidate positive samples
  443. K = 9
  444. dx = shapes_per_level.new_tensor([-1, 0, 1, -1, 0, 1, -1, 0,
  445. 1]).view(1, 1, K)
  446. dy = shapes_per_level.new_tensor([-1, -1, -1, 0, 0, 0, 1, 1,
  447. 1]).view(1, 1, K)
  448. regress_ranges = shapes_per_level.new_tensor(self.regress_ranges).view(
  449. len(self.regress_ranges), 2) # L x 2
  450. strides = shapes_per_level.new_tensor(self.strides)
  451. start_coord_pre_level = []
  452. _start = 0
  453. for level in range(total_levels):
  454. start_coord_pre_level.append(_start)
  455. _start = _start + batch * area_per_level[level]
  456. start_coord_pre_level = shapes_per_level.new_tensor(
  457. start_coord_pre_level).view(1, total_levels, 1)
  458. area_per_level = area_per_level.view(1, total_levels, 1)
  459. for im_i in range(batch):
  460. gt_instance = bacth_gt_instances[im_i]
  461. gt_bboxes = gt_instance.bboxes
  462. gt_labels = gt_instance.labels
  463. num_gts = gt_bboxes.shape[0]
  464. if num_gts == 0:
  465. continue
  466. cls_labels.append(gt_labels)
  467. gt_bboxes = gt_bboxes[:, None].expand(num_gts, total_levels, 4)
  468. expanded_strides = strides[None, :,
  469. None].expand(num_gts, total_levels, 2)
  470. expanded_regress_ranges = regress_ranges[None].expand(
  471. num_gts, total_levels, 2)
  472. expanded_shapes_per_level = shapes_per_level[None].expand(
  473. num_gts, total_levels, 2)
  474. # calc reg_target
  475. centers = ((gt_bboxes[..., [0, 1]] + gt_bboxes[..., [2, 3]]) / 2)
  476. centers_inds = (centers / expanded_strides).long()
  477. centers_discret = centers_inds * expanded_strides \
  478. + expanded_strides // 2
  479. bbox_target = bbox2distance(centers_discret,
  480. gt_bboxes) # M x N x 4
  481. # calc inside_fpn_level_mask
  482. bbox_target_wh = bbox_target[..., :2] + bbox_target[..., 2:]
  483. crit = (bbox_target_wh**2).sum(dim=2)**0.5 / 2
  484. inside_fpn_level_mask = \
  485. (crit >= expanded_regress_ranges[..., 0]) & \
  486. (crit <= expanded_regress_ranges[..., 1])
  487. inside_gt_bbox_mask = bbox_target.min(dim=2)[0] >= 0
  488. inside_fpn_level_mask = inside_gt_bbox_mask & inside_fpn_level_mask
  489. inside_fpn_level_masks.append(inside_fpn_level_mask)
  490. # calc center3x3_ind and mask
  491. expand_ws = expanded_shapes_per_level[..., 1:2].expand(
  492. num_gts, total_levels, K)
  493. expand_hs = expanded_shapes_per_level[..., 0:1].expand(
  494. num_gts, total_levels, K)
  495. centers_inds_x = centers_inds[..., 0:1]
  496. centers_inds_y = centers_inds[..., 1:2]
  497. center3x3_idx = start_coord_pre_level + \
  498. im_i * area_per_level + \
  499. (centers_inds_y + dy) * expand_ws + \
  500. (centers_inds_x + dx)
  501. center3x3_mask = \
  502. ((centers_inds_y + dy) < expand_hs) & \
  503. ((centers_inds_y + dy) >= 0) & \
  504. ((centers_inds_x + dx) < expand_ws) & \
  505. ((centers_inds_x + dx) >= 0)
  506. # recalc center3x3 region reg target
  507. bbox_target = bbox_target / expanded_strides.repeat(1, 1, 2)
  508. center3x3_bbox_target = bbox_target[..., None, :].expand(
  509. num_gts, total_levels, K, 4).clone()
  510. center3x3_bbox_target[..., 0] += dx
  511. center3x3_bbox_target[..., 1] += dy
  512. center3x3_bbox_target[..., 2] -= dx
  513. center3x3_bbox_target[..., 3] -= dy
  514. # update center3x3_mask
  515. center3x3_mask = center3x3_mask & (
  516. center3x3_bbox_target.min(dim=3)[0] >= 0) # n x L x K
  517. center3x3_inds.append(center3x3_idx)
  518. center3x3_masks.append(center3x3_mask)
  519. center3x3_bbox_targets.append(center3x3_bbox_target)
  520. if len(inside_fpn_level_masks) > 0:
  521. cls_labels = torch.cat(cls_labels, dim=0)
  522. inside_fpn_level_masks = torch.cat(inside_fpn_level_masks, dim=0)
  523. center3x3_inds = torch.cat(center3x3_inds, dim=0).long()
  524. center3x3_bbox_targets = torch.cat(center3x3_bbox_targets, dim=0)
  525. center3x3_masks = torch.cat(center3x3_masks, dim=0)
  526. else:
  527. cls_labels = shapes_per_level.new_zeros(0).long()
  528. inside_fpn_level_masks = shapes_per_level.new_zeros(
  529. (0, total_levels)).bool()
  530. center3x3_inds = shapes_per_level.new_zeros(
  531. (0, total_levels, K)).long()
  532. center3x3_bbox_targets = shapes_per_level.new_zeros(
  533. (0, total_levels, K, 4)).float()
  534. center3x3_masks = shapes_per_level.new_zeros(
  535. (0, total_levels, K)).bool()
  536. return cls_labels, inside_fpn_level_masks, center3x3_inds, \
  537. center3x3_bbox_targets, center3x3_masks