corner_head.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from logging import warning
  3. from math import ceil, log
  4. from typing import List, Optional, Sequence, Tuple
  5. import torch
  6. import torch.nn as nn
  7. from mmcv.cnn import ConvModule
  8. from mmcv.ops import CornerPool, batched_nms
  9. from mmengine.config import ConfigDict
  10. from mmengine.model import BaseModule, bias_init_with_prob
  11. from mmengine.structures import InstanceData
  12. from torch import Tensor
  13. from mmdet.registry import MODELS
  14. from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
  15. OptInstanceList, OptMultiConfig)
  16. from ..utils import (gather_feat, gaussian_radius, gen_gaussian_target,
  17. get_local_maximum, get_topk_from_heatmap, multi_apply,
  18. transpose_and_gather_feat)
  19. from .base_dense_head import BaseDenseHead
  20. class BiCornerPool(BaseModule):
  21. """Bidirectional Corner Pooling Module (TopLeft, BottomRight, etc.)
  22. Args:
  23. in_channels (int): Input channels of module.
  24. directions (list[str]): Directions of two CornerPools.
  25. out_channels (int): Output channels of module.
  26. feat_channels (int): Feature channels of module.
  27. norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct
  28. and config norm layer.
  29. init_cfg (:obj:`ConfigDict` or dict, optional): the config to
  30. control the initialization.
  31. """
  32. def __init__(self,
  33. in_channels: int,
  34. directions: List[int],
  35. feat_channels: int = 128,
  36. out_channels: int = 128,
  37. norm_cfg: ConfigType = dict(type='BN', requires_grad=True),
  38. init_cfg: OptMultiConfig = None) -> None:
  39. super().__init__(init_cfg)
  40. self.direction1_conv = ConvModule(
  41. in_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg)
  42. self.direction2_conv = ConvModule(
  43. in_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg)
  44. self.aftpool_conv = ConvModule(
  45. feat_channels,
  46. out_channels,
  47. 3,
  48. padding=1,
  49. norm_cfg=norm_cfg,
  50. act_cfg=None)
  51. self.conv1 = ConvModule(
  52. in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
  53. self.conv2 = ConvModule(
  54. in_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg)
  55. self.direction1_pool = CornerPool(directions[0])
  56. self.direction2_pool = CornerPool(directions[1])
  57. self.relu = nn.ReLU(inplace=True)
  58. def forward(self, x: Tensor) -> Tensor:
  59. """Forward features from the upstream network.
  60. Args:
  61. x (tensor): Input feature of BiCornerPool.
  62. Returns:
  63. conv2 (tensor): Output feature of BiCornerPool.
  64. """
  65. direction1_conv = self.direction1_conv(x)
  66. direction2_conv = self.direction2_conv(x)
  67. direction1_feat = self.direction1_pool(direction1_conv)
  68. direction2_feat = self.direction2_pool(direction2_conv)
  69. aftpool_conv = self.aftpool_conv(direction1_feat + direction2_feat)
  70. conv1 = self.conv1(x)
  71. relu = self.relu(aftpool_conv + conv1)
  72. conv2 = self.conv2(relu)
  73. return conv2
  74. @MODELS.register_module()
  75. class CornerHead(BaseDenseHead):
  76. """Head of CornerNet: Detecting Objects as Paired Keypoints.
  77. Code is modified from the `official github repo
  78. <https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/
  79. kp.py#L73>`_ .
  80. More details can be found in the `paper
  81. <https://arxiv.org/abs/1808.01244>`_ .
  82. Args:
  83. num_classes (int): Number of categories excluding the background
  84. category.
  85. in_channels (int): Number of channels in the input feature map.
  86. num_feat_levels (int): Levels of feature from the previous module.
  87. 2 for HourglassNet-104 and 1 for HourglassNet-52. Because
  88. HourglassNet-104 outputs the final feature and intermediate
  89. supervision feature and HourglassNet-52 only outputs the final
  90. feature. Defaults to 2.
  91. corner_emb_channels (int): Channel of embedding vector. Defaults to 1.
  92. train_cfg (:obj:`ConfigDict` or dict, optional): Training config.
  93. Useless in CornerHead, but we keep this variable for
  94. SingleStageDetector.
  95. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  96. CornerHead.
  97. loss_heatmap (:obj:`ConfigDict` or dict): Config of corner heatmap
  98. loss. Defaults to GaussianFocalLoss.
  99. loss_embedding (:obj:`ConfigDict` or dict): Config of corner embedding
  100. loss. Defaults to AssociativeEmbeddingLoss.
  101. loss_offset (:obj:`ConfigDict` or dict): Config of corner offset loss.
  102. Defaults to SmoothL1Loss.
  103. init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
  104. the initialization.
  105. """
  106. def __init__(self,
  107. num_classes: int,
  108. in_channels: int,
  109. num_feat_levels: int = 2,
  110. corner_emb_channels: int = 1,
  111. train_cfg: OptConfigType = None,
  112. test_cfg: OptConfigType = None,
  113. loss_heatmap: ConfigType = dict(
  114. type='GaussianFocalLoss',
  115. alpha=2.0,
  116. gamma=4.0,
  117. loss_weight=1),
  118. loss_embedding: ConfigType = dict(
  119. type='AssociativeEmbeddingLoss',
  120. pull_weight=0.25,
  121. push_weight=0.25),
  122. loss_offset: ConfigType = dict(
  123. type='SmoothL1Loss', beta=1.0, loss_weight=1),
  124. init_cfg: OptMultiConfig = None) -> None:
  125. assert init_cfg is None, 'To prevent abnormal initialization ' \
  126. 'behavior, init_cfg is not allowed to be set'
  127. super().__init__(init_cfg=init_cfg)
  128. self.num_classes = num_classes
  129. self.in_channels = in_channels
  130. self.corner_emb_channels = corner_emb_channels
  131. self.with_corner_emb = self.corner_emb_channels > 0
  132. self.corner_offset_channels = 2
  133. self.num_feat_levels = num_feat_levels
  134. self.loss_heatmap = MODELS.build(
  135. loss_heatmap) if loss_heatmap is not None else None
  136. self.loss_embedding = MODELS.build(
  137. loss_embedding) if loss_embedding is not None else None
  138. self.loss_offset = MODELS.build(
  139. loss_offset) if loss_offset is not None else None
  140. self.train_cfg = train_cfg
  141. self.test_cfg = test_cfg
  142. self._init_layers()
  143. def _make_layers(self,
  144. out_channels: int,
  145. in_channels: int = 256,
  146. feat_channels: int = 256) -> nn.Sequential:
  147. """Initialize conv sequential for CornerHead."""
  148. return nn.Sequential(
  149. ConvModule(in_channels, feat_channels, 3, padding=1),
  150. ConvModule(
  151. feat_channels, out_channels, 1, norm_cfg=None, act_cfg=None))
  152. def _init_corner_kpt_layers(self) -> None:
  153. """Initialize corner keypoint layers.
  154. Including corner heatmap branch and corner offset branch. Each branch
  155. has two parts: prefix `tl_` for top-left and `br_` for bottom-right.
  156. """
  157. self.tl_pool, self.br_pool = nn.ModuleList(), nn.ModuleList()
  158. self.tl_heat, self.br_heat = nn.ModuleList(), nn.ModuleList()
  159. self.tl_off, self.br_off = nn.ModuleList(), nn.ModuleList()
  160. for _ in range(self.num_feat_levels):
  161. self.tl_pool.append(
  162. BiCornerPool(
  163. self.in_channels, ['top', 'left'],
  164. out_channels=self.in_channels))
  165. self.br_pool.append(
  166. BiCornerPool(
  167. self.in_channels, ['bottom', 'right'],
  168. out_channels=self.in_channels))
  169. self.tl_heat.append(
  170. self._make_layers(
  171. out_channels=self.num_classes,
  172. in_channels=self.in_channels))
  173. self.br_heat.append(
  174. self._make_layers(
  175. out_channels=self.num_classes,
  176. in_channels=self.in_channels))
  177. self.tl_off.append(
  178. self._make_layers(
  179. out_channels=self.corner_offset_channels,
  180. in_channels=self.in_channels))
  181. self.br_off.append(
  182. self._make_layers(
  183. out_channels=self.corner_offset_channels,
  184. in_channels=self.in_channels))
  185. def _init_corner_emb_layers(self) -> None:
  186. """Initialize corner embedding layers.
  187. Only include corner embedding branch with two parts: prefix `tl_` for
  188. top-left and `br_` for bottom-right.
  189. """
  190. self.tl_emb, self.br_emb = nn.ModuleList(), nn.ModuleList()
  191. for _ in range(self.num_feat_levels):
  192. self.tl_emb.append(
  193. self._make_layers(
  194. out_channels=self.corner_emb_channels,
  195. in_channels=self.in_channels))
  196. self.br_emb.append(
  197. self._make_layers(
  198. out_channels=self.corner_emb_channels,
  199. in_channels=self.in_channels))
  200. def _init_layers(self) -> None:
  201. """Initialize layers for CornerHead.
  202. Including two parts: corner keypoint layers and corner embedding layers
  203. """
  204. self._init_corner_kpt_layers()
  205. if self.with_corner_emb:
  206. self._init_corner_emb_layers()
  207. def init_weights(self) -> None:
  208. super().init_weights()
  209. bias_init = bias_init_with_prob(0.1)
  210. for i in range(self.num_feat_levels):
  211. # The initialization of parameters are different between
  212. # nn.Conv2d and ConvModule. Our experiments show that
  213. # using the original initialization of nn.Conv2d increases
  214. # the final mAP by about 0.2%
  215. self.tl_heat[i][-1].conv.reset_parameters()
  216. self.tl_heat[i][-1].conv.bias.data.fill_(bias_init)
  217. self.br_heat[i][-1].conv.reset_parameters()
  218. self.br_heat[i][-1].conv.bias.data.fill_(bias_init)
  219. self.tl_off[i][-1].conv.reset_parameters()
  220. self.br_off[i][-1].conv.reset_parameters()
  221. if self.with_corner_emb:
  222. self.tl_emb[i][-1].conv.reset_parameters()
  223. self.br_emb[i][-1].conv.reset_parameters()
  224. def forward(self, feats: Tuple[Tensor]) -> tuple:
  225. """Forward features from the upstream network.
  226. Args:
  227. feats (tuple[Tensor]): Features from the upstream network, each is
  228. a 4D-tensor.
  229. Returns:
  230. tuple: Usually a tuple of corner heatmaps, offset heatmaps and
  231. embedding heatmaps.
  232. - tl_heats (list[Tensor]): Top-left corner heatmaps for all
  233. levels, each is a 4D-tensor, the channels number is
  234. num_classes.
  235. - br_heats (list[Tensor]): Bottom-right corner heatmaps for all
  236. levels, each is a 4D-tensor, the channels number is
  237. num_classes.
  238. - tl_embs (list[Tensor] | list[None]): Top-left embedding
  239. heatmaps for all levels, each is a 4D-tensor or None.
  240. If not None, the channels number is corner_emb_channels.
  241. - br_embs (list[Tensor] | list[None]): Bottom-right embedding
  242. heatmaps for all levels, each is a 4D-tensor or None.
  243. If not None, the channels number is corner_emb_channels.
  244. - tl_offs (list[Tensor]): Top-left offset heatmaps for all
  245. levels, each is a 4D-tensor. The channels number is
  246. corner_offset_channels.
  247. - br_offs (list[Tensor]): Bottom-right offset heatmaps for all
  248. levels, each is a 4D-tensor. The channels number is
  249. corner_offset_channels.
  250. """
  251. lvl_ind = list(range(self.num_feat_levels))
  252. return multi_apply(self.forward_single, feats, lvl_ind)
  253. def forward_single(self,
  254. x: Tensor,
  255. lvl_ind: int,
  256. return_pool: bool = False) -> List[Tensor]:
  257. """Forward feature of a single level.
  258. Args:
  259. x (Tensor): Feature of a single level.
  260. lvl_ind (int): Level index of current feature.
  261. return_pool (bool): Return corner pool feature or not.
  262. Defaults to False.
  263. Returns:
  264. tuple[Tensor]: A tuple of CornerHead's output for current feature
  265. level. Containing the following Tensors:
  266. - tl_heat (Tensor): Predicted top-left corner heatmap.
  267. - br_heat (Tensor): Predicted bottom-right corner heatmap.
  268. - tl_emb (Tensor | None): Predicted top-left embedding heatmap.
  269. None for `self.with_corner_emb == False`.
  270. - br_emb (Tensor | None): Predicted bottom-right embedding
  271. heatmap. None for `self.with_corner_emb == False`.
  272. - tl_off (Tensor): Predicted top-left offset heatmap.
  273. - br_off (Tensor): Predicted bottom-right offset heatmap.
  274. - tl_pool (Tensor): Top-left corner pool feature. Not must
  275. have.
  276. - br_pool (Tensor): Bottom-right corner pool feature. Not must
  277. have.
  278. """
  279. tl_pool = self.tl_pool[lvl_ind](x)
  280. tl_heat = self.tl_heat[lvl_ind](tl_pool)
  281. br_pool = self.br_pool[lvl_ind](x)
  282. br_heat = self.br_heat[lvl_ind](br_pool)
  283. tl_emb, br_emb = None, None
  284. if self.with_corner_emb:
  285. tl_emb = self.tl_emb[lvl_ind](tl_pool)
  286. br_emb = self.br_emb[lvl_ind](br_pool)
  287. tl_off = self.tl_off[lvl_ind](tl_pool)
  288. br_off = self.br_off[lvl_ind](br_pool)
  289. result_list = [tl_heat, br_heat, tl_emb, br_emb, tl_off, br_off]
  290. if return_pool:
  291. result_list.append(tl_pool)
  292. result_list.append(br_pool)
  293. return result_list
  294. def get_targets(self,
  295. gt_bboxes: List[Tensor],
  296. gt_labels: List[Tensor],
  297. feat_shape: Sequence[int],
  298. img_shape: Sequence[int],
  299. with_corner_emb: bool = False,
  300. with_guiding_shift: bool = False,
  301. with_centripetal_shift: bool = False) -> dict:
  302. """Generate corner targets.
  303. Including corner heatmap, corner offset.
  304. Optional: corner embedding, corner guiding shift, centripetal shift.
  305. For CornerNet, we generate corner heatmap, corner offset and corner
  306. embedding from this function.
  307. For CentripetalNet, we generate corner heatmap, corner offset, guiding
  308. shift and centripetal shift from this function.
  309. Args:
  310. gt_bboxes (list[Tensor]): Ground truth bboxes of each image, each
  311. has shape (num_gt, 4).
  312. gt_labels (list[Tensor]): Ground truth labels of each box, each has
  313. shape (num_gt, ).
  314. feat_shape (Sequence[int]): Shape of output feature,
  315. [batch, channel, height, width].
  316. img_shape (Sequence[int]): Shape of input image,
  317. [height, width, channel].
  318. with_corner_emb (bool): Generate corner embedding target or not.
  319. Defaults to False.
  320. with_guiding_shift (bool): Generate guiding shift target or not.
  321. Defaults to False.
  322. with_centripetal_shift (bool): Generate centripetal shift target or
  323. not. Defaults to False.
  324. Returns:
  325. dict: Ground truth of corner heatmap, corner offset, corner
  326. embedding, guiding shift and centripetal shift. Containing the
  327. following keys:
  328. - topleft_heatmap (Tensor): Ground truth top-left corner
  329. heatmap.
  330. - bottomright_heatmap (Tensor): Ground truth bottom-right
  331. corner heatmap.
  332. - topleft_offset (Tensor): Ground truth top-left corner offset.
  333. - bottomright_offset (Tensor): Ground truth bottom-right corner
  334. offset.
  335. - corner_embedding (list[list[list[int]]]): Ground truth corner
  336. embedding. Not must have.
  337. - topleft_guiding_shift (Tensor): Ground truth top-left corner
  338. guiding shift. Not must have.
  339. - bottomright_guiding_shift (Tensor): Ground truth bottom-right
  340. corner guiding shift. Not must have.
  341. - topleft_centripetal_shift (Tensor): Ground truth top-left
  342. corner centripetal shift. Not must have.
  343. - bottomright_centripetal_shift (Tensor): Ground truth
  344. bottom-right corner centripetal shift. Not must have.
  345. """
  346. batch_size, _, height, width = feat_shape
  347. img_h, img_w = img_shape[:2]
  348. width_ratio = float(width / img_w)
  349. height_ratio = float(height / img_h)
  350. gt_tl_heatmap = gt_bboxes[-1].new_zeros(
  351. [batch_size, self.num_classes, height, width])
  352. gt_br_heatmap = gt_bboxes[-1].new_zeros(
  353. [batch_size, self.num_classes, height, width])
  354. gt_tl_offset = gt_bboxes[-1].new_zeros([batch_size, 2, height, width])
  355. gt_br_offset = gt_bboxes[-1].new_zeros([batch_size, 2, height, width])
  356. if with_corner_emb:
  357. match = []
  358. # Guiding shift is a kind of offset, from center to corner
  359. if with_guiding_shift:
  360. gt_tl_guiding_shift = gt_bboxes[-1].new_zeros(
  361. [batch_size, 2, height, width])
  362. gt_br_guiding_shift = gt_bboxes[-1].new_zeros(
  363. [batch_size, 2, height, width])
  364. # Centripetal shift is also a kind of offset, from center to corner
  365. # and normalized by log.
  366. if with_centripetal_shift:
  367. gt_tl_centripetal_shift = gt_bboxes[-1].new_zeros(
  368. [batch_size, 2, height, width])
  369. gt_br_centripetal_shift = gt_bboxes[-1].new_zeros(
  370. [batch_size, 2, height, width])
  371. for batch_id in range(batch_size):
  372. # Ground truth of corner embedding per image is a list of coord set
  373. corner_match = []
  374. for box_id in range(len(gt_labels[batch_id])):
  375. left, top, right, bottom = gt_bboxes[batch_id][box_id]
  376. center_x = (left + right) / 2.0
  377. center_y = (top + bottom) / 2.0
  378. label = gt_labels[batch_id][box_id]
  379. # Use coords in the feature level to generate ground truth
  380. scale_left = left * width_ratio
  381. scale_right = right * width_ratio
  382. scale_top = top * height_ratio
  383. scale_bottom = bottom * height_ratio
  384. scale_center_x = center_x * width_ratio
  385. scale_center_y = center_y * height_ratio
  386. # Int coords on feature map/ground truth tensor
  387. left_idx = int(min(scale_left, width - 1))
  388. right_idx = int(min(scale_right, width - 1))
  389. top_idx = int(min(scale_top, height - 1))
  390. bottom_idx = int(min(scale_bottom, height - 1))
  391. # Generate gaussian heatmap
  392. scale_box_width = ceil(scale_right - scale_left)
  393. scale_box_height = ceil(scale_bottom - scale_top)
  394. radius = gaussian_radius((scale_box_height, scale_box_width),
  395. min_overlap=0.3)
  396. radius = max(0, int(radius))
  397. gt_tl_heatmap[batch_id, label] = gen_gaussian_target(
  398. gt_tl_heatmap[batch_id, label], [left_idx, top_idx],
  399. radius)
  400. gt_br_heatmap[batch_id, label] = gen_gaussian_target(
  401. gt_br_heatmap[batch_id, label], [right_idx, bottom_idx],
  402. radius)
  403. # Generate corner offset
  404. left_offset = scale_left - left_idx
  405. top_offset = scale_top - top_idx
  406. right_offset = scale_right - right_idx
  407. bottom_offset = scale_bottom - bottom_idx
  408. gt_tl_offset[batch_id, 0, top_idx, left_idx] = left_offset
  409. gt_tl_offset[batch_id, 1, top_idx, left_idx] = top_offset
  410. gt_br_offset[batch_id, 0, bottom_idx, right_idx] = right_offset
  411. gt_br_offset[batch_id, 1, bottom_idx,
  412. right_idx] = bottom_offset
  413. # Generate corner embedding
  414. if with_corner_emb:
  415. corner_match.append([[top_idx, left_idx],
  416. [bottom_idx, right_idx]])
  417. # Generate guiding shift
  418. if with_guiding_shift:
  419. gt_tl_guiding_shift[batch_id, 0, top_idx,
  420. left_idx] = scale_center_x - left_idx
  421. gt_tl_guiding_shift[batch_id, 1, top_idx,
  422. left_idx] = scale_center_y - top_idx
  423. gt_br_guiding_shift[batch_id, 0, bottom_idx,
  424. right_idx] = right_idx - scale_center_x
  425. gt_br_guiding_shift[
  426. batch_id, 1, bottom_idx,
  427. right_idx] = bottom_idx - scale_center_y
  428. # Generate centripetal shift
  429. if with_centripetal_shift:
  430. gt_tl_centripetal_shift[batch_id, 0, top_idx,
  431. left_idx] = log(scale_center_x -
  432. scale_left)
  433. gt_tl_centripetal_shift[batch_id, 1, top_idx,
  434. left_idx] = log(scale_center_y -
  435. scale_top)
  436. gt_br_centripetal_shift[batch_id, 0, bottom_idx,
  437. right_idx] = log(scale_right -
  438. scale_center_x)
  439. gt_br_centripetal_shift[batch_id, 1, bottom_idx,
  440. right_idx] = log(scale_bottom -
  441. scale_center_y)
  442. if with_corner_emb:
  443. match.append(corner_match)
  444. target_result = dict(
  445. topleft_heatmap=gt_tl_heatmap,
  446. topleft_offset=gt_tl_offset,
  447. bottomright_heatmap=gt_br_heatmap,
  448. bottomright_offset=gt_br_offset)
  449. if with_corner_emb:
  450. target_result.update(corner_embedding=match)
  451. if with_guiding_shift:
  452. target_result.update(
  453. topleft_guiding_shift=gt_tl_guiding_shift,
  454. bottomright_guiding_shift=gt_br_guiding_shift)
  455. if with_centripetal_shift:
  456. target_result.update(
  457. topleft_centripetal_shift=gt_tl_centripetal_shift,
  458. bottomright_centripetal_shift=gt_br_centripetal_shift)
  459. return target_result
  460. def loss_by_feat(
  461. self,
  462. tl_heats: List[Tensor],
  463. br_heats: List[Tensor],
  464. tl_embs: List[Tensor],
  465. br_embs: List[Tensor],
  466. tl_offs: List[Tensor],
  467. br_offs: List[Tensor],
  468. batch_gt_instances: InstanceList,
  469. batch_img_metas: List[dict],
  470. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  471. """Calculate the loss based on the features extracted by the detection
  472. head.
  473. Args:
  474. tl_heats (list[Tensor]): Top-left corner heatmaps for each level
  475. with shape (N, num_classes, H, W).
  476. br_heats (list[Tensor]): Bottom-right corner heatmaps for each
  477. level with shape (N, num_classes, H, W).
  478. tl_embs (list[Tensor]): Top-left corner embeddings for each level
  479. with shape (N, corner_emb_channels, H, W).
  480. br_embs (list[Tensor]): Bottom-right corner embeddings for each
  481. level with shape (N, corner_emb_channels, H, W).
  482. tl_offs (list[Tensor]): Top-left corner offsets for each level
  483. with shape (N, corner_offset_channels, H, W).
  484. br_offs (list[Tensor]): Bottom-right corner offsets for each level
  485. with shape (N, corner_offset_channels, H, W).
  486. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  487. gt_instance. It usually includes ``bboxes`` and ``labels``
  488. attributes.
  489. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  490. image size, scaling factor, etc.
  491. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  492. Specify which bounding boxes can be ignored when computing
  493. the loss.
  494. Returns:
  495. dict[str, Tensor]: A dictionary of loss components. Containing the
  496. following losses:
  497. - det_loss (list[Tensor]): Corner keypoint losses of all
  498. feature levels.
  499. - pull_loss (list[Tensor]): Part one of AssociativeEmbedding
  500. losses of all feature levels.
  501. - push_loss (list[Tensor]): Part two of AssociativeEmbedding
  502. losses of all feature levels.
  503. - off_loss (list[Tensor]): Corner offset losses of all feature
  504. levels.
  505. """
  506. gt_bboxes = [
  507. gt_instances.bboxes for gt_instances in batch_gt_instances
  508. ]
  509. gt_labels = [
  510. gt_instances.labels for gt_instances in batch_gt_instances
  511. ]
  512. targets = self.get_targets(
  513. gt_bboxes,
  514. gt_labels,
  515. tl_heats[-1].shape,
  516. batch_img_metas[0]['batch_input_shape'],
  517. with_corner_emb=self.with_corner_emb)
  518. mlvl_targets = [targets for _ in range(self.num_feat_levels)]
  519. det_losses, pull_losses, push_losses, off_losses = multi_apply(
  520. self.loss_by_feat_single, tl_heats, br_heats, tl_embs, br_embs,
  521. tl_offs, br_offs, mlvl_targets)
  522. loss_dict = dict(det_loss=det_losses, off_loss=off_losses)
  523. if self.with_corner_emb:
  524. loss_dict.update(pull_loss=pull_losses, push_loss=push_losses)
  525. return loss_dict
  526. def loss_by_feat_single(self, tl_hmp: Tensor, br_hmp: Tensor,
  527. tl_emb: Optional[Tensor], br_emb: Optional[Tensor],
  528. tl_off: Tensor, br_off: Tensor,
  529. targets: dict) -> Tuple[Tensor, ...]:
  530. """Calculate the loss of a single scale level based on the features
  531. extracted by the detection head.
  532. Args:
  533. tl_hmp (Tensor): Top-left corner heatmap for current level with
  534. shape (N, num_classes, H, W).
  535. br_hmp (Tensor): Bottom-right corner heatmap for current level with
  536. shape (N, num_classes, H, W).
  537. tl_emb (Tensor, optional): Top-left corner embedding for current
  538. level with shape (N, corner_emb_channels, H, W).
  539. br_emb (Tensor, optional): Bottom-right corner embedding for
  540. current level with shape (N, corner_emb_channels, H, W).
  541. tl_off (Tensor): Top-left corner offset for current level with
  542. shape (N, corner_offset_channels, H, W).
  543. br_off (Tensor): Bottom-right corner offset for current level with
  544. shape (N, corner_offset_channels, H, W).
  545. targets (dict): Corner target generated by `get_targets`.
  546. Returns:
  547. tuple[torch.Tensor]: Losses of the head's different branches
  548. containing the following losses:
  549. - det_loss (Tensor): Corner keypoint loss.
  550. - pull_loss (Tensor): Part one of AssociativeEmbedding loss.
  551. - push_loss (Tensor): Part two of AssociativeEmbedding loss.
  552. - off_loss (Tensor): Corner offset loss.
  553. """
  554. gt_tl_hmp = targets['topleft_heatmap']
  555. gt_br_hmp = targets['bottomright_heatmap']
  556. gt_tl_off = targets['topleft_offset']
  557. gt_br_off = targets['bottomright_offset']
  558. gt_embedding = targets['corner_embedding']
  559. # Detection loss
  560. tl_det_loss = self.loss_heatmap(
  561. tl_hmp.sigmoid(),
  562. gt_tl_hmp,
  563. avg_factor=max(1,
  564. gt_tl_hmp.eq(1).sum()))
  565. br_det_loss = self.loss_heatmap(
  566. br_hmp.sigmoid(),
  567. gt_br_hmp,
  568. avg_factor=max(1,
  569. gt_br_hmp.eq(1).sum()))
  570. det_loss = (tl_det_loss + br_det_loss) / 2.0
  571. # AssociativeEmbedding loss
  572. if self.with_corner_emb and self.loss_embedding is not None:
  573. pull_loss, push_loss = self.loss_embedding(tl_emb, br_emb,
  574. gt_embedding)
  575. else:
  576. pull_loss, push_loss = None, None
  577. # Offset loss
  578. # We only compute the offset loss at the real corner position.
  579. # The value of real corner would be 1 in heatmap ground truth.
  580. # The mask is computed in class agnostic mode and its shape is
  581. # batch * 1 * width * height.
  582. tl_off_mask = gt_tl_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
  583. gt_tl_hmp)
  584. br_off_mask = gt_br_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
  585. gt_br_hmp)
  586. tl_off_loss = self.loss_offset(
  587. tl_off,
  588. gt_tl_off,
  589. tl_off_mask,
  590. avg_factor=max(1, tl_off_mask.sum()))
  591. br_off_loss = self.loss_offset(
  592. br_off,
  593. gt_br_off,
  594. br_off_mask,
  595. avg_factor=max(1, br_off_mask.sum()))
  596. off_loss = (tl_off_loss + br_off_loss) / 2.0
  597. return det_loss, pull_loss, push_loss, off_loss
  598. def predict_by_feat(self,
  599. tl_heats: List[Tensor],
  600. br_heats: List[Tensor],
  601. tl_embs: List[Tensor],
  602. br_embs: List[Tensor],
  603. tl_offs: List[Tensor],
  604. br_offs: List[Tensor],
  605. batch_img_metas: Optional[List[dict]] = None,
  606. rescale: bool = False,
  607. with_nms: bool = True) -> InstanceList:
  608. """Transform a batch of output features extracted from the head into
  609. bbox results.
  610. Args:
  611. tl_heats (list[Tensor]): Top-left corner heatmaps for each level
  612. with shape (N, num_classes, H, W).
  613. br_heats (list[Tensor]): Bottom-right corner heatmaps for each
  614. level with shape (N, num_classes, H, W).
  615. tl_embs (list[Tensor]): Top-left corner embeddings for each level
  616. with shape (N, corner_emb_channels, H, W).
  617. br_embs (list[Tensor]): Bottom-right corner embeddings for each
  618. level with shape (N, corner_emb_channels, H, W).
  619. tl_offs (list[Tensor]): Top-left corner offsets for each level
  620. with shape (N, corner_offset_channels, H, W).
  621. br_offs (list[Tensor]): Bottom-right corner offsets for each level
  622. with shape (N, corner_offset_channels, H, W).
  623. batch_img_metas (list[dict], optional): Batch image meta info.
  624. Defaults to None.
  625. rescale (bool): If True, return boxes in original image space.
  626. Defaults to False.
  627. with_nms (bool): If True, do nms before return boxes.
  628. Defaults to True.
  629. Returns:
  630. list[:obj:`InstanceData`]: Object detection results of each image
  631. after the post process. Each item usually contains following keys.
  632. - scores (Tensor): Classification scores, has a shape
  633. (num_instance, )
  634. - labels (Tensor): Labels of bboxes, has a shape
  635. (num_instances, ).
  636. - bboxes (Tensor): Has a shape (num_instances, 4),
  637. the last dimension 4 arrange as (x1, y1, x2, y2).
  638. """
  639. assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len(
  640. batch_img_metas)
  641. result_list = []
  642. for img_id in range(len(batch_img_metas)):
  643. result_list.append(
  644. self._predict_by_feat_single(
  645. tl_heats[-1][img_id:img_id + 1, :],
  646. br_heats[-1][img_id:img_id + 1, :],
  647. tl_offs[-1][img_id:img_id + 1, :],
  648. br_offs[-1][img_id:img_id + 1, :],
  649. batch_img_metas[img_id],
  650. tl_emb=tl_embs[-1][img_id:img_id + 1, :],
  651. br_emb=br_embs[-1][img_id:img_id + 1, :],
  652. rescale=rescale,
  653. with_nms=with_nms))
  654. return result_list
  655. def _predict_by_feat_single(self,
  656. tl_heat: Tensor,
  657. br_heat: Tensor,
  658. tl_off: Tensor,
  659. br_off: Tensor,
  660. img_meta: dict,
  661. tl_emb: Optional[Tensor] = None,
  662. br_emb: Optional[Tensor] = None,
  663. tl_centripetal_shift: Optional[Tensor] = None,
  664. br_centripetal_shift: Optional[Tensor] = None,
  665. rescale: bool = False,
  666. with_nms: bool = True) -> InstanceData:
  667. """Transform a single image's features extracted from the head into
  668. bbox results.
  669. Args:
  670. tl_heat (Tensor): Top-left corner heatmap for current level with
  671. shape (N, num_classes, H, W).
  672. br_heat (Tensor): Bottom-right corner heatmap for current level
  673. with shape (N, num_classes, H, W).
  674. tl_off (Tensor): Top-left corner offset for current level with
  675. shape (N, corner_offset_channels, H, W).
  676. br_off (Tensor): Bottom-right corner offset for current level with
  677. shape (N, corner_offset_channels, H, W).
  678. img_meta (dict): Meta information of current image, e.g.,
  679. image size, scaling factor, etc.
  680. tl_emb (Tensor): Top-left corner embedding for current level with
  681. shape (N, corner_emb_channels, H, W).
  682. br_emb (Tensor): Bottom-right corner embedding for current level
  683. with shape (N, corner_emb_channels, H, W).
  684. tl_centripetal_shift: Top-left corner's centripetal shift for
  685. current level with shape (N, 2, H, W).
  686. br_centripetal_shift: Bottom-right corner's centripetal shift for
  687. current level with shape (N, 2, H, W).
  688. rescale (bool): If True, return boxes in original image space.
  689. Defaults to False.
  690. with_nms (bool): If True, do nms before return boxes.
  691. Defaults to True.
  692. Returns:
  693. :obj:`InstanceData`: Detection results of each image
  694. after the post process.
  695. Each item usually contains following keys.
  696. - scores (Tensor): Classification scores, has a shape
  697. (num_instance, )
  698. - labels (Tensor): Labels of bboxes, has a shape
  699. (num_instances, ).
  700. - bboxes (Tensor): Has a shape (num_instances, 4),
  701. the last dimension 4 arrange as (x1, y1, x2, y2).
  702. """
  703. if isinstance(img_meta, (list, tuple)):
  704. img_meta = img_meta[0]
  705. batch_bboxes, batch_scores, batch_clses = self._decode_heatmap(
  706. tl_heat=tl_heat.sigmoid(),
  707. br_heat=br_heat.sigmoid(),
  708. tl_off=tl_off,
  709. br_off=br_off,
  710. tl_emb=tl_emb,
  711. br_emb=br_emb,
  712. tl_centripetal_shift=tl_centripetal_shift,
  713. br_centripetal_shift=br_centripetal_shift,
  714. img_meta=img_meta,
  715. k=self.test_cfg.corner_topk,
  716. kernel=self.test_cfg.local_maximum_kernel,
  717. distance_threshold=self.test_cfg.distance_threshold)
  718. if rescale and 'scale_factor' in img_meta:
  719. batch_bboxes /= batch_bboxes.new_tensor(
  720. img_meta['scale_factor']).repeat((1, 2))
  721. bboxes = batch_bboxes.view([-1, 4])
  722. scores = batch_scores.view(-1)
  723. clses = batch_clses.view(-1)
  724. det_bboxes = torch.cat([bboxes, scores.unsqueeze(-1)], -1)
  725. keepinds = (det_bboxes[:, -1] > -0.1)
  726. det_bboxes = det_bboxes[keepinds]
  727. det_labels = clses[keepinds]
  728. if with_nms:
  729. det_bboxes, det_labels = self._bboxes_nms(det_bboxes, det_labels,
  730. self.test_cfg)
  731. results = InstanceData()
  732. results.bboxes = det_bboxes[..., :4]
  733. results.scores = det_bboxes[..., 4]
  734. results.labels = det_labels
  735. return results
  736. def _bboxes_nms(self, bboxes: Tensor, labels: Tensor,
  737. cfg: ConfigDict) -> Tuple[Tensor, Tensor]:
  738. """bboxes nms."""
  739. if 'nms_cfg' in cfg:
  740. warning.warn('nms_cfg in test_cfg will be deprecated. '
  741. 'Please rename it as nms')
  742. if 'nms' not in cfg:
  743. cfg.nms = cfg.nms_cfg
  744. if labels.numel() > 0:
  745. max_num = cfg.max_per_img
  746. bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:,
  747. -1].contiguous(),
  748. labels, cfg.nms)
  749. if max_num > 0:
  750. bboxes = bboxes[:max_num]
  751. labels = labels[keep][:max_num]
  752. return bboxes, labels
  753. def _decode_heatmap(self,
  754. tl_heat: Tensor,
  755. br_heat: Tensor,
  756. tl_off: Tensor,
  757. br_off: Tensor,
  758. tl_emb: Optional[Tensor] = None,
  759. br_emb: Optional[Tensor] = None,
  760. tl_centripetal_shift: Optional[Tensor] = None,
  761. br_centripetal_shift: Optional[Tensor] = None,
  762. img_meta: Optional[dict] = None,
  763. k: int = 100,
  764. kernel: int = 3,
  765. distance_threshold: float = 0.5,
  766. num_dets: int = 1000) -> Tuple[Tensor, Tensor, Tensor]:
  767. """Transform outputs into detections raw bbox prediction.
  768. Args:
  769. tl_heat (Tensor): Top-left corner heatmap for current level with
  770. shape (N, num_classes, H, W).
  771. br_heat (Tensor): Bottom-right corner heatmap for current level
  772. with shape (N, num_classes, H, W).
  773. tl_off (Tensor): Top-left corner offset for current level with
  774. shape (N, corner_offset_channels, H, W).
  775. br_off (Tensor): Bottom-right corner offset for current level with
  776. shape (N, corner_offset_channels, H, W).
  777. tl_emb (Tensor, Optional): Top-left corner embedding for current
  778. level with shape (N, corner_emb_channels, H, W).
  779. br_emb (Tensor, Optional): Bottom-right corner embedding for
  780. current level with shape (N, corner_emb_channels, H, W).
  781. tl_centripetal_shift (Tensor, Optional): Top-left centripetal shift
  782. for current level with shape (N, 2, H, W).
  783. br_centripetal_shift (Tensor, Optional): Bottom-right centripetal
  784. shift for current level with shape (N, 2, H, W).
  785. img_meta (dict): Meta information of current image, e.g.,
  786. image size, scaling factor, etc.
  787. k (int): Get top k corner keypoints from heatmap.
  788. kernel (int): Max pooling kernel for extract local maximum pixels.
  789. distance_threshold (float): Distance threshold. Top-left and
  790. bottom-right corner keypoints with feature distance less than
  791. the threshold will be regarded as keypoints from same object.
  792. num_dets (int): Num of raw boxes before doing nms.
  793. Returns:
  794. tuple[torch.Tensor]: Decoded output of CornerHead, containing the
  795. following Tensors:
  796. - bboxes (Tensor): Coords of each box.
  797. - scores (Tensor): Scores of each box.
  798. - clses (Tensor): Categories of each box.
  799. """
  800. with_embedding = tl_emb is not None and br_emb is not None
  801. with_centripetal_shift = (
  802. tl_centripetal_shift is not None
  803. and br_centripetal_shift is not None)
  804. assert with_embedding + with_centripetal_shift == 1
  805. batch, _, height, width = tl_heat.size()
  806. if torch.onnx.is_in_onnx_export():
  807. inp_h, inp_w = img_meta['pad_shape_for_onnx'][:2]
  808. else:
  809. inp_h, inp_w = img_meta['batch_input_shape'][:2]
  810. # perform nms on heatmaps
  811. tl_heat = get_local_maximum(tl_heat, kernel=kernel)
  812. br_heat = get_local_maximum(br_heat, kernel=kernel)
  813. tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = get_topk_from_heatmap(
  814. tl_heat, k=k)
  815. br_scores, br_inds, br_clses, br_ys, br_xs = get_topk_from_heatmap(
  816. br_heat, k=k)
  817. # We use repeat instead of expand here because expand is a
  818. # shallow-copy function. Thus it could cause unexpected testing result
  819. # sometimes. Using expand will decrease about 10% mAP during testing
  820. # compared to repeat.
  821. tl_ys = tl_ys.view(batch, k, 1).repeat(1, 1, k)
  822. tl_xs = tl_xs.view(batch, k, 1).repeat(1, 1, k)
  823. br_ys = br_ys.view(batch, 1, k).repeat(1, k, 1)
  824. br_xs = br_xs.view(batch, 1, k).repeat(1, k, 1)
  825. tl_off = transpose_and_gather_feat(tl_off, tl_inds)
  826. tl_off = tl_off.view(batch, k, 1, 2)
  827. br_off = transpose_and_gather_feat(br_off, br_inds)
  828. br_off = br_off.view(batch, 1, k, 2)
  829. tl_xs = tl_xs + tl_off[..., 0]
  830. tl_ys = tl_ys + tl_off[..., 1]
  831. br_xs = br_xs + br_off[..., 0]
  832. br_ys = br_ys + br_off[..., 1]
  833. if with_centripetal_shift:
  834. tl_centripetal_shift = transpose_and_gather_feat(
  835. tl_centripetal_shift, tl_inds).view(batch, k, 1, 2).exp()
  836. br_centripetal_shift = transpose_and_gather_feat(
  837. br_centripetal_shift, br_inds).view(batch, 1, k, 2).exp()
  838. tl_ctxs = tl_xs + tl_centripetal_shift[..., 0]
  839. tl_ctys = tl_ys + tl_centripetal_shift[..., 1]
  840. br_ctxs = br_xs - br_centripetal_shift[..., 0]
  841. br_ctys = br_ys - br_centripetal_shift[..., 1]
  842. # all possible boxes based on top k corners (ignoring class)
  843. tl_xs *= (inp_w / width)
  844. tl_ys *= (inp_h / height)
  845. br_xs *= (inp_w / width)
  846. br_ys *= (inp_h / height)
  847. if with_centripetal_shift:
  848. tl_ctxs *= (inp_w / width)
  849. tl_ctys *= (inp_h / height)
  850. br_ctxs *= (inp_w / width)
  851. br_ctys *= (inp_h / height)
  852. x_off, y_off = 0, 0 # no crop
  853. if not torch.onnx.is_in_onnx_export():
  854. # since `RandomCenterCropPad` is done on CPU with numpy and it's
  855. # not dynamic traceable when exporting to ONNX, thus 'border'
  856. # does not appears as key in 'img_meta'. As a tmp solution,
  857. # we move this 'border' handle part to the postprocess after
  858. # finished exporting to ONNX, which is handle in
  859. # `mmdet/core/export/model_wrappers.py`. Though difference between
  860. # pytorch and exported onnx model, it might be ignored since
  861. # comparable performance is achieved between them (e.g. 40.4 vs
  862. # 40.6 on COCO val2017, for CornerNet without test-time flip)
  863. if 'border' in img_meta:
  864. x_off = img_meta['border'][2]
  865. y_off = img_meta['border'][0]
  866. tl_xs -= x_off
  867. tl_ys -= y_off
  868. br_xs -= x_off
  869. br_ys -= y_off
  870. zeros = tl_xs.new_zeros(*tl_xs.size())
  871. tl_xs = torch.where(tl_xs > 0.0, tl_xs, zeros)
  872. tl_ys = torch.where(tl_ys > 0.0, tl_ys, zeros)
  873. br_xs = torch.where(br_xs > 0.0, br_xs, zeros)
  874. br_ys = torch.where(br_ys > 0.0, br_ys, zeros)
  875. bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
  876. area_bboxes = ((br_xs - tl_xs) * (br_ys - tl_ys)).abs()
  877. if with_centripetal_shift:
  878. tl_ctxs -= x_off
  879. tl_ctys -= y_off
  880. br_ctxs -= x_off
  881. br_ctys -= y_off
  882. tl_ctxs *= tl_ctxs.gt(0.0).type_as(tl_ctxs)
  883. tl_ctys *= tl_ctys.gt(0.0).type_as(tl_ctys)
  884. br_ctxs *= br_ctxs.gt(0.0).type_as(br_ctxs)
  885. br_ctys *= br_ctys.gt(0.0).type_as(br_ctys)
  886. ct_bboxes = torch.stack((tl_ctxs, tl_ctys, br_ctxs, br_ctys),
  887. dim=3)
  888. area_ct_bboxes = ((br_ctxs - tl_ctxs) * (br_ctys - tl_ctys)).abs()
  889. rcentral = torch.zeros_like(ct_bboxes)
  890. # magic nums from paper section 4.1
  891. mu = torch.ones_like(area_bboxes) / 2.4
  892. mu[area_bboxes > 3500] = 1 / 2.1 # large bbox have smaller mu
  893. bboxes_center_x = (bboxes[..., 0] + bboxes[..., 2]) / 2
  894. bboxes_center_y = (bboxes[..., 1] + bboxes[..., 3]) / 2
  895. rcentral[..., 0] = bboxes_center_x - mu * (bboxes[..., 2] -
  896. bboxes[..., 0]) / 2
  897. rcentral[..., 1] = bboxes_center_y - mu * (bboxes[..., 3] -
  898. bboxes[..., 1]) / 2
  899. rcentral[..., 2] = bboxes_center_x + mu * (bboxes[..., 2] -
  900. bboxes[..., 0]) / 2
  901. rcentral[..., 3] = bboxes_center_y + mu * (bboxes[..., 3] -
  902. bboxes[..., 1]) / 2
  903. area_rcentral = ((rcentral[..., 2] - rcentral[..., 0]) *
  904. (rcentral[..., 3] - rcentral[..., 1])).abs()
  905. dists = area_ct_bboxes / area_rcentral
  906. tl_ctx_inds = (ct_bboxes[..., 0] <= rcentral[..., 0]) | (
  907. ct_bboxes[..., 0] >= rcentral[..., 2])
  908. tl_cty_inds = (ct_bboxes[..., 1] <= rcentral[..., 1]) | (
  909. ct_bboxes[..., 1] >= rcentral[..., 3])
  910. br_ctx_inds = (ct_bboxes[..., 2] <= rcentral[..., 0]) | (
  911. ct_bboxes[..., 2] >= rcentral[..., 2])
  912. br_cty_inds = (ct_bboxes[..., 3] <= rcentral[..., 1]) | (
  913. ct_bboxes[..., 3] >= rcentral[..., 3])
  914. if with_embedding:
  915. tl_emb = transpose_and_gather_feat(tl_emb, tl_inds)
  916. tl_emb = tl_emb.view(batch, k, 1)
  917. br_emb = transpose_and_gather_feat(br_emb, br_inds)
  918. br_emb = br_emb.view(batch, 1, k)
  919. dists = torch.abs(tl_emb - br_emb)
  920. tl_scores = tl_scores.view(batch, k, 1).repeat(1, 1, k)
  921. br_scores = br_scores.view(batch, 1, k).repeat(1, k, 1)
  922. scores = (tl_scores + br_scores) / 2 # scores for all possible boxes
  923. # tl and br should have same class
  924. tl_clses = tl_clses.view(batch, k, 1).repeat(1, 1, k)
  925. br_clses = br_clses.view(batch, 1, k).repeat(1, k, 1)
  926. cls_inds = (tl_clses != br_clses)
  927. # reject boxes based on distances
  928. dist_inds = dists > distance_threshold
  929. # reject boxes based on widths and heights
  930. width_inds = (br_xs <= tl_xs)
  931. height_inds = (br_ys <= tl_ys)
  932. # No use `scores[cls_inds]`, instead we use `torch.where` here.
  933. # Since only 1-D indices with type 'tensor(bool)' are supported
  934. # when exporting to ONNX, any other bool indices with more dimensions
  935. # (e.g. 2-D bool tensor) as input parameter in node is invalid
  936. negative_scores = -1 * torch.ones_like(scores)
  937. scores = torch.where(cls_inds, negative_scores, scores)
  938. scores = torch.where(width_inds, negative_scores, scores)
  939. scores = torch.where(height_inds, negative_scores, scores)
  940. scores = torch.where(dist_inds, negative_scores, scores)
  941. if with_centripetal_shift:
  942. scores[tl_ctx_inds] = -1
  943. scores[tl_cty_inds] = -1
  944. scores[br_ctx_inds] = -1
  945. scores[br_cty_inds] = -1
  946. scores = scores.view(batch, -1)
  947. scores, inds = torch.topk(scores, num_dets)
  948. scores = scores.unsqueeze(2)
  949. bboxes = bboxes.view(batch, -1, 4)
  950. bboxes = gather_feat(bboxes, inds)
  951. clses = tl_clses.contiguous().view(batch, -1, 1)
  952. clses = gather_feat(clses, inds)
  953. return bboxes, scores, clses