rtmdet_head.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule, Scale, is_norm
  6. from mmengine.model import bias_init_with_prob, constant_init, normal_init
  7. from mmengine.structures import InstanceData
  8. from torch import Tensor
  9. from mmdet.registry import MODELS, TASK_UTILS
  10. from mmdet.structures.bbox import distance2bbox
  11. from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean
  12. from ..layers.transformer import inverse_sigmoid
  13. from ..task_modules import anchor_inside_flags
  14. from ..utils import (images_to_levels, multi_apply, sigmoid_geometric_mean,
  15. unmap)
  16. from .atss_head import ATSSHead
  17. @MODELS.register_module()
  18. class RTMDetHead(ATSSHead):
  19. """Detection Head of RTMDet.
  20. Args:
  21. num_classes (int): Number of categories excluding the background
  22. category.
  23. in_channels (int): Number of channels in the input feature map.
  24. with_objectness (bool): Whether to add an objectness branch.
  25. Defaults to True.
  26. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
  27. Default: dict(type='ReLU')
  28. """
  29. def __init__(self,
  30. num_classes: int,
  31. in_channels: int,
  32. with_objectness: bool = True,
  33. act_cfg: ConfigType = dict(type='ReLU'),
  34. **kwargs) -> None:
  35. self.act_cfg = act_cfg
  36. self.with_objectness = with_objectness
  37. super().__init__(num_classes, in_channels, **kwargs)
  38. if self.train_cfg:
  39. self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
  40. def _init_layers(self):
  41. """Initialize layers of the head."""
  42. self.cls_convs = nn.ModuleList()
  43. self.reg_convs = nn.ModuleList()
  44. for i in range(self.stacked_convs):
  45. chn = self.in_channels if i == 0 else self.feat_channels
  46. self.cls_convs.append(
  47. ConvModule(
  48. chn,
  49. self.feat_channels,
  50. 3,
  51. stride=1,
  52. padding=1,
  53. conv_cfg=self.conv_cfg,
  54. norm_cfg=self.norm_cfg,
  55. act_cfg=self.act_cfg))
  56. self.reg_convs.append(
  57. ConvModule(
  58. chn,
  59. self.feat_channels,
  60. 3,
  61. stride=1,
  62. padding=1,
  63. conv_cfg=self.conv_cfg,
  64. norm_cfg=self.norm_cfg,
  65. act_cfg=self.act_cfg))
  66. pred_pad_size = self.pred_kernel_size // 2
  67. self.rtm_cls = nn.Conv2d(
  68. self.feat_channels,
  69. self.num_base_priors * self.cls_out_channels,
  70. self.pred_kernel_size,
  71. padding=pred_pad_size)
  72. self.rtm_reg = nn.Conv2d(
  73. self.feat_channels,
  74. self.num_base_priors * 4,
  75. self.pred_kernel_size,
  76. padding=pred_pad_size)
  77. if self.with_objectness:
  78. self.rtm_obj = nn.Conv2d(
  79. self.feat_channels,
  80. 1,
  81. self.pred_kernel_size,
  82. padding=pred_pad_size)
  83. self.scales = nn.ModuleList(
  84. [Scale(1.0) for _ in self.prior_generator.strides])
  85. def init_weights(self) -> None:
  86. """Initialize weights of the head."""
  87. for m in self.modules():
  88. if isinstance(m, nn.Conv2d):
  89. normal_init(m, mean=0, std=0.01)
  90. if is_norm(m):
  91. constant_init(m, 1)
  92. bias_cls = bias_init_with_prob(0.01)
  93. normal_init(self.rtm_cls, std=0.01, bias=bias_cls)
  94. normal_init(self.rtm_reg, std=0.01)
  95. if self.with_objectness:
  96. normal_init(self.rtm_obj, std=0.01, bias=bias_cls)
  97. def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
  98. """Forward features from the upstream network.
  99. Args:
  100. feats (tuple[Tensor]): Features from the upstream network, each is
  101. a 4D-tensor.
  102. Returns:
  103. tuple: Usually a tuple of classification scores and bbox prediction
  104. - cls_scores (list[Tensor]): Classification scores for all scale
  105. levels, each is a 4D-tensor, the channels number is
  106. num_base_priors * num_classes.
  107. - bbox_preds (list[Tensor]): Box energies / deltas for all scale
  108. levels, each is a 4D-tensor, the channels number is
  109. num_base_priors * 4.
  110. """
  111. cls_scores = []
  112. bbox_preds = []
  113. for idx, (x, scale, stride) in enumerate(
  114. zip(feats, self.scales, self.prior_generator.strides)):
  115. cls_feat = x
  116. reg_feat = x
  117. for cls_layer in self.cls_convs:
  118. cls_feat = cls_layer(cls_feat)
  119. cls_score = self.rtm_cls(cls_feat)
  120. for reg_layer in self.reg_convs:
  121. reg_feat = reg_layer(reg_feat)
  122. if self.with_objectness:
  123. objectness = self.rtm_obj(reg_feat)
  124. cls_score = inverse_sigmoid(
  125. sigmoid_geometric_mean(cls_score, objectness))
  126. reg_dist = scale(self.rtm_reg(reg_feat).exp()).float() * stride[0]
  127. cls_scores.append(cls_score)
  128. bbox_preds.append(reg_dist)
  129. return tuple(cls_scores), tuple(bbox_preds)
  130. def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
  131. labels: Tensor, label_weights: Tensor,
  132. bbox_targets: Tensor, assign_metrics: Tensor,
  133. stride: List[int]):
  134. """Compute loss of a single scale level.
  135. Args:
  136. cls_score (Tensor): Box scores for each scale level
  137. Has shape (N, num_anchors * num_classes, H, W).
  138. bbox_pred (Tensor): Decoded bboxes for each scale
  139. level with shape (N, num_anchors * 4, H, W).
  140. labels (Tensor): Labels of each anchors with shape
  141. (N, num_total_anchors).
  142. label_weights (Tensor): Label weights of each anchor with shape
  143. (N, num_total_anchors).
  144. bbox_targets (Tensor): BBox regression targets of each anchor with
  145. shape (N, num_total_anchors, 4).
  146. assign_metrics (Tensor): Assign metrics with shape
  147. (N, num_total_anchors).
  148. stride (List[int]): Downsample stride of the feature map.
  149. Returns:
  150. dict[str, Tensor]: A dictionary of loss components.
  151. """
  152. assert stride[0] == stride[1], 'h stride is not equal to w stride!'
  153. cls_score = cls_score.permute(0, 2, 3, 1).reshape(
  154. -1, self.cls_out_channels).contiguous()
  155. bbox_pred = bbox_pred.reshape(-1, 4)
  156. bbox_targets = bbox_targets.reshape(-1, 4)
  157. labels = labels.reshape(-1)
  158. assign_metrics = assign_metrics.reshape(-1)
  159. label_weights = label_weights.reshape(-1)
  160. targets = (labels, assign_metrics)
  161. loss_cls = self.loss_cls(
  162. cls_score, targets, label_weights, avg_factor=1.0)
  163. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  164. bg_class_ind = self.num_classes
  165. pos_inds = ((labels >= 0)
  166. & (labels < bg_class_ind)).nonzero().squeeze(1)
  167. if len(pos_inds) > 0:
  168. pos_bbox_targets = bbox_targets[pos_inds]
  169. pos_bbox_pred = bbox_pred[pos_inds]
  170. pos_decode_bbox_pred = pos_bbox_pred
  171. pos_decode_bbox_targets = pos_bbox_targets
  172. # regression loss
  173. pos_bbox_weight = assign_metrics[pos_inds]
  174. loss_bbox = self.loss_bbox(
  175. pos_decode_bbox_pred,
  176. pos_decode_bbox_targets,
  177. weight=pos_bbox_weight,
  178. avg_factor=1.0)
  179. else:
  180. loss_bbox = bbox_pred.sum() * 0
  181. pos_bbox_weight = bbox_targets.new_tensor(0.)
  182. return loss_cls, loss_bbox, assign_metrics.sum(), pos_bbox_weight.sum()
  183. def loss_by_feat(self,
  184. cls_scores: List[Tensor],
  185. bbox_preds: List[Tensor],
  186. batch_gt_instances: InstanceList,
  187. batch_img_metas: List[dict],
  188. batch_gt_instances_ignore: OptInstanceList = None):
  189. """Compute losses of the head.
  190. Args:
  191. cls_scores (list[Tensor]): Box scores for each scale level
  192. Has shape (N, num_anchors * num_classes, H, W)
  193. bbox_preds (list[Tensor]): Decoded box for each scale
  194. level with shape (N, num_anchors * 4, H, W) in
  195. [tl_x, tl_y, br_x, br_y] format.
  196. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  197. gt_instance. It usually includes ``bboxes`` and ``labels``
  198. attributes.
  199. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  200. image size, scaling factor, etc.
  201. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  202. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  203. data that is ignored during training and testing.
  204. Defaults to None.
  205. Returns:
  206. dict[str, Tensor]: A dictionary of loss components.
  207. """
  208. num_imgs = len(batch_img_metas)
  209. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  210. assert len(featmap_sizes) == self.prior_generator.num_levels
  211. device = cls_scores[0].device
  212. anchor_list, valid_flag_list = self.get_anchors(
  213. featmap_sizes, batch_img_metas, device=device)
  214. flatten_cls_scores = torch.cat([
  215. cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
  216. self.cls_out_channels)
  217. for cls_score in cls_scores
  218. ], 1)
  219. decoded_bboxes = []
  220. for anchor, bbox_pred in zip(anchor_list[0], bbox_preds):
  221. anchor = anchor.reshape(-1, 4)
  222. bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
  223. bbox_pred = distance2bbox(anchor, bbox_pred)
  224. decoded_bboxes.append(bbox_pred)
  225. flatten_bboxes = torch.cat(decoded_bboxes, 1)
  226. cls_reg_targets = self.get_targets(
  227. flatten_cls_scores,
  228. flatten_bboxes,
  229. anchor_list,
  230. valid_flag_list,
  231. batch_gt_instances,
  232. batch_img_metas,
  233. batch_gt_instances_ignore=batch_gt_instances_ignore)
  234. (anchor_list, labels_list, label_weights_list, bbox_targets_list,
  235. assign_metrics_list, sampling_results_list) = cls_reg_targets
  236. losses_cls, losses_bbox,\
  237. cls_avg_factors, bbox_avg_factors = multi_apply(
  238. self.loss_by_feat_single,
  239. cls_scores,
  240. decoded_bboxes,
  241. labels_list,
  242. label_weights_list,
  243. bbox_targets_list,
  244. assign_metrics_list,
  245. self.prior_generator.strides)
  246. cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
  247. losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
  248. bbox_avg_factor = reduce_mean(
  249. sum(bbox_avg_factors)).clamp_(min=1).item()
  250. losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
  251. return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
  252. def get_targets(self,
  253. cls_scores: Tensor,
  254. bbox_preds: Tensor,
  255. anchor_list: List[List[Tensor]],
  256. valid_flag_list: List[List[Tensor]],
  257. batch_gt_instances: InstanceList,
  258. batch_img_metas: List[dict],
  259. batch_gt_instances_ignore: OptInstanceList = None,
  260. unmap_outputs=True):
  261. """Compute regression and classification targets for anchors in
  262. multiple images.
  263. Args:
  264. cls_scores (Tensor): Classification predictions of images,
  265. a 3D-Tensor with shape [num_imgs, num_priors, num_classes].
  266. bbox_preds (Tensor): Decoded bboxes predictions of one image,
  267. a 3D-Tensor with shape [num_imgs, num_priors, 4] in [tl_x,
  268. tl_y, br_x, br_y] format.
  269. anchor_list (list[list[Tensor]]): Multi level anchors of each
  270. image. The outer list indicates images, and the inner list
  271. corresponds to feature levels of the image. Each element of
  272. the inner list is a tensor of shape (num_anchors, 4).
  273. valid_flag_list (list[list[Tensor]]): Multi level valid flags of
  274. each image. The outer list indicates images, and the inner list
  275. corresponds to feature levels of the image. Each element of
  276. the inner list is a tensor of shape (num_anchors, )
  277. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  278. gt_instance. It usually includes ``bboxes`` and ``labels``
  279. attributes.
  280. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  281. image size, scaling factor, etc.
  282. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  283. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  284. data that is ignored during training and testing.
  285. Defaults to None.
  286. unmap_outputs (bool): Whether to map outputs back to the original
  287. set of anchors. Defaults to True.
  288. Returns:
  289. tuple: a tuple containing learning targets.
  290. - anchors_list (list[list[Tensor]]): Anchors of each level.
  291. - labels_list (list[Tensor]): Labels of each level.
  292. - label_weights_list (list[Tensor]): Label weights of each
  293. level.
  294. - bbox_targets_list (list[Tensor]): BBox targets of each level.
  295. - assign_metrics_list (list[Tensor]): alignment metrics of each
  296. level.
  297. """
  298. num_imgs = len(batch_img_metas)
  299. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  300. # anchor number of multi levels
  301. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  302. # concat all level anchors and flags to a single tensor
  303. for i in range(num_imgs):
  304. assert len(anchor_list[i]) == len(valid_flag_list[i])
  305. anchor_list[i] = torch.cat(anchor_list[i])
  306. valid_flag_list[i] = torch.cat(valid_flag_list[i])
  307. # compute targets for each image
  308. if batch_gt_instances_ignore is None:
  309. batch_gt_instances_ignore = [None] * num_imgs
  310. # anchor_list: list(b * [-1, 4])
  311. (all_anchors, all_labels, all_label_weights, all_bbox_targets,
  312. all_assign_metrics, sampling_results_list) = multi_apply(
  313. self._get_targets_single,
  314. cls_scores.detach(),
  315. bbox_preds.detach(),
  316. anchor_list,
  317. valid_flag_list,
  318. batch_gt_instances,
  319. batch_img_metas,
  320. batch_gt_instances_ignore,
  321. unmap_outputs=unmap_outputs)
  322. # no valid anchors
  323. if any([labels is None for labels in all_labels]):
  324. return None
  325. # split targets to a list w.r.t. multiple levels
  326. anchors_list = images_to_levels(all_anchors, num_level_anchors)
  327. labels_list = images_to_levels(all_labels, num_level_anchors)
  328. label_weights_list = images_to_levels(all_label_weights,
  329. num_level_anchors)
  330. bbox_targets_list = images_to_levels(all_bbox_targets,
  331. num_level_anchors)
  332. assign_metrics_list = images_to_levels(all_assign_metrics,
  333. num_level_anchors)
  334. return (anchors_list, labels_list, label_weights_list,
  335. bbox_targets_list, assign_metrics_list, sampling_results_list)
  336. def _get_targets_single(self,
  337. cls_scores: Tensor,
  338. bbox_preds: Tensor,
  339. flat_anchors: Tensor,
  340. valid_flags: Tensor,
  341. gt_instances: InstanceData,
  342. img_meta: dict,
  343. gt_instances_ignore: Optional[InstanceData] = None,
  344. unmap_outputs=True):
  345. """Compute regression, classification targets for anchors in a single
  346. image.
  347. Args:
  348. cls_scores (list(Tensor)): Box scores for each image.
  349. bbox_preds (list(Tensor)): Box energies / deltas for each image.
  350. flat_anchors (Tensor): Multi-level anchors of the image, which are
  351. concatenated into a single tensor of shape (num_anchors ,4)
  352. valid_flags (Tensor): Multi level valid flags of the image,
  353. which are concatenated into a single tensor of
  354. shape (num_anchors,).
  355. gt_instances (:obj:`InstanceData`): Ground truth of instance
  356. annotations. It usually includes ``bboxes`` and ``labels``
  357. attributes.
  358. img_meta (dict): Meta information for current image.
  359. gt_instances_ignore (:obj:`InstanceData`, optional): Instances
  360. to be ignored during training. It includes ``bboxes`` attribute
  361. data that is ignored during training and testing.
  362. Defaults to None.
  363. unmap_outputs (bool): Whether to map outputs back to the original
  364. set of anchors. Defaults to True.
  365. Returns:
  366. tuple: N is the number of total anchors in the image.
  367. - anchors (Tensor): All anchors in the image with shape (N, 4).
  368. - labels (Tensor): Labels of all anchors in the image with shape
  369. (N,).
  370. - label_weights (Tensor): Label weights of all anchor in the
  371. image with shape (N,).
  372. - bbox_targets (Tensor): BBox targets of all anchors in the
  373. image with shape (N, 4).
  374. - norm_alignment_metrics (Tensor): Normalized alignment metrics
  375. of all priors in the image with shape (N,).
  376. """
  377. inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
  378. img_meta['img_shape'][:2],
  379. self.train_cfg['allowed_border'])
  380. if not inside_flags.any():
  381. return (None, ) * 7
  382. # assign gt and sample anchors
  383. anchors = flat_anchors[inside_flags, :]
  384. pred_instances = InstanceData(
  385. scores=cls_scores[inside_flags, :],
  386. bboxes=bbox_preds[inside_flags, :],
  387. priors=anchors)
  388. assign_result = self.assigner.assign(pred_instances, gt_instances,
  389. gt_instances_ignore)
  390. sampling_result = self.sampler.sample(assign_result, pred_instances,
  391. gt_instances)
  392. num_valid_anchors = anchors.shape[0]
  393. bbox_targets = torch.zeros_like(anchors)
  394. labels = anchors.new_full((num_valid_anchors, ),
  395. self.num_classes,
  396. dtype=torch.long)
  397. label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
  398. assign_metrics = anchors.new_zeros(
  399. num_valid_anchors, dtype=torch.float)
  400. pos_inds = sampling_result.pos_inds
  401. neg_inds = sampling_result.neg_inds
  402. if len(pos_inds) > 0:
  403. # point-based
  404. pos_bbox_targets = sampling_result.pos_gt_bboxes
  405. bbox_targets[pos_inds, :] = pos_bbox_targets
  406. labels[pos_inds] = sampling_result.pos_gt_labels
  407. if self.train_cfg['pos_weight'] <= 0:
  408. label_weights[pos_inds] = 1.0
  409. else:
  410. label_weights[pos_inds] = self.train_cfg['pos_weight']
  411. if len(neg_inds) > 0:
  412. label_weights[neg_inds] = 1.0
  413. class_assigned_gt_inds = torch.unique(
  414. sampling_result.pos_assigned_gt_inds)
  415. for gt_inds in class_assigned_gt_inds:
  416. gt_class_inds = pos_inds[sampling_result.pos_assigned_gt_inds ==
  417. gt_inds]
  418. assign_metrics[gt_class_inds] = assign_result.max_overlaps[
  419. gt_class_inds]
  420. # map up to original set of anchors
  421. if unmap_outputs:
  422. num_total_anchors = flat_anchors.size(0)
  423. anchors = unmap(anchors, num_total_anchors, inside_flags)
  424. labels = unmap(
  425. labels, num_total_anchors, inside_flags, fill=self.num_classes)
  426. label_weights = unmap(label_weights, num_total_anchors,
  427. inside_flags)
  428. bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
  429. assign_metrics = unmap(assign_metrics, num_total_anchors,
  430. inside_flags)
  431. return (anchors, labels, label_weights, bbox_targets, assign_metrics,
  432. sampling_result)
  433. def get_anchors(self,
  434. featmap_sizes: List[tuple],
  435. batch_img_metas: List[dict],
  436. device: Union[torch.device, str] = 'cuda') \
  437. -> Tuple[List[List[Tensor]], List[List[Tensor]]]:
  438. """Get anchors according to feature map sizes.
  439. Args:
  440. featmap_sizes (list[tuple]): Multi-level feature map sizes.
  441. batch_img_metas (list[dict]): Image meta info.
  442. device (torch.device or str): Device for returned tensors.
  443. Defaults to cuda.
  444. Returns:
  445. tuple:
  446. - anchor_list (list[list[Tensor]]): Anchors of each image.
  447. - valid_flag_list (list[list[Tensor]]): Valid flags of each
  448. image.
  449. """
  450. num_imgs = len(batch_img_metas)
  451. # since feature map sizes of all images are the same, we only compute
  452. # anchors for one time
  453. multi_level_anchors = self.prior_generator.grid_priors(
  454. featmap_sizes, device=device, with_stride=True)
  455. anchor_list = [multi_level_anchors for _ in range(num_imgs)]
  456. # for each image, we compute valid flags of multi level anchors
  457. valid_flag_list = []
  458. for img_id, img_meta in enumerate(batch_img_metas):
  459. multi_level_flags = self.prior_generator.valid_flags(
  460. featmap_sizes, img_meta['pad_shape'], device)
  461. valid_flag_list.append(multi_level_flags)
  462. return anchor_list, valid_flag_list
  463. @MODELS.register_module()
  464. class RTMDetSepBNHead(RTMDetHead):
  465. """RTMDetHead with separated BN layers and shared conv layers.
  466. Args:
  467. num_classes (int): Number of categories excluding the background
  468. category.
  469. in_channels (int): Number of channels in the input feature map.
  470. share_conv (bool): Whether to share conv layers between stages.
  471. Defaults to True.
  472. use_depthwise (bool): Whether to use depthwise separable convolution in
  473. head. Defaults to False.
  474. norm_cfg (:obj:`ConfigDict` or dict)): Config dict for normalization
  475. layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
  476. act_cfg (:obj:`ConfigDict` or dict)): Config dict for activation layer.
  477. Defaults to dict(type='SiLU').
  478. pred_kernel_size (int): Kernel size of prediction layer. Defaults to 1.
  479. """
  480. def __init__(self,
  481. num_classes: int,
  482. in_channels: int,
  483. share_conv: bool = True,
  484. use_depthwise: bool = False,
  485. norm_cfg: ConfigType = dict(
  486. type='BN', momentum=0.03, eps=0.001),
  487. act_cfg: ConfigType = dict(type='SiLU'),
  488. pred_kernel_size: int = 1,
  489. exp_on_reg=False,
  490. **kwargs) -> None:
  491. self.share_conv = share_conv
  492. self.exp_on_reg = exp_on_reg
  493. self.use_depthwise = use_depthwise
  494. super().__init__(
  495. num_classes,
  496. in_channels,
  497. norm_cfg=norm_cfg,
  498. act_cfg=act_cfg,
  499. pred_kernel_size=pred_kernel_size,
  500. **kwargs)
  501. def _init_layers(self) -> None:
  502. """Initialize layers of the head."""
  503. conv = DepthwiseSeparableConvModule \
  504. if self.use_depthwise else ConvModule
  505. self.cls_convs = nn.ModuleList()
  506. self.reg_convs = nn.ModuleList()
  507. self.rtm_cls = nn.ModuleList()
  508. self.rtm_reg = nn.ModuleList()
  509. if self.with_objectness:
  510. self.rtm_obj = nn.ModuleList()
  511. for n in range(len(self.prior_generator.strides)):
  512. cls_convs = nn.ModuleList()
  513. reg_convs = nn.ModuleList()
  514. for i in range(self.stacked_convs):
  515. chn = self.in_channels if i == 0 else self.feat_channels
  516. cls_convs.append(
  517. conv(
  518. chn,
  519. self.feat_channels,
  520. 3,
  521. stride=1,
  522. padding=1,
  523. conv_cfg=self.conv_cfg,
  524. norm_cfg=self.norm_cfg,
  525. act_cfg=self.act_cfg))
  526. reg_convs.append(
  527. conv(
  528. chn,
  529. self.feat_channels,
  530. 3,
  531. stride=1,
  532. padding=1,
  533. conv_cfg=self.conv_cfg,
  534. norm_cfg=self.norm_cfg,
  535. act_cfg=self.act_cfg))
  536. self.cls_convs.append(cls_convs)
  537. self.reg_convs.append(reg_convs)
  538. self.rtm_cls.append(
  539. nn.Conv2d(
  540. self.feat_channels,
  541. self.num_base_priors * self.cls_out_channels,
  542. self.pred_kernel_size,
  543. padding=self.pred_kernel_size // 2))
  544. self.rtm_reg.append(
  545. nn.Conv2d(
  546. self.feat_channels,
  547. self.num_base_priors * 4,
  548. self.pred_kernel_size,
  549. padding=self.pred_kernel_size // 2))
  550. if self.with_objectness:
  551. self.rtm_obj.append(
  552. nn.Conv2d(
  553. self.feat_channels,
  554. 1,
  555. self.pred_kernel_size,
  556. padding=self.pred_kernel_size // 2))
  557. if self.share_conv:
  558. for n in range(len(self.prior_generator.strides)):
  559. for i in range(self.stacked_convs):
  560. self.cls_convs[n][i].conv = self.cls_convs[0][i].conv
  561. self.reg_convs[n][i].conv = self.reg_convs[0][i].conv
  562. def init_weights(self) -> None:
  563. """Initialize weights of the head."""
  564. for m in self.modules():
  565. if isinstance(m, nn.Conv2d):
  566. normal_init(m, mean=0, std=0.01)
  567. if is_norm(m):
  568. constant_init(m, 1)
  569. bias_cls = bias_init_with_prob(0.01)
  570. for rtm_cls, rtm_reg in zip(self.rtm_cls, self.rtm_reg):
  571. normal_init(rtm_cls, std=0.01, bias=bias_cls)
  572. normal_init(rtm_reg, std=0.01)
  573. if self.with_objectness:
  574. for rtm_obj in self.rtm_obj:
  575. normal_init(rtm_obj, std=0.01, bias=bias_cls)
  576. def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
  577. """Forward features from the upstream network.
  578. Args:
  579. feats (tuple[Tensor]): Features from the upstream network, each is
  580. a 4D-tensor.
  581. Returns:
  582. tuple: Usually a tuple of classification scores and bbox prediction
  583. - cls_scores (tuple[Tensor]): Classification scores for all scale
  584. levels, each is a 4D-tensor, the channels number is
  585. num_anchors * num_classes.
  586. - bbox_preds (tuple[Tensor]): Box energies / deltas for all scale
  587. levels, each is a 4D-tensor, the channels number is
  588. num_anchors * 4.
  589. """
  590. cls_scores = []
  591. bbox_preds = []
  592. for idx, (x, stride) in enumerate(
  593. zip(feats, self.prior_generator.strides)):
  594. cls_feat = x
  595. reg_feat = x
  596. for cls_layer in self.cls_convs[idx]:
  597. cls_feat = cls_layer(cls_feat)
  598. cls_score = self.rtm_cls[idx](cls_feat)
  599. for reg_layer in self.reg_convs[idx]:
  600. reg_feat = reg_layer(reg_feat)
  601. if self.with_objectness:
  602. objectness = self.rtm_obj[idx](reg_feat)
  603. cls_score = inverse_sigmoid(
  604. sigmoid_geometric_mean(cls_score, objectness))
  605. if self.exp_on_reg:
  606. reg_dist = self.rtm_reg[idx](reg_feat).exp() * stride[0]
  607. else:
  608. reg_dist = self.rtm_reg[idx](reg_feat) * stride[0]
  609. cls_scores.append(cls_score)
  610. bbox_preds.append(reg_dist)
  611. return tuple(cls_scores), tuple(bbox_preds)