centernet_head.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.ops import batched_nms
  6. from mmengine.config import ConfigDict
  7. from mmengine.model import bias_init_with_prob, normal_init
  8. from mmengine.structures import InstanceData
  9. from torch import Tensor
  10. from mmdet.registry import MODELS
  11. from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
  12. OptInstanceList, OptMultiConfig)
  13. from ..utils import (gaussian_radius, gen_gaussian_target, get_local_maximum,
  14. get_topk_from_heatmap, multi_apply,
  15. transpose_and_gather_feat)
  16. from .base_dense_head import BaseDenseHead
  17. @MODELS.register_module()
  18. class CenterNetHead(BaseDenseHead):
  19. """Objects as Points Head. CenterHead use center_point to indicate object's
  20. position. Paper link <https://arxiv.org/abs/1904.07850>
  21. Args:
  22. in_channels (int): Number of channel in the input feature map.
  23. feat_channels (int): Number of channel in the intermediate feature map.
  24. num_classes (int): Number of categories excluding the background
  25. category.
  26. loss_center_heatmap (:obj:`ConfigDict` or dict): Config of center
  27. heatmap loss. Defaults to
  28. dict(type='GaussianFocalLoss', loss_weight=1.0)
  29. loss_wh (:obj:`ConfigDict` or dict): Config of wh loss. Defaults to
  30. dict(type='L1Loss', loss_weight=0.1).
  31. loss_offset (:obj:`ConfigDict` or dict): Config of offset loss.
  32. Defaults to dict(type='L1Loss', loss_weight=1.0).
  33. train_cfg (:obj:`ConfigDict` or dict, optional): Training config.
  34. Useless in CenterNet, but we keep this variable for
  35. SingleStageDetector.
  36. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config
  37. of CenterNet.
  38. init_cfg (:obj:`ConfigDict` or dict or list[dict] or
  39. list[:obj:`ConfigDict`], optional): Initialization
  40. config dict.
  41. """
  42. def __init__(self,
  43. in_channels: int,
  44. feat_channels: int,
  45. num_classes: int,
  46. loss_center_heatmap: ConfigType = dict(
  47. type='GaussianFocalLoss', loss_weight=1.0),
  48. loss_wh: ConfigType = dict(type='L1Loss', loss_weight=0.1),
  49. loss_offset: ConfigType = dict(
  50. type='L1Loss', loss_weight=1.0),
  51. train_cfg: OptConfigType = None,
  52. test_cfg: OptConfigType = None,
  53. init_cfg: OptMultiConfig = None) -> None:
  54. super().__init__(init_cfg=init_cfg)
  55. self.num_classes = num_classes
  56. self.heatmap_head = self._build_head(in_channels, feat_channels,
  57. num_classes)
  58. self.wh_head = self._build_head(in_channels, feat_channels, 2)
  59. self.offset_head = self._build_head(in_channels, feat_channels, 2)
  60. self.loss_center_heatmap = MODELS.build(loss_center_heatmap)
  61. self.loss_wh = MODELS.build(loss_wh)
  62. self.loss_offset = MODELS.build(loss_offset)
  63. self.train_cfg = train_cfg
  64. self.test_cfg = test_cfg
  65. self.fp16_enabled = False
  66. def _build_head(self, in_channels: int, feat_channels: int,
  67. out_channels: int) -> nn.Sequential:
  68. """Build head for each branch."""
  69. layer = nn.Sequential(
  70. nn.Conv2d(in_channels, feat_channels, kernel_size=3, padding=1),
  71. nn.ReLU(inplace=True),
  72. nn.Conv2d(feat_channels, out_channels, kernel_size=1))
  73. return layer
  74. def init_weights(self) -> None:
  75. """Initialize weights of the head."""
  76. bias_init = bias_init_with_prob(0.1)
  77. self.heatmap_head[-1].bias.data.fill_(bias_init)
  78. for head in [self.wh_head, self.offset_head]:
  79. for m in head.modules():
  80. if isinstance(m, nn.Conv2d):
  81. normal_init(m, std=0.001)
  82. def forward(self, x: Tuple[Tensor, ...]) -> Tuple[List[Tensor]]:
  83. """Forward features. Notice CenterNet head does not use FPN.
  84. Args:
  85. x (tuple[Tensor]): Features from the upstream network, each is
  86. a 4D-tensor.
  87. Returns:
  88. center_heatmap_preds (list[Tensor]): center predict heatmaps for
  89. all levels, the channels number is num_classes.
  90. wh_preds (list[Tensor]): wh predicts for all levels, the channels
  91. number is 2.
  92. offset_preds (list[Tensor]): offset predicts for all levels, the
  93. channels number is 2.
  94. """
  95. return multi_apply(self.forward_single, x)
  96. def forward_single(self, x: Tensor) -> Tuple[Tensor, ...]:
  97. """Forward feature of a single level.
  98. Args:
  99. x (Tensor): Feature of a single level.
  100. Returns:
  101. center_heatmap_pred (Tensor): center predict heatmaps, the
  102. channels number is num_classes.
  103. wh_pred (Tensor): wh predicts, the channels number is 2.
  104. offset_pred (Tensor): offset predicts, the channels number is 2.
  105. """
  106. center_heatmap_pred = self.heatmap_head(x).sigmoid()
  107. wh_pred = self.wh_head(x)
  108. offset_pred = self.offset_head(x)
  109. return center_heatmap_pred, wh_pred, offset_pred
  110. def loss_by_feat(
  111. self,
  112. center_heatmap_preds: List[Tensor],
  113. wh_preds: List[Tensor],
  114. offset_preds: List[Tensor],
  115. batch_gt_instances: InstanceList,
  116. batch_img_metas: List[dict],
  117. batch_gt_instances_ignore: OptInstanceList = None) -> dict:
  118. """Compute losses of the head.
  119. Args:
  120. center_heatmap_preds (list[Tensor]): center predict heatmaps for
  121. all levels with shape (B, num_classes, H, W).
  122. wh_preds (list[Tensor]): wh predicts for all levels with
  123. shape (B, 2, H, W).
  124. offset_preds (list[Tensor]): offset predicts for all levels
  125. with shape (B, 2, H, W).
  126. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  127. gt_instance. It usually includes ``bboxes`` and ``labels``
  128. attributes.
  129. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  130. image size, scaling factor, etc.
  131. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
  132. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  133. data that is ignored during training and testing.
  134. Defaults to None.
  135. Returns:
  136. dict[str, Tensor]: which has components below:
  137. - loss_center_heatmap (Tensor): loss of center heatmap.
  138. - loss_wh (Tensor): loss of hw heatmap
  139. - loss_offset (Tensor): loss of offset heatmap.
  140. """
  141. assert len(center_heatmap_preds) == len(wh_preds) == len(
  142. offset_preds) == 1
  143. center_heatmap_pred = center_heatmap_preds[0]
  144. wh_pred = wh_preds[0]
  145. offset_pred = offset_preds[0]
  146. gt_bboxes = [
  147. gt_instances.bboxes for gt_instances in batch_gt_instances
  148. ]
  149. gt_labels = [
  150. gt_instances.labels for gt_instances in batch_gt_instances
  151. ]
  152. img_shape = batch_img_metas[0]['batch_input_shape']
  153. target_result, avg_factor = self.get_targets(gt_bboxes, gt_labels,
  154. center_heatmap_pred.shape,
  155. img_shape)
  156. center_heatmap_target = target_result['center_heatmap_target']
  157. wh_target = target_result['wh_target']
  158. offset_target = target_result['offset_target']
  159. wh_offset_target_weight = target_result['wh_offset_target_weight']
  160. # Since the channel of wh_target and offset_target is 2, the avg_factor
  161. # of loss_center_heatmap is always 1/2 of loss_wh and loss_offset.
  162. loss_center_heatmap = self.loss_center_heatmap(
  163. center_heatmap_pred, center_heatmap_target, avg_factor=avg_factor)
  164. loss_wh = self.loss_wh(
  165. wh_pred,
  166. wh_target,
  167. wh_offset_target_weight,
  168. avg_factor=avg_factor * 2)
  169. loss_offset = self.loss_offset(
  170. offset_pred,
  171. offset_target,
  172. wh_offset_target_weight,
  173. avg_factor=avg_factor * 2)
  174. return dict(
  175. loss_center_heatmap=loss_center_heatmap,
  176. loss_wh=loss_wh,
  177. loss_offset=loss_offset)
  178. def get_targets(self, gt_bboxes: List[Tensor], gt_labels: List[Tensor],
  179. feat_shape: tuple, img_shape: tuple) -> Tuple[dict, int]:
  180. """Compute regression and classification targets in multiple images.
  181. Args:
  182. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  183. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  184. gt_labels (list[Tensor]): class indices corresponding to each box.
  185. feat_shape (tuple): feature map shape with value [B, _, H, W]
  186. img_shape (tuple): image shape.
  187. Returns:
  188. tuple[dict, float]: The float value is mean avg_factor, the dict
  189. has components below:
  190. - center_heatmap_target (Tensor): targets of center heatmap, \
  191. shape (B, num_classes, H, W).
  192. - wh_target (Tensor): targets of wh predict, shape \
  193. (B, 2, H, W).
  194. - offset_target (Tensor): targets of offset predict, shape \
  195. (B, 2, H, W).
  196. - wh_offset_target_weight (Tensor): weights of wh and offset \
  197. predict, shape (B, 2, H, W).
  198. """
  199. img_h, img_w = img_shape[:2]
  200. bs, _, feat_h, feat_w = feat_shape
  201. width_ratio = float(feat_w / img_w)
  202. height_ratio = float(feat_h / img_h)
  203. center_heatmap_target = gt_bboxes[-1].new_zeros(
  204. [bs, self.num_classes, feat_h, feat_w])
  205. wh_target = gt_bboxes[-1].new_zeros([bs, 2, feat_h, feat_w])
  206. offset_target = gt_bboxes[-1].new_zeros([bs, 2, feat_h, feat_w])
  207. wh_offset_target_weight = gt_bboxes[-1].new_zeros(
  208. [bs, 2, feat_h, feat_w])
  209. for batch_id in range(bs):
  210. gt_bbox = gt_bboxes[batch_id]
  211. gt_label = gt_labels[batch_id]
  212. center_x = (gt_bbox[:, [0]] + gt_bbox[:, [2]]) * width_ratio / 2
  213. center_y = (gt_bbox[:, [1]] + gt_bbox[:, [3]]) * height_ratio / 2
  214. gt_centers = torch.cat((center_x, center_y), dim=1)
  215. for j, ct in enumerate(gt_centers):
  216. ctx_int, cty_int = ct.int()
  217. ctx, cty = ct
  218. scale_box_h = (gt_bbox[j][3] - gt_bbox[j][1]) * height_ratio
  219. scale_box_w = (gt_bbox[j][2] - gt_bbox[j][0]) * width_ratio
  220. radius = gaussian_radius([scale_box_h, scale_box_w],
  221. min_overlap=0.3)
  222. radius = max(0, int(radius))
  223. ind = gt_label[j]
  224. gen_gaussian_target(center_heatmap_target[batch_id, ind],
  225. [ctx_int, cty_int], radius)
  226. wh_target[batch_id, 0, cty_int, ctx_int] = scale_box_w
  227. wh_target[batch_id, 1, cty_int, ctx_int] = scale_box_h
  228. offset_target[batch_id, 0, cty_int, ctx_int] = ctx - ctx_int
  229. offset_target[batch_id, 1, cty_int, ctx_int] = cty - cty_int
  230. wh_offset_target_weight[batch_id, :, cty_int, ctx_int] = 1
  231. avg_factor = max(1, center_heatmap_target.eq(1).sum())
  232. target_result = dict(
  233. center_heatmap_target=center_heatmap_target,
  234. wh_target=wh_target,
  235. offset_target=offset_target,
  236. wh_offset_target_weight=wh_offset_target_weight)
  237. return target_result, avg_factor
  238. def predict_by_feat(self,
  239. center_heatmap_preds: List[Tensor],
  240. wh_preds: List[Tensor],
  241. offset_preds: List[Tensor],
  242. batch_img_metas: Optional[List[dict]] = None,
  243. rescale: bool = True,
  244. with_nms: bool = False) -> InstanceList:
  245. """Transform network output for a batch into bbox predictions.
  246. Args:
  247. center_heatmap_preds (list[Tensor]): Center predict heatmaps for
  248. all levels with shape (B, num_classes, H, W).
  249. wh_preds (list[Tensor]): WH predicts for all levels with
  250. shape (B, 2, H, W).
  251. offset_preds (list[Tensor]): Offset predicts for all levels
  252. with shape (B, 2, H, W).
  253. batch_img_metas (list[dict], optional): Batch image meta info.
  254. Defaults to None.
  255. rescale (bool): If True, return boxes in original image space.
  256. Defaults to True.
  257. with_nms (bool): If True, do nms before return boxes.
  258. Defaults to False.
  259. Returns:
  260. list[:obj:`InstanceData`]: Instance segmentation
  261. results of each image after the post process.
  262. Each item usually contains following keys.
  263. - scores (Tensor): Classification scores, has a shape
  264. (num_instance, )
  265. - labels (Tensor): Labels of bboxes, has a shape
  266. (num_instances, ).
  267. - bboxes (Tensor): Has a shape (num_instances, 4),
  268. the last dimension 4 arrange as (x1, y1, x2, y2).
  269. """
  270. assert len(center_heatmap_preds) == len(wh_preds) == len(
  271. offset_preds) == 1
  272. result_list = []
  273. for img_id in range(len(batch_img_metas)):
  274. result_list.append(
  275. self._predict_by_feat_single(
  276. center_heatmap_preds[0][img_id:img_id + 1, ...],
  277. wh_preds[0][img_id:img_id + 1, ...],
  278. offset_preds[0][img_id:img_id + 1, ...],
  279. batch_img_metas[img_id],
  280. rescale=rescale,
  281. with_nms=with_nms))
  282. return result_list
  283. def _predict_by_feat_single(self,
  284. center_heatmap_pred: Tensor,
  285. wh_pred: Tensor,
  286. offset_pred: Tensor,
  287. img_meta: dict,
  288. rescale: bool = True,
  289. with_nms: bool = False) -> InstanceData:
  290. """Transform outputs of a single image into bbox results.
  291. Args:
  292. center_heatmap_pred (Tensor): Center heatmap for current level with
  293. shape (1, num_classes, H, W).
  294. wh_pred (Tensor): WH heatmap for current level with shape
  295. (1, num_classes, H, W).
  296. offset_pred (Tensor): Offset for current level with shape
  297. (1, corner_offset_channels, H, W).
  298. img_meta (dict): Meta information of current image, e.g.,
  299. image size, scaling factor, etc.
  300. rescale (bool): If True, return boxes in original image space.
  301. Defaults to True.
  302. with_nms (bool): If True, do nms before return boxes.
  303. Defaults to False.
  304. Returns:
  305. :obj:`InstanceData`: Detection results of each image
  306. after the post process.
  307. Each item usually contains following keys.
  308. - scores (Tensor): Classification scores, has a shape
  309. (num_instance, )
  310. - labels (Tensor): Labels of bboxes, has a shape
  311. (num_instances, ).
  312. - bboxes (Tensor): Has a shape (num_instances, 4),
  313. the last dimension 4 arrange as (x1, y1, x2, y2).
  314. """
  315. batch_det_bboxes, batch_labels = self._decode_heatmap(
  316. center_heatmap_pred,
  317. wh_pred,
  318. offset_pred,
  319. img_meta['batch_input_shape'],
  320. k=self.test_cfg.topk,
  321. kernel=self.test_cfg.local_maximum_kernel)
  322. det_bboxes = batch_det_bboxes.view([-1, 5])
  323. det_labels = batch_labels.view(-1)
  324. batch_border = det_bboxes.new_tensor(img_meta['border'])[...,
  325. [2, 0, 2, 0]]
  326. det_bboxes[..., :4] -= batch_border
  327. if rescale and 'scale_factor' in img_meta:
  328. det_bboxes[..., :4] /= det_bboxes.new_tensor(
  329. img_meta['scale_factor']).repeat((1, 2))
  330. if with_nms:
  331. det_bboxes, det_labels = self._bboxes_nms(det_bboxes, det_labels,
  332. self.test_cfg)
  333. results = InstanceData()
  334. results.bboxes = det_bboxes[..., :4]
  335. results.scores = det_bboxes[..., 4]
  336. results.labels = det_labels
  337. return results
  338. def _decode_heatmap(self,
  339. center_heatmap_pred: Tensor,
  340. wh_pred: Tensor,
  341. offset_pred: Tensor,
  342. img_shape: tuple,
  343. k: int = 100,
  344. kernel: int = 3) -> Tuple[Tensor, Tensor]:
  345. """Transform outputs into detections raw bbox prediction.
  346. Args:
  347. center_heatmap_pred (Tensor): center predict heatmap,
  348. shape (B, num_classes, H, W).
  349. wh_pred (Tensor): wh predict, shape (B, 2, H, W).
  350. offset_pred (Tensor): offset predict, shape (B, 2, H, W).
  351. img_shape (tuple): image shape in hw format.
  352. k (int): Get top k center keypoints from heatmap. Defaults to 100.
  353. kernel (int): Max pooling kernel for extract local maximum pixels.
  354. Defaults to 3.
  355. Returns:
  356. tuple[Tensor]: Decoded output of CenterNetHead, containing
  357. the following Tensors:
  358. - batch_bboxes (Tensor): Coords of each box with shape (B, k, 5)
  359. - batch_topk_labels (Tensor): Categories of each box with \
  360. shape (B, k)
  361. """
  362. height, width = center_heatmap_pred.shape[2:]
  363. inp_h, inp_w = img_shape
  364. center_heatmap_pred = get_local_maximum(
  365. center_heatmap_pred, kernel=kernel)
  366. *batch_dets, topk_ys, topk_xs = get_topk_from_heatmap(
  367. center_heatmap_pred, k=k)
  368. batch_scores, batch_index, batch_topk_labels = batch_dets
  369. wh = transpose_and_gather_feat(wh_pred, batch_index)
  370. offset = transpose_and_gather_feat(offset_pred, batch_index)
  371. topk_xs = topk_xs + offset[..., 0]
  372. topk_ys = topk_ys + offset[..., 1]
  373. tl_x = (topk_xs - wh[..., 0] / 2) * (inp_w / width)
  374. tl_y = (topk_ys - wh[..., 1] / 2) * (inp_h / height)
  375. br_x = (topk_xs + wh[..., 0] / 2) * (inp_w / width)
  376. br_y = (topk_ys + wh[..., 1] / 2) * (inp_h / height)
  377. batch_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=2)
  378. batch_bboxes = torch.cat((batch_bboxes, batch_scores[..., None]),
  379. dim=-1)
  380. return batch_bboxes, batch_topk_labels
  381. def _bboxes_nms(self, bboxes: Tensor, labels: Tensor,
  382. cfg: ConfigDict) -> Tuple[Tensor, Tensor]:
  383. """bboxes nms."""
  384. if labels.numel() > 0:
  385. max_num = cfg.max_per_img
  386. bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:,
  387. -1].contiguous(),
  388. labels, cfg.nms)
  389. if max_num > 0:
  390. bboxes = bboxes[:max_num]
  391. labels = labels[keep][:max_num]
  392. return bboxes, labels