centripetal_head.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule
  5. from mmcv.ops import DeformConv2d
  6. from mmengine.model import normal_init
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import (ConfigType, InstanceList, OptInstanceList,
  10. OptMultiConfig)
  11. from ..utils import multi_apply
  12. from .corner_head import CornerHead
  13. @MODELS.register_module()
  14. class CentripetalHead(CornerHead):
  15. """Head of CentripetalNet: Pursuing High-quality Keypoint Pairs for Object
  16. Detection.
  17. CentripetalHead inherits from :class:`CornerHead`. It removes the
  18. embedding branch and adds guiding shift and centripetal shift branches.
  19. More details can be found in the `paper
  20. <https://arxiv.org/abs/2003.09119>`_ .
  21. Args:
  22. num_classes (int): Number of categories excluding the background
  23. category.
  24. in_channels (int): Number of channels in the input feature map.
  25. num_feat_levels (int): Levels of feature from the previous module.
  26. 2 for HourglassNet-104 and 1 for HourglassNet-52. HourglassNet-104
  27. outputs the final feature and intermediate supervision feature and
  28. HourglassNet-52 only outputs the final feature. Defaults to 2.
  29. corner_emb_channels (int): Channel of embedding vector. Defaults to 1.
  30. train_cfg (:obj:`ConfigDict` or dict, optional): Training config.
  31. Useless in CornerHead, but we keep this variable for
  32. SingleStageDetector.
  33. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  34. CornerHead.
  35. loss_heatmap (:obj:`ConfigDict` or dict): Config of corner heatmap
  36. loss. Defaults to GaussianFocalLoss.
  37. loss_embedding (:obj:`ConfigDict` or dict): Config of corner embedding
  38. loss. Defaults to AssociativeEmbeddingLoss.
  39. loss_offset (:obj:`ConfigDict` or dict): Config of corner offset loss.
  40. Defaults to SmoothL1Loss.
  41. loss_guiding_shift (:obj:`ConfigDict` or dict): Config of
  42. guiding shift loss. Defaults to SmoothL1Loss.
  43. loss_centripetal_shift (:obj:`ConfigDict` or dict): Config of
  44. centripetal shift loss. Defaults to SmoothL1Loss.
  45. init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
  46. the initialization.
  47. """
  48. def __init__(self,
  49. *args,
  50. centripetal_shift_channels: int = 2,
  51. guiding_shift_channels: int = 2,
  52. feat_adaption_conv_kernel: int = 3,
  53. loss_guiding_shift: ConfigType = dict(
  54. type='SmoothL1Loss', beta=1.0, loss_weight=0.05),
  55. loss_centripetal_shift: ConfigType = dict(
  56. type='SmoothL1Loss', beta=1.0, loss_weight=1),
  57. init_cfg: OptMultiConfig = None,
  58. **kwargs) -> None:
  59. assert init_cfg is None, 'To prevent abnormal initialization ' \
  60. 'behavior, init_cfg is not allowed to be set'
  61. assert centripetal_shift_channels == 2, (
  62. 'CentripetalHead only support centripetal_shift_channels == 2')
  63. self.centripetal_shift_channels = centripetal_shift_channels
  64. assert guiding_shift_channels == 2, (
  65. 'CentripetalHead only support guiding_shift_channels == 2')
  66. self.guiding_shift_channels = guiding_shift_channels
  67. self.feat_adaption_conv_kernel = feat_adaption_conv_kernel
  68. super().__init__(*args, init_cfg=init_cfg, **kwargs)
  69. self.loss_guiding_shift = MODELS.build(loss_guiding_shift)
  70. self.loss_centripetal_shift = MODELS.build(loss_centripetal_shift)
  71. def _init_centripetal_layers(self) -> None:
  72. """Initialize centripetal layers.
  73. Including feature adaption deform convs (feat_adaption), deform offset
  74. prediction convs (dcn_off), guiding shift (guiding_shift) and
  75. centripetal shift ( centripetal_shift). Each branch has two parts:
  76. prefix `tl_` for top-left and `br_` for bottom-right.
  77. """
  78. self.tl_feat_adaption = nn.ModuleList()
  79. self.br_feat_adaption = nn.ModuleList()
  80. self.tl_dcn_offset = nn.ModuleList()
  81. self.br_dcn_offset = nn.ModuleList()
  82. self.tl_guiding_shift = nn.ModuleList()
  83. self.br_guiding_shift = nn.ModuleList()
  84. self.tl_centripetal_shift = nn.ModuleList()
  85. self.br_centripetal_shift = nn.ModuleList()
  86. for _ in range(self.num_feat_levels):
  87. self.tl_feat_adaption.append(
  88. DeformConv2d(self.in_channels, self.in_channels,
  89. self.feat_adaption_conv_kernel, 1, 1))
  90. self.br_feat_adaption.append(
  91. DeformConv2d(self.in_channels, self.in_channels,
  92. self.feat_adaption_conv_kernel, 1, 1))
  93. self.tl_guiding_shift.append(
  94. self._make_layers(
  95. out_channels=self.guiding_shift_channels,
  96. in_channels=self.in_channels))
  97. self.br_guiding_shift.append(
  98. self._make_layers(
  99. out_channels=self.guiding_shift_channels,
  100. in_channels=self.in_channels))
  101. self.tl_dcn_offset.append(
  102. ConvModule(
  103. self.guiding_shift_channels,
  104. self.feat_adaption_conv_kernel**2 *
  105. self.guiding_shift_channels,
  106. 1,
  107. bias=False,
  108. act_cfg=None))
  109. self.br_dcn_offset.append(
  110. ConvModule(
  111. self.guiding_shift_channels,
  112. self.feat_adaption_conv_kernel**2 *
  113. self.guiding_shift_channels,
  114. 1,
  115. bias=False,
  116. act_cfg=None))
  117. self.tl_centripetal_shift.append(
  118. self._make_layers(
  119. out_channels=self.centripetal_shift_channels,
  120. in_channels=self.in_channels))
  121. self.br_centripetal_shift.append(
  122. self._make_layers(
  123. out_channels=self.centripetal_shift_channels,
  124. in_channels=self.in_channels))
  125. def _init_layers(self) -> None:
  126. """Initialize layers for CentripetalHead.
  127. Including two parts: CornerHead layers and CentripetalHead layers
  128. """
  129. super()._init_layers() # using _init_layers in CornerHead
  130. self._init_centripetal_layers()
  131. def init_weights(self) -> None:
  132. super().init_weights()
  133. for i in range(self.num_feat_levels):
  134. normal_init(self.tl_feat_adaption[i], std=0.01)
  135. normal_init(self.br_feat_adaption[i], std=0.01)
  136. normal_init(self.tl_dcn_offset[i].conv, std=0.1)
  137. normal_init(self.br_dcn_offset[i].conv, std=0.1)
  138. _ = [x.conv.reset_parameters() for x in self.tl_guiding_shift[i]]
  139. _ = [x.conv.reset_parameters() for x in self.br_guiding_shift[i]]
  140. _ = [
  141. x.conv.reset_parameters() for x in self.tl_centripetal_shift[i]
  142. ]
  143. _ = [
  144. x.conv.reset_parameters() for x in self.br_centripetal_shift[i]
  145. ]
  146. def forward_single(self, x: Tensor, lvl_ind: int) -> List[Tensor]:
  147. """Forward feature of a single level.
  148. Args:
  149. x (Tensor): Feature of a single level.
  150. lvl_ind (int): Level index of current feature.
  151. Returns:
  152. tuple[Tensor]: A tuple of CentripetalHead's output for current
  153. feature level. Containing the following Tensors:
  154. - tl_heat (Tensor): Predicted top-left corner heatmap.
  155. - br_heat (Tensor): Predicted bottom-right corner heatmap.
  156. - tl_off (Tensor): Predicted top-left offset heatmap.
  157. - br_off (Tensor): Predicted bottom-right offset heatmap.
  158. - tl_guiding_shift (Tensor): Predicted top-left guiding shift
  159. heatmap.
  160. - br_guiding_shift (Tensor): Predicted bottom-right guiding
  161. shift heatmap.
  162. - tl_centripetal_shift (Tensor): Predicted top-left centripetal
  163. shift heatmap.
  164. - br_centripetal_shift (Tensor): Predicted bottom-right
  165. centripetal shift heatmap.
  166. """
  167. tl_heat, br_heat, _, _, tl_off, br_off, tl_pool, br_pool = super(
  168. ).forward_single(
  169. x, lvl_ind, return_pool=True)
  170. tl_guiding_shift = self.tl_guiding_shift[lvl_ind](tl_pool)
  171. br_guiding_shift = self.br_guiding_shift[lvl_ind](br_pool)
  172. tl_dcn_offset = self.tl_dcn_offset[lvl_ind](tl_guiding_shift.detach())
  173. br_dcn_offset = self.br_dcn_offset[lvl_ind](br_guiding_shift.detach())
  174. tl_feat_adaption = self.tl_feat_adaption[lvl_ind](tl_pool,
  175. tl_dcn_offset)
  176. br_feat_adaption = self.br_feat_adaption[lvl_ind](br_pool,
  177. br_dcn_offset)
  178. tl_centripetal_shift = self.tl_centripetal_shift[lvl_ind](
  179. tl_feat_adaption)
  180. br_centripetal_shift = self.br_centripetal_shift[lvl_ind](
  181. br_feat_adaption)
  182. result_list = [
  183. tl_heat, br_heat, tl_off, br_off, tl_guiding_shift,
  184. br_guiding_shift, tl_centripetal_shift, br_centripetal_shift
  185. ]
  186. return result_list
  187. def loss_by_feat(
  188. self,
  189. tl_heats: List[Tensor],
  190. br_heats: List[Tensor],
  191. tl_offs: List[Tensor],
  192. br_offs: List[Tensor],
  193. tl_guiding_shifts: List[Tensor],
  194. br_guiding_shifts: List[Tensor],
  195. tl_centripetal_shifts: List[Tensor],
  196. br_centripetal_shifts: List[Tensor],
  197. batch_gt_instances: InstanceList,
  198. batch_img_metas: List[dict],
  199. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  200. """Calculate the loss based on the features extracted by the detection
  201. head.
  202. Args:
  203. tl_heats (list[Tensor]): Top-left corner heatmaps for each level
  204. with shape (N, num_classes, H, W).
  205. br_heats (list[Tensor]): Bottom-right corner heatmaps for each
  206. level with shape (N, num_classes, H, W).
  207. tl_offs (list[Tensor]): Top-left corner offsets for each level
  208. with shape (N, corner_offset_channels, H, W).
  209. br_offs (list[Tensor]): Bottom-right corner offsets for each level
  210. with shape (N, corner_offset_channels, H, W).
  211. tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
  212. level with shape (N, guiding_shift_channels, H, W).
  213. br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
  214. each level with shape (N, guiding_shift_channels, H, W).
  215. tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
  216. for each level with shape (N, centripetal_shift_channels, H,
  217. W).
  218. br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
  219. shifts for each level with shape (N,
  220. centripetal_shift_channels, H, W).
  221. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  222. gt_instance. It usually includes ``bboxes`` and ``labels``
  223. attributes.
  224. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  225. image size, scaling factor, etc.
  226. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  227. Specify which bounding boxes can be ignored when computing
  228. the loss.
  229. Returns:
  230. dict[str, Tensor]: A dictionary of loss components. Containing the
  231. following losses:
  232. - det_loss (list[Tensor]): Corner keypoint losses of all
  233. feature levels.
  234. - off_loss (list[Tensor]): Corner offset losses of all feature
  235. levels.
  236. - guiding_loss (list[Tensor]): Guiding shift losses of all
  237. feature levels.
  238. - centripetal_loss (list[Tensor]): Centripetal shift losses of
  239. all feature levels.
  240. """
  241. gt_bboxes = [
  242. gt_instances.bboxes for gt_instances in batch_gt_instances
  243. ]
  244. gt_labels = [
  245. gt_instances.labels for gt_instances in batch_gt_instances
  246. ]
  247. targets = self.get_targets(
  248. gt_bboxes,
  249. gt_labels,
  250. tl_heats[-1].shape,
  251. batch_img_metas[0]['batch_input_shape'],
  252. with_corner_emb=self.with_corner_emb,
  253. with_guiding_shift=True,
  254. with_centripetal_shift=True)
  255. mlvl_targets = [targets for _ in range(self.num_feat_levels)]
  256. [det_losses, off_losses, guiding_losses, centripetal_losses
  257. ] = multi_apply(self.loss_by_feat_single, tl_heats, br_heats, tl_offs,
  258. br_offs, tl_guiding_shifts, br_guiding_shifts,
  259. tl_centripetal_shifts, br_centripetal_shifts,
  260. mlvl_targets)
  261. loss_dict = dict(
  262. det_loss=det_losses,
  263. off_loss=off_losses,
  264. guiding_loss=guiding_losses,
  265. centripetal_loss=centripetal_losses)
  266. return loss_dict
  267. def loss_by_feat_single(self, tl_hmp: Tensor, br_hmp: Tensor,
  268. tl_off: Tensor, br_off: Tensor,
  269. tl_guiding_shift: Tensor, br_guiding_shift: Tensor,
  270. tl_centripetal_shift: Tensor,
  271. br_centripetal_shift: Tensor,
  272. targets: dict) -> Tuple[Tensor, ...]:
  273. """Calculate the loss of a single scale level based on the features
  274. extracted by the detection head.
  275. Args:
  276. tl_hmp (Tensor): Top-left corner heatmap for current level with
  277. shape (N, num_classes, H, W).
  278. br_hmp (Tensor): Bottom-right corner heatmap for current level with
  279. shape (N, num_classes, H, W).
  280. tl_off (Tensor): Top-left corner offset for current level with
  281. shape (N, corner_offset_channels, H, W).
  282. br_off (Tensor): Bottom-right corner offset for current level with
  283. shape (N, corner_offset_channels, H, W).
  284. tl_guiding_shift (Tensor): Top-left guiding shift for current level
  285. with shape (N, guiding_shift_channels, H, W).
  286. br_guiding_shift (Tensor): Bottom-right guiding shift for current
  287. level with shape (N, guiding_shift_channels, H, W).
  288. tl_centripetal_shift (Tensor): Top-left centripetal shift for
  289. current level with shape (N, centripetal_shift_channels, H, W).
  290. br_centripetal_shift (Tensor): Bottom-right centripetal shift for
  291. current level with shape (N, centripetal_shift_channels, H, W).
  292. targets (dict): Corner target generated by `get_targets`.
  293. Returns:
  294. tuple[torch.Tensor]: Losses of the head's different branches
  295. containing the following losses:
  296. - det_loss (Tensor): Corner keypoint loss.
  297. - off_loss (Tensor): Corner offset loss.
  298. - guiding_loss (Tensor): Guiding shift loss.
  299. - centripetal_loss (Tensor): Centripetal shift loss.
  300. """
  301. targets['corner_embedding'] = None
  302. det_loss, _, _, off_loss = super().loss_by_feat_single(
  303. tl_hmp, br_hmp, None, None, tl_off, br_off, targets)
  304. gt_tl_guiding_shift = targets['topleft_guiding_shift']
  305. gt_br_guiding_shift = targets['bottomright_guiding_shift']
  306. gt_tl_centripetal_shift = targets['topleft_centripetal_shift']
  307. gt_br_centripetal_shift = targets['bottomright_centripetal_shift']
  308. gt_tl_heatmap = targets['topleft_heatmap']
  309. gt_br_heatmap = targets['bottomright_heatmap']
  310. # We only compute the offset loss at the real corner position.
  311. # The value of real corner would be 1 in heatmap ground truth.
  312. # The mask is computed in class agnostic mode and its shape is
  313. # batch * 1 * width * height.
  314. tl_mask = gt_tl_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
  315. gt_tl_heatmap)
  316. br_mask = gt_br_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
  317. gt_br_heatmap)
  318. # Guiding shift loss
  319. tl_guiding_loss = self.loss_guiding_shift(
  320. tl_guiding_shift,
  321. gt_tl_guiding_shift,
  322. tl_mask,
  323. avg_factor=tl_mask.sum())
  324. br_guiding_loss = self.loss_guiding_shift(
  325. br_guiding_shift,
  326. gt_br_guiding_shift,
  327. br_mask,
  328. avg_factor=br_mask.sum())
  329. guiding_loss = (tl_guiding_loss + br_guiding_loss) / 2.0
  330. # Centripetal shift loss
  331. tl_centripetal_loss = self.loss_centripetal_shift(
  332. tl_centripetal_shift,
  333. gt_tl_centripetal_shift,
  334. tl_mask,
  335. avg_factor=tl_mask.sum())
  336. br_centripetal_loss = self.loss_centripetal_shift(
  337. br_centripetal_shift,
  338. gt_br_centripetal_shift,
  339. br_mask,
  340. avg_factor=br_mask.sum())
  341. centripetal_loss = (tl_centripetal_loss + br_centripetal_loss) / 2.0
  342. return det_loss, off_loss, guiding_loss, centripetal_loss
  343. def predict_by_feat(self,
  344. tl_heats: List[Tensor],
  345. br_heats: List[Tensor],
  346. tl_offs: List[Tensor],
  347. br_offs: List[Tensor],
  348. tl_guiding_shifts: List[Tensor],
  349. br_guiding_shifts: List[Tensor],
  350. tl_centripetal_shifts: List[Tensor],
  351. br_centripetal_shifts: List[Tensor],
  352. batch_img_metas: Optional[List[dict]] = None,
  353. rescale: bool = False,
  354. with_nms: bool = True) -> InstanceList:
  355. """Transform a batch of output features extracted from the head into
  356. bbox results.
  357. Args:
  358. tl_heats (list[Tensor]): Top-left corner heatmaps for each level
  359. with shape (N, num_classes, H, W).
  360. br_heats (list[Tensor]): Bottom-right corner heatmaps for each
  361. level with shape (N, num_classes, H, W).
  362. tl_offs (list[Tensor]): Top-left corner offsets for each level
  363. with shape (N, corner_offset_channels, H, W).
  364. br_offs (list[Tensor]): Bottom-right corner offsets for each level
  365. with shape (N, corner_offset_channels, H, W).
  366. tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
  367. level with shape (N, guiding_shift_channels, H, W). Useless in
  368. this function, we keep this arg because it's the raw output
  369. from CentripetalHead.
  370. br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
  371. each level with shape (N, guiding_shift_channels, H, W).
  372. Useless in this function, we keep this arg because it's the
  373. raw output from CentripetalHead.
  374. tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
  375. for each level with shape (N, centripetal_shift_channels, H,
  376. W).
  377. br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
  378. shifts for each level with shape (N,
  379. centripetal_shift_channels, H, W).
  380. batch_img_metas (list[dict], optional): Batch image meta info.
  381. Defaults to None.
  382. rescale (bool): If True, return boxes in original image space.
  383. Defaults to False.
  384. with_nms (bool): If True, do nms before return boxes.
  385. Defaults to True.
  386. Returns:
  387. list[:obj:`InstanceData`]: Object detection results of each image
  388. after the post process. Each item usually contains following keys.
  389. - scores (Tensor): Classification scores, has a shape
  390. (num_instance, )
  391. - labels (Tensor): Labels of bboxes, has a shape
  392. (num_instances, ).
  393. - bboxes (Tensor): Has a shape (num_instances, 4),
  394. the last dimension 4 arrange as (x1, y1, x2, y2).
  395. """
  396. assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len(
  397. batch_img_metas)
  398. result_list = []
  399. for img_id in range(len(batch_img_metas)):
  400. result_list.append(
  401. self._predict_by_feat_single(
  402. tl_heats[-1][img_id:img_id + 1, :],
  403. br_heats[-1][img_id:img_id + 1, :],
  404. tl_offs[-1][img_id:img_id + 1, :],
  405. br_offs[-1][img_id:img_id + 1, :],
  406. batch_img_metas[img_id],
  407. tl_emb=None,
  408. br_emb=None,
  409. tl_centripetal_shift=tl_centripetal_shifts[-1][
  410. img_id:img_id + 1, :],
  411. br_centripetal_shift=br_centripetal_shifts[-1][
  412. img_id:img_id + 1, :],
  413. rescale=rescale,
  414. with_nms=with_nms))
  415. return result_list