rtmdet_ins_head.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import math
  4. from typing import List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from mmcv.cnn import ConvModule, is_norm
  9. from mmcv.ops import batched_nms
  10. from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
  11. normal_init)
  12. from mmengine.structures import InstanceData
  13. from torch import Tensor
  14. from mmdet.models.layers.transformer import inverse_sigmoid
  15. from mmdet.models.utils import (filter_scores_and_topk, multi_apply,
  16. select_single_mlvl, sigmoid_geometric_mean)
  17. from mmdet.registry import MODELS
  18. from mmdet.structures.bbox import (cat_boxes, distance2bbox, get_box_tensor,
  19. get_box_wh, scale_boxes)
  20. from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean
  21. from .rtmdet_head import RTMDetHead
  22. @MODELS.register_module()
  23. class RTMDetInsHead(RTMDetHead):
  24. """Detection Head of RTMDet-Ins.
  25. Args:
  26. num_prototypes (int): Number of mask prototype features extracted
  27. from the mask head. Defaults to 8.
  28. dyconv_channels (int): Channel of the dynamic conv layers.
  29. Defaults to 8.
  30. num_dyconvs (int): Number of the dynamic convolution layers.
  31. Defaults to 3.
  32. mask_loss_stride (int): Down sample stride of the masks for loss
  33. computation. Defaults to 4.
  34. loss_mask (:obj:`ConfigDict` or dict): Config dict for mask loss.
  35. """
  36. def __init__(self,
  37. *args,
  38. num_prototypes: int = 8,
  39. dyconv_channels: int = 8,
  40. num_dyconvs: int = 3,
  41. mask_loss_stride: int = 4,
  42. loss_mask=dict(
  43. type='DiceLoss',
  44. loss_weight=2.0,
  45. eps=5e-6,
  46. reduction='mean'),
  47. **kwargs) -> None:
  48. self.num_prototypes = num_prototypes
  49. self.num_dyconvs = num_dyconvs
  50. self.dyconv_channels = dyconv_channels
  51. self.mask_loss_stride = mask_loss_stride
  52. super().__init__(*args, **kwargs)
  53. self.loss_mask = MODELS.build(loss_mask)
  54. def _init_layers(self) -> None:
  55. """Initialize layers of the head."""
  56. super()._init_layers()
  57. # a branch to predict kernels of dynamic convs
  58. self.kernel_convs = nn.ModuleList()
  59. # calculate num dynamic parameters
  60. weight_nums, bias_nums = [], []
  61. for i in range(self.num_dyconvs):
  62. if i == 0:
  63. weight_nums.append(
  64. # mask prototype and coordinate features
  65. (self.num_prototypes + 2) * self.dyconv_channels)
  66. bias_nums.append(self.dyconv_channels * 1)
  67. elif i == self.num_dyconvs - 1:
  68. weight_nums.append(self.dyconv_channels * 1)
  69. bias_nums.append(1)
  70. else:
  71. weight_nums.append(self.dyconv_channels * self.dyconv_channels)
  72. bias_nums.append(self.dyconv_channels * 1)
  73. self.weight_nums = weight_nums
  74. self.bias_nums = bias_nums
  75. self.num_gen_params = sum(weight_nums) + sum(bias_nums)
  76. for i in range(self.stacked_convs):
  77. chn = self.in_channels if i == 0 else self.feat_channels
  78. self.kernel_convs.append(
  79. ConvModule(
  80. chn,
  81. self.feat_channels,
  82. 3,
  83. stride=1,
  84. padding=1,
  85. conv_cfg=self.conv_cfg,
  86. norm_cfg=self.norm_cfg,
  87. act_cfg=self.act_cfg))
  88. pred_pad_size = self.pred_kernel_size // 2
  89. self.rtm_kernel = nn.Conv2d(
  90. self.feat_channels,
  91. self.num_gen_params,
  92. self.pred_kernel_size,
  93. padding=pred_pad_size)
  94. self.mask_head = MaskFeatModule(
  95. in_channels=self.in_channels,
  96. feat_channels=self.feat_channels,
  97. stacked_convs=4,
  98. num_levels=len(self.prior_generator.strides),
  99. num_prototypes=self.num_prototypes,
  100. act_cfg=self.act_cfg,
  101. norm_cfg=self.norm_cfg)
  102. def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
  103. """Forward features from the upstream network.
  104. Args:
  105. feats (tuple[Tensor]): Features from the upstream network, each is
  106. a 4D-tensor.
  107. Returns:
  108. tuple: Usually a tuple of classification scores and bbox prediction
  109. - cls_scores (list[Tensor]): Classification scores for all scale
  110. levels, each is a 4D-tensor, the channels number is
  111. num_base_priors * num_classes.
  112. - bbox_preds (list[Tensor]): Box energies / deltas for all scale
  113. levels, each is a 4D-tensor, the channels number is
  114. num_base_priors * 4.
  115. - kernel_preds (list[Tensor]): Dynamic conv kernels for all scale
  116. levels, each is a 4D-tensor, the channels number is
  117. num_gen_params.
  118. - mask_feat (Tensor): Output feature of the mask head. Each is a
  119. 4D-tensor, the channels number is num_prototypes.
  120. """
  121. mask_feat = self.mask_head(feats)
  122. cls_scores = []
  123. bbox_preds = []
  124. kernel_preds = []
  125. for idx, (x, scale, stride) in enumerate(
  126. zip(feats, self.scales, self.prior_generator.strides)):
  127. cls_feat = x
  128. reg_feat = x
  129. kernel_feat = x
  130. for cls_layer in self.cls_convs:
  131. cls_feat = cls_layer(cls_feat)
  132. cls_score = self.rtm_cls(cls_feat)
  133. for kernel_layer in self.kernel_convs:
  134. kernel_feat = kernel_layer(kernel_feat)
  135. kernel_pred = self.rtm_kernel(kernel_feat)
  136. for reg_layer in self.reg_convs:
  137. reg_feat = reg_layer(reg_feat)
  138. if self.with_objectness:
  139. objectness = self.rtm_obj(reg_feat)
  140. cls_score = inverse_sigmoid(
  141. sigmoid_geometric_mean(cls_score, objectness))
  142. reg_dist = scale(self.rtm_reg(reg_feat)) * stride[0]
  143. cls_scores.append(cls_score)
  144. bbox_preds.append(reg_dist)
  145. kernel_preds.append(kernel_pred)
  146. return tuple(cls_scores), tuple(bbox_preds), tuple(
  147. kernel_preds), mask_feat
  148. def predict_by_feat(self,
  149. cls_scores: List[Tensor],
  150. bbox_preds: List[Tensor],
  151. kernel_preds: List[Tensor],
  152. mask_feat: Tensor,
  153. score_factors: Optional[List[Tensor]] = None,
  154. batch_img_metas: Optional[List[dict]] = None,
  155. cfg: Optional[ConfigType] = None,
  156. rescale: bool = False,
  157. with_nms: bool = True) -> InstanceList:
  158. """Transform a batch of output features extracted from the head into
  159. bbox results.
  160. Note: When score_factors is not None, the cls_scores are
  161. usually multiplied by it then obtain the real score used in NMS,
  162. such as CenterNess in FCOS, IoU branch in ATSS.
  163. Args:
  164. cls_scores (list[Tensor]): Classification scores for all
  165. scale levels, each is a 4D-tensor, has shape
  166. (batch_size, num_priors * num_classes, H, W).
  167. bbox_preds (list[Tensor]): Box energies / deltas for all
  168. scale levels, each is a 4D-tensor, has shape
  169. (batch_size, num_priors * 4, H, W).
  170. kernel_preds (list[Tensor]): Kernel predictions of dynamic
  171. convs for all scale levels, each is a 4D-tensor, has shape
  172. (batch_size, num_params, H, W).
  173. mask_feat (Tensor): Mask prototype features extracted from the
  174. mask head, has shape (batch_size, num_prototypes, H, W).
  175. score_factors (list[Tensor], optional): Score factor for
  176. all scale level, each is a 4D-tensor, has shape
  177. (batch_size, num_priors * 1, H, W). Defaults to None.
  178. batch_img_metas (list[dict], Optional): Batch image meta info.
  179. Defaults to None.
  180. cfg (ConfigDict, optional): Test / postprocessing
  181. configuration, if None, test_cfg would be used.
  182. Defaults to None.
  183. rescale (bool): If True, return boxes in original image space.
  184. Defaults to False.
  185. with_nms (bool): If True, do nms before return boxes.
  186. Defaults to True.
  187. Returns:
  188. list[:obj:`InstanceData`]: Object detection results of each image
  189. after the post process. Each item usually contains following keys.
  190. - scores (Tensor): Classification scores, has a shape
  191. (num_instance, )
  192. - labels (Tensor): Labels of bboxes, has a shape
  193. (num_instances, ).
  194. - bboxes (Tensor): Has a shape (num_instances, 4),
  195. the last dimension 4 arrange as (x1, y1, x2, y2).
  196. - masks (Tensor): Has a shape (num_instances, h, w).
  197. """
  198. assert len(cls_scores) == len(bbox_preds)
  199. if score_factors is None:
  200. # e.g. Retina, FreeAnchor, Foveabox, etc.
  201. with_score_factors = False
  202. else:
  203. # e.g. FCOS, PAA, ATSS, AutoAssign, etc.
  204. with_score_factors = True
  205. assert len(cls_scores) == len(score_factors)
  206. num_levels = len(cls_scores)
  207. featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
  208. mlvl_priors = self.prior_generator.grid_priors(
  209. featmap_sizes,
  210. dtype=cls_scores[0].dtype,
  211. device=cls_scores[0].device,
  212. with_stride=True)
  213. result_list = []
  214. for img_id in range(len(batch_img_metas)):
  215. img_meta = batch_img_metas[img_id]
  216. cls_score_list = select_single_mlvl(
  217. cls_scores, img_id, detach=True)
  218. bbox_pred_list = select_single_mlvl(
  219. bbox_preds, img_id, detach=True)
  220. kernel_pred_list = select_single_mlvl(
  221. kernel_preds, img_id, detach=True)
  222. if with_score_factors:
  223. score_factor_list = select_single_mlvl(
  224. score_factors, img_id, detach=True)
  225. else:
  226. score_factor_list = [None for _ in range(num_levels)]
  227. results = self._predict_by_feat_single(
  228. cls_score_list=cls_score_list,
  229. bbox_pred_list=bbox_pred_list,
  230. kernel_pred_list=kernel_pred_list,
  231. mask_feat=mask_feat[img_id],
  232. score_factor_list=score_factor_list,
  233. mlvl_priors=mlvl_priors,
  234. img_meta=img_meta,
  235. cfg=cfg,
  236. rescale=rescale,
  237. with_nms=with_nms)
  238. result_list.append(results)
  239. return result_list
  240. def _predict_by_feat_single(self,
  241. cls_score_list: List[Tensor],
  242. bbox_pred_list: List[Tensor],
  243. kernel_pred_list: List[Tensor],
  244. mask_feat: Tensor,
  245. score_factor_list: List[Tensor],
  246. mlvl_priors: List[Tensor],
  247. img_meta: dict,
  248. cfg: ConfigType,
  249. rescale: bool = False,
  250. with_nms: bool = True) -> InstanceData:
  251. """Transform a single image's features extracted from the head into
  252. bbox and mask results.
  253. Args:
  254. cls_score_list (list[Tensor]): Box scores from all scale
  255. levels of a single image, each item has shape
  256. (num_priors * num_classes, H, W).
  257. bbox_pred_list (list[Tensor]): Box energies / deltas from
  258. all scale levels of a single image, each item has shape
  259. (num_priors * 4, H, W).
  260. kernel_preds (list[Tensor]): Kernel predictions of dynamic
  261. convs for all scale levels of a single image, each is a
  262. 4D-tensor, has shape (num_params, H, W).
  263. mask_feat (Tensor): Mask prototype features of a single image
  264. extracted from the mask head, has shape (num_prototypes, H, W).
  265. score_factor_list (list[Tensor]): Score factor from all scale
  266. levels of a single image, each item has shape
  267. (num_priors * 1, H, W).
  268. mlvl_priors (list[Tensor]): Each element in the list is
  269. the priors of a single level in feature pyramid. In all
  270. anchor-based methods, it has shape (num_priors, 4). In
  271. all anchor-free methods, it has shape (num_priors, 2)
  272. when `with_stride=True`, otherwise it still has shape
  273. (num_priors, 4).
  274. img_meta (dict): Image meta info.
  275. cfg (mmengine.Config): Test / postprocessing configuration,
  276. if None, test_cfg would be used.
  277. rescale (bool): If True, return boxes in original image space.
  278. Defaults to False.
  279. with_nms (bool): If True, do nms before return boxes.
  280. Defaults to True.
  281. Returns:
  282. :obj:`InstanceData`: Detection results of each image
  283. after the post process.
  284. Each item usually contains following keys.
  285. - scores (Tensor): Classification scores, has a shape
  286. (num_instance, )
  287. - labels (Tensor): Labels of bboxes, has a shape
  288. (num_instances, ).
  289. - bboxes (Tensor): Has a shape (num_instances, 4),
  290. the last dimension 4 arrange as (x1, y1, x2, y2).
  291. - masks (Tensor): Has a shape (num_instances, h, w).
  292. """
  293. if score_factor_list[0] is None:
  294. # e.g. Retina, FreeAnchor, etc.
  295. with_score_factors = False
  296. else:
  297. # e.g. FCOS, PAA, ATSS, etc.
  298. with_score_factors = True
  299. cfg = self.test_cfg if cfg is None else cfg
  300. cfg = copy.deepcopy(cfg)
  301. img_shape = img_meta['img_shape']
  302. nms_pre = cfg.get('nms_pre', -1)
  303. mlvl_bbox_preds = []
  304. mlvl_kernels = []
  305. mlvl_valid_priors = []
  306. mlvl_scores = []
  307. mlvl_labels = []
  308. if with_score_factors:
  309. mlvl_score_factors = []
  310. else:
  311. mlvl_score_factors = None
  312. for level_idx, (cls_score, bbox_pred, kernel_pred,
  313. score_factor, priors) in \
  314. enumerate(zip(cls_score_list, bbox_pred_list, kernel_pred_list,
  315. score_factor_list, mlvl_priors)):
  316. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  317. dim = self.bbox_coder.encode_size
  318. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim)
  319. if with_score_factors:
  320. score_factor = score_factor.permute(1, 2,
  321. 0).reshape(-1).sigmoid()
  322. cls_score = cls_score.permute(1, 2,
  323. 0).reshape(-1, self.cls_out_channels)
  324. kernel_pred = kernel_pred.permute(1, 2, 0).reshape(
  325. -1, self.num_gen_params)
  326. if self.use_sigmoid_cls:
  327. scores = cls_score.sigmoid()
  328. else:
  329. # remind that we set FG labels to [0, num_class-1]
  330. # since mmdet v2.0
  331. # BG cat_id: num_class
  332. scores = cls_score.softmax(-1)[:, :-1]
  333. # After https://github.com/open-mmlab/mmdetection/pull/6268/,
  334. # this operation keeps fewer bboxes under the same `nms_pre`.
  335. # There is no difference in performance for most models. If you
  336. # find a slight drop in performance, you can set a larger
  337. # `nms_pre` than before.
  338. score_thr = cfg.get('score_thr', 0)
  339. results = filter_scores_and_topk(
  340. scores, score_thr, nms_pre,
  341. dict(
  342. bbox_pred=bbox_pred,
  343. priors=priors,
  344. kernel_pred=kernel_pred))
  345. scores, labels, keep_idxs, filtered_results = results
  346. bbox_pred = filtered_results['bbox_pred']
  347. priors = filtered_results['priors']
  348. kernel_pred = filtered_results['kernel_pred']
  349. if with_score_factors:
  350. score_factor = score_factor[keep_idxs]
  351. mlvl_bbox_preds.append(bbox_pred)
  352. mlvl_valid_priors.append(priors)
  353. mlvl_scores.append(scores)
  354. mlvl_labels.append(labels)
  355. mlvl_kernels.append(kernel_pred)
  356. if with_score_factors:
  357. mlvl_score_factors.append(score_factor)
  358. bbox_pred = torch.cat(mlvl_bbox_preds)
  359. priors = cat_boxes(mlvl_valid_priors)
  360. bboxes = self.bbox_coder.decode(
  361. priors[..., :2], bbox_pred, max_shape=img_shape)
  362. results = InstanceData()
  363. results.bboxes = bboxes
  364. results.priors = priors
  365. results.scores = torch.cat(mlvl_scores)
  366. results.labels = torch.cat(mlvl_labels)
  367. results.kernels = torch.cat(mlvl_kernels)
  368. if with_score_factors:
  369. results.score_factors = torch.cat(mlvl_score_factors)
  370. return self._bbox_mask_post_process(
  371. results=results,
  372. mask_feat=mask_feat,
  373. cfg=cfg,
  374. rescale=rescale,
  375. with_nms=with_nms,
  376. img_meta=img_meta)
  377. def _bbox_mask_post_process(
  378. self,
  379. results: InstanceData,
  380. mask_feat,
  381. cfg: ConfigType,
  382. rescale: bool = False,
  383. with_nms: bool = True,
  384. img_meta: Optional[dict] = None) -> InstanceData:
  385. """bbox and mask post-processing method.
  386. The boxes would be rescaled to the original image scale and do
  387. the nms operation. Usually `with_nms` is False is used for aug test.
  388. Args:
  389. results (:obj:`InstaceData`): Detection instance results,
  390. each item has shape (num_bboxes, ).
  391. cfg (ConfigDict): Test / postprocessing configuration,
  392. if None, test_cfg would be used.
  393. rescale (bool): If True, return boxes in original image space.
  394. Default to False.
  395. with_nms (bool): If True, do nms before return boxes.
  396. Default to True.
  397. img_meta (dict, optional): Image meta info. Defaults to None.
  398. Returns:
  399. :obj:`InstanceData`: Detection results of each image
  400. after the post process.
  401. Each item usually contains following keys.
  402. - scores (Tensor): Classification scores, has a shape
  403. (num_instance, )
  404. - labels (Tensor): Labels of bboxes, has a shape
  405. (num_instances, ).
  406. - bboxes (Tensor): Has a shape (num_instances, 4),
  407. the last dimension 4 arrange as (x1, y1, x2, y2).
  408. - masks (Tensor): Has a shape (num_instances, h, w).
  409. """
  410. stride = self.prior_generator.strides[0][0]
  411. if rescale:
  412. assert img_meta.get('scale_factor') is not None
  413. scale_factor = [1 / s for s in img_meta['scale_factor']]
  414. results.bboxes = scale_boxes(results.bboxes, scale_factor)
  415. if hasattr(results, 'score_factors'):
  416. # TODO: Add sqrt operation in order to be consistent with
  417. # the paper.
  418. score_factors = results.pop('score_factors')
  419. results.scores = results.scores * score_factors
  420. # filter small size bboxes
  421. if cfg.get('min_bbox_size', -1) >= 0:
  422. w, h = get_box_wh(results.bboxes)
  423. valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
  424. if not valid_mask.all():
  425. results = results[valid_mask]
  426. # TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg
  427. assert with_nms, 'with_nms must be True for RTMDet-Ins'
  428. if results.bboxes.numel() > 0:
  429. bboxes = get_box_tensor(results.bboxes)
  430. det_bboxes, keep_idxs = batched_nms(bboxes, results.scores,
  431. results.labels, cfg.nms)
  432. results = results[keep_idxs]
  433. # some nms would reweight the score, such as softnms
  434. results.scores = det_bboxes[:, -1]
  435. results = results[:cfg.max_per_img]
  436. # process masks
  437. mask_logits = self._mask_predict_by_feat_single(
  438. mask_feat, results.kernels, results.priors)
  439. mask_logits = F.interpolate(
  440. mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear')
  441. if rescale:
  442. ori_h, ori_w = img_meta['ori_shape'][:2]
  443. mask_logits = F.interpolate(
  444. mask_logits,
  445. size=[
  446. math.ceil(mask_logits.shape[-2] * scale_factor[0]),
  447. math.ceil(mask_logits.shape[-1] * scale_factor[1])
  448. ],
  449. mode='bilinear',
  450. align_corners=False)[..., :ori_h, :ori_w]
  451. masks = mask_logits.sigmoid().squeeze(0)
  452. masks = masks > cfg.mask_thr_binary
  453. results.masks = masks
  454. else:
  455. h, w = img_meta['ori_shape'][:2] if rescale else img_meta[
  456. 'img_shape'][:2]
  457. results.masks = torch.zeros(
  458. size=(results.bboxes.shape[0], h, w),
  459. dtype=torch.bool,
  460. device=results.bboxes.device)
  461. return results
  462. def parse_dynamic_params(self, flatten_kernels: Tensor) -> tuple:
  463. """split kernel head prediction to conv weight and bias."""
  464. n_inst = flatten_kernels.size(0)
  465. n_layers = len(self.weight_nums)
  466. params_splits = list(
  467. torch.split_with_sizes(
  468. flatten_kernels, self.weight_nums + self.bias_nums, dim=1))
  469. weight_splits = params_splits[:n_layers]
  470. bias_splits = params_splits[n_layers:]
  471. for i in range(n_layers):
  472. if i < n_layers - 1:
  473. weight_splits[i] = weight_splits[i].reshape(
  474. n_inst * self.dyconv_channels, -1, 1, 1)
  475. bias_splits[i] = bias_splits[i].reshape(n_inst *
  476. self.dyconv_channels)
  477. else:
  478. weight_splits[i] = weight_splits[i].reshape(n_inst, -1, 1, 1)
  479. bias_splits[i] = bias_splits[i].reshape(n_inst)
  480. return weight_splits, bias_splits
  481. def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor,
  482. priors: Tensor) -> Tensor:
  483. """Generate mask logits from mask features with dynamic convs.
  484. Args:
  485. mask_feat (Tensor): Mask prototype features.
  486. Has shape (num_prototypes, H, W).
  487. kernels (Tensor): Kernel parameters for each instance.
  488. Has shape (num_instance, num_params)
  489. priors (Tensor): Center priors for each instance.
  490. Has shape (num_instance, 4).
  491. Returns:
  492. Tensor: Instance segmentation masks for each instance.
  493. Has shape (num_instance, H, W).
  494. """
  495. num_inst = priors.shape[0]
  496. h, w = mask_feat.size()[-2:]
  497. if num_inst < 1:
  498. return torch.empty(
  499. size=(num_inst, h, w),
  500. dtype=mask_feat.dtype,
  501. device=mask_feat.device)
  502. if len(mask_feat.shape) < 4:
  503. mask_feat.unsqueeze(0)
  504. coord = self.prior_generator.single_level_grid_priors(
  505. (h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2)
  506. num_inst = priors.shape[0]
  507. points = priors[:, :2].reshape(-1, 1, 2)
  508. strides = priors[:, 2:].reshape(-1, 1, 2)
  509. relative_coord = (points - coord).permute(0, 2, 1) / (
  510. strides[..., 0].reshape(-1, 1, 1) * 8)
  511. relative_coord = relative_coord.reshape(num_inst, 2, h, w)
  512. mask_feat = torch.cat(
  513. [relative_coord,
  514. mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
  515. weights, biases = self.parse_dynamic_params(kernels)
  516. n_layers = len(weights)
  517. x = mask_feat.reshape(1, -1, h, w)
  518. for i, (weight, bias) in enumerate(zip(weights, biases)):
  519. x = F.conv2d(
  520. x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
  521. if i < n_layers - 1:
  522. x = F.relu(x)
  523. x = x.reshape(num_inst, h, w)
  524. return x
  525. def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor,
  526. sampling_results_list: list,
  527. batch_gt_instances: InstanceList) -> Tensor:
  528. """Compute instance segmentation loss.
  529. Args:
  530. mask_feats (list[Tensor]): Mask prototype features extracted from
  531. the mask head. Has shape (N, num_prototypes, H, W)
  532. flatten_kernels (list[Tensor]): Kernels of the dynamic conv layers.
  533. Has shape (N, num_instances, num_params)
  534. sampling_results_list (list[:obj:`SamplingResults`]) Batch of
  535. assignment results.
  536. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  537. gt_instance. It usually includes ``bboxes`` and ``labels``
  538. attributes.
  539. Returns:
  540. Tensor: The mask loss tensor.
  541. """
  542. batch_pos_mask_logits = []
  543. pos_gt_masks = []
  544. for idx, (mask_feat, kernels, sampling_results,
  545. gt_instances) in enumerate(
  546. zip(mask_feats, flatten_kernels, sampling_results_list,
  547. batch_gt_instances)):
  548. pos_priors = sampling_results.pos_priors
  549. pos_inds = sampling_results.pos_inds
  550. pos_kernels = kernels[pos_inds] # n_pos, num_gen_params
  551. pos_mask_logits = self._mask_predict_by_feat_single(
  552. mask_feat, pos_kernels, pos_priors)
  553. if gt_instances.masks.numel() == 0:
  554. gt_masks = torch.empty_like(gt_instances.masks)
  555. else:
  556. gt_masks = gt_instances.masks[
  557. sampling_results.pos_assigned_gt_inds, :]
  558. batch_pos_mask_logits.append(pos_mask_logits)
  559. pos_gt_masks.append(gt_masks)
  560. pos_gt_masks = torch.cat(pos_gt_masks, 0)
  561. batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0)
  562. # avg_factor
  563. num_pos = batch_pos_mask_logits.shape[0]
  564. num_pos = reduce_mean(mask_feats.new_tensor([num_pos
  565. ])).clamp_(min=1).item()
  566. if batch_pos_mask_logits.shape[0] == 0:
  567. return mask_feats.sum() * 0
  568. scale = self.prior_generator.strides[0][0] // self.mask_loss_stride
  569. # upsample pred masks
  570. batch_pos_mask_logits = F.interpolate(
  571. batch_pos_mask_logits.unsqueeze(0),
  572. scale_factor=scale,
  573. mode='bilinear',
  574. align_corners=False).squeeze(0)
  575. # downsample gt masks
  576. pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride //
  577. 2::self.mask_loss_stride,
  578. self.mask_loss_stride //
  579. 2::self.mask_loss_stride]
  580. loss_mask = self.loss_mask(
  581. batch_pos_mask_logits,
  582. pos_gt_masks,
  583. weight=None,
  584. avg_factor=num_pos)
  585. return loss_mask
  586. def loss_by_feat(self,
  587. cls_scores: List[Tensor],
  588. bbox_preds: List[Tensor],
  589. kernel_preds: List[Tensor],
  590. mask_feat: Tensor,
  591. batch_gt_instances: InstanceList,
  592. batch_img_metas: List[dict],
  593. batch_gt_instances_ignore: OptInstanceList = None):
  594. """Compute losses of the head.
  595. Args:
  596. cls_scores (list[Tensor]): Box scores for each scale level
  597. Has shape (N, num_anchors * num_classes, H, W)
  598. bbox_preds (list[Tensor]): Decoded box for each scale
  599. level with shape (N, num_anchors * 4, H, W) in
  600. [tl_x, tl_y, br_x, br_y] format.
  601. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  602. gt_instance. It usually includes ``bboxes`` and ``labels``
  603. attributes.
  604. batch_img_metas (list[dict]): Meta information of each image, e.g.,
  605. image size, scaling factor, etc.
  606. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
  607. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
  608. data that is ignored during training and testing.
  609. Defaults to None.
  610. Returns:
  611. dict[str, Tensor]: A dictionary of loss components.
  612. """
  613. num_imgs = len(batch_img_metas)
  614. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  615. assert len(featmap_sizes) == self.prior_generator.num_levels
  616. device = cls_scores[0].device
  617. anchor_list, valid_flag_list = self.get_anchors(
  618. featmap_sizes, batch_img_metas, device=device)
  619. flatten_cls_scores = torch.cat([
  620. cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
  621. self.cls_out_channels)
  622. for cls_score in cls_scores
  623. ], 1)
  624. flatten_kernels = torch.cat([
  625. kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
  626. self.num_gen_params)
  627. for kernel_pred in kernel_preds
  628. ], 1)
  629. decoded_bboxes = []
  630. for anchor, bbox_pred in zip(anchor_list[0], bbox_preds):
  631. anchor = anchor.reshape(-1, 4)
  632. bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
  633. bbox_pred = distance2bbox(anchor, bbox_pred)
  634. decoded_bboxes.append(bbox_pred)
  635. flatten_bboxes = torch.cat(decoded_bboxes, 1)
  636. for gt_instances in batch_gt_instances:
  637. gt_instances.masks = gt_instances.masks.to_tensor(
  638. dtype=torch.bool, device=device)
  639. cls_reg_targets = self.get_targets(
  640. flatten_cls_scores,
  641. flatten_bboxes,
  642. anchor_list,
  643. valid_flag_list,
  644. batch_gt_instances,
  645. batch_img_metas,
  646. batch_gt_instances_ignore=batch_gt_instances_ignore)
  647. (anchor_list, labels_list, label_weights_list, bbox_targets_list,
  648. assign_metrics_list, sampling_results_list) = cls_reg_targets
  649. losses_cls, losses_bbox,\
  650. cls_avg_factors, bbox_avg_factors = multi_apply(
  651. self.loss_by_feat_single,
  652. cls_scores,
  653. decoded_bboxes,
  654. labels_list,
  655. label_weights_list,
  656. bbox_targets_list,
  657. assign_metrics_list,
  658. self.prior_generator.strides)
  659. cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
  660. losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
  661. bbox_avg_factor = reduce_mean(
  662. sum(bbox_avg_factors)).clamp_(min=1).item()
  663. losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
  664. loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels,
  665. sampling_results_list,
  666. batch_gt_instances)
  667. loss = dict(
  668. loss_cls=losses_cls, loss_bbox=losses_bbox, loss_mask=loss_mask)
  669. return loss
  670. class MaskFeatModule(BaseModule):
  671. """Mask feature head used in RTMDet-Ins.
  672. Args:
  673. in_channels (int): Number of channels in the input feature map.
  674. feat_channels (int): Number of hidden channels of the mask feature
  675. map branch.
  676. num_levels (int): The starting feature map level from RPN that
  677. will be used to predict the mask feature map.
  678. num_prototypes (int): Number of output channel of the mask feature
  679. map branch. This is the channel count of the mask
  680. feature map that to be dynamically convolved with the predicted
  681. kernel.
  682. stacked_convs (int): Number of convs in mask feature branch.
  683. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
  684. Default: dict(type='ReLU', inplace=True)
  685. norm_cfg (dict): Config dict for normalization layer. Default: None.
  686. """
  687. def __init__(
  688. self,
  689. in_channels: int,
  690. feat_channels: int = 256,
  691. stacked_convs: int = 4,
  692. num_levels: int = 3,
  693. num_prototypes: int = 8,
  694. act_cfg: ConfigType = dict(type='ReLU', inplace=True),
  695. norm_cfg: ConfigType = dict(type='BN')
  696. ) -> None:
  697. super().__init__(init_cfg=None)
  698. self.num_levels = num_levels
  699. self.fusion_conv = nn.Conv2d(num_levels * in_channels, in_channels, 1)
  700. convs = []
  701. for i in range(stacked_convs):
  702. in_c = in_channels if i == 0 else feat_channels
  703. convs.append(
  704. ConvModule(
  705. in_c,
  706. feat_channels,
  707. 3,
  708. padding=1,
  709. act_cfg=act_cfg,
  710. norm_cfg=norm_cfg))
  711. self.stacked_convs = nn.Sequential(*convs)
  712. self.projection = nn.Conv2d(
  713. feat_channels, num_prototypes, kernel_size=1)
  714. def forward(self, features: Tuple[Tensor, ...]) -> Tensor:
  715. # multi-level feature fusion
  716. fusion_feats = [features[0]]
  717. size = features[0].shape[-2:]
  718. for i in range(1, self.num_levels):
  719. f = F.interpolate(features[i], size=size, mode='bilinear')
  720. fusion_feats.append(f)
  721. fusion_feats = torch.cat(fusion_feats, dim=1)
  722. fusion_feats = self.fusion_conv(fusion_feats)
  723. # pred mask feats
  724. mask_features = self.stacked_convs(fusion_feats)
  725. mask_features = self.projection(mask_features)
  726. return mask_features
  727. @MODELS.register_module()
  728. class RTMDetInsSepBNHead(RTMDetInsHead):
  729. """Detection Head of RTMDet-Ins with sep-bn layers.
  730. Args:
  731. num_classes (int): Number of categories excluding the background
  732. category.
  733. in_channels (int): Number of channels in the input feature map.
  734. share_conv (bool): Whether to share conv layers between stages.
  735. Defaults to True.
  736. norm_cfg (:obj:`ConfigDict` or dict)): Config dict for normalization
  737. layer. Defaults to dict(type='BN').
  738. act_cfg (:obj:`ConfigDict` or dict)): Config dict for activation layer.
  739. Defaults to dict(type='SiLU', inplace=True).
  740. pred_kernel_size (int): Kernel size of prediction layer. Defaults to 1.
  741. """
  742. def __init__(self,
  743. num_classes: int,
  744. in_channels: int,
  745. share_conv: bool = True,
  746. with_objectness: bool = False,
  747. norm_cfg: ConfigType = dict(type='BN', requires_grad=True),
  748. act_cfg: ConfigType = dict(type='SiLU', inplace=True),
  749. pred_kernel_size: int = 1,
  750. **kwargs) -> None:
  751. self.share_conv = share_conv
  752. super().__init__(
  753. num_classes,
  754. in_channels,
  755. norm_cfg=norm_cfg,
  756. act_cfg=act_cfg,
  757. pred_kernel_size=pred_kernel_size,
  758. with_objectness=with_objectness,
  759. **kwargs)
  760. def _init_layers(self) -> None:
  761. """Initialize layers of the head."""
  762. self.cls_convs = nn.ModuleList()
  763. self.reg_convs = nn.ModuleList()
  764. self.kernel_convs = nn.ModuleList()
  765. self.rtm_cls = nn.ModuleList()
  766. self.rtm_reg = nn.ModuleList()
  767. self.rtm_kernel = nn.ModuleList()
  768. self.rtm_obj = nn.ModuleList()
  769. # calculate num dynamic parameters
  770. weight_nums, bias_nums = [], []
  771. for i in range(self.num_dyconvs):
  772. if i == 0:
  773. weight_nums.append(
  774. (self.num_prototypes + 2) * self.dyconv_channels)
  775. bias_nums.append(self.dyconv_channels)
  776. elif i == self.num_dyconvs - 1:
  777. weight_nums.append(self.dyconv_channels)
  778. bias_nums.append(1)
  779. else:
  780. weight_nums.append(self.dyconv_channels * self.dyconv_channels)
  781. bias_nums.append(self.dyconv_channels)
  782. self.weight_nums = weight_nums
  783. self.bias_nums = bias_nums
  784. self.num_gen_params = sum(weight_nums) + sum(bias_nums)
  785. pred_pad_size = self.pred_kernel_size // 2
  786. for n in range(len(self.prior_generator.strides)):
  787. cls_convs = nn.ModuleList()
  788. reg_convs = nn.ModuleList()
  789. kernel_convs = nn.ModuleList()
  790. for i in range(self.stacked_convs):
  791. chn = self.in_channels if i == 0 else self.feat_channels
  792. cls_convs.append(
  793. ConvModule(
  794. chn,
  795. self.feat_channels,
  796. 3,
  797. stride=1,
  798. padding=1,
  799. conv_cfg=self.conv_cfg,
  800. norm_cfg=self.norm_cfg,
  801. act_cfg=self.act_cfg))
  802. reg_convs.append(
  803. ConvModule(
  804. chn,
  805. self.feat_channels,
  806. 3,
  807. stride=1,
  808. padding=1,
  809. conv_cfg=self.conv_cfg,
  810. norm_cfg=self.norm_cfg,
  811. act_cfg=self.act_cfg))
  812. kernel_convs.append(
  813. ConvModule(
  814. chn,
  815. self.feat_channels,
  816. 3,
  817. stride=1,
  818. padding=1,
  819. conv_cfg=self.conv_cfg,
  820. norm_cfg=self.norm_cfg,
  821. act_cfg=self.act_cfg))
  822. self.cls_convs.append(cls_convs)
  823. self.reg_convs.append(cls_convs)
  824. self.kernel_convs.append(kernel_convs)
  825. self.rtm_cls.append(
  826. nn.Conv2d(
  827. self.feat_channels,
  828. self.num_base_priors * self.cls_out_channels,
  829. self.pred_kernel_size,
  830. padding=pred_pad_size))
  831. self.rtm_reg.append(
  832. nn.Conv2d(
  833. self.feat_channels,
  834. self.num_base_priors * 4,
  835. self.pred_kernel_size,
  836. padding=pred_pad_size))
  837. self.rtm_kernel.append(
  838. nn.Conv2d(
  839. self.feat_channels,
  840. self.num_gen_params,
  841. self.pred_kernel_size,
  842. padding=pred_pad_size))
  843. if self.with_objectness:
  844. self.rtm_obj.append(
  845. nn.Conv2d(
  846. self.feat_channels,
  847. 1,
  848. self.pred_kernel_size,
  849. padding=pred_pad_size))
  850. if self.share_conv:
  851. for n in range(len(self.prior_generator.strides)):
  852. for i in range(self.stacked_convs):
  853. self.cls_convs[n][i].conv = self.cls_convs[0][i].conv
  854. self.reg_convs[n][i].conv = self.reg_convs[0][i].conv
  855. self.mask_head = MaskFeatModule(
  856. in_channels=self.in_channels,
  857. feat_channels=self.feat_channels,
  858. stacked_convs=4,
  859. num_levels=len(self.prior_generator.strides),
  860. num_prototypes=self.num_prototypes,
  861. act_cfg=self.act_cfg,
  862. norm_cfg=self.norm_cfg)
  863. def init_weights(self) -> None:
  864. """Initialize weights of the head."""
  865. for m in self.modules():
  866. if isinstance(m, nn.Conv2d):
  867. normal_init(m, mean=0, std=0.01)
  868. if is_norm(m):
  869. constant_init(m, 1)
  870. bias_cls = bias_init_with_prob(0.01)
  871. for rtm_cls, rtm_reg, rtm_kernel in zip(self.rtm_cls, self.rtm_reg,
  872. self.rtm_kernel):
  873. normal_init(rtm_cls, std=0.01, bias=bias_cls)
  874. normal_init(rtm_reg, std=0.01, bias=1)
  875. if self.with_objectness:
  876. for rtm_obj in self.rtm_obj:
  877. normal_init(rtm_obj, std=0.01, bias=bias_cls)
  878. def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
  879. """Forward features from the upstream network.
  880. Args:
  881. feats (tuple[Tensor]): Features from the upstream network, each is
  882. a 4D-tensor.
  883. Returns:
  884. tuple: Usually a tuple of classification scores and bbox prediction
  885. - cls_scores (list[Tensor]): Classification scores for all scale
  886. levels, each is a 4D-tensor, the channels number is
  887. num_base_priors * num_classes.
  888. - bbox_preds (list[Tensor]): Box energies / deltas for all scale
  889. levels, each is a 4D-tensor, the channels number is
  890. num_base_priors * 4.
  891. - kernel_preds (list[Tensor]): Dynamic conv kernels for all scale
  892. levels, each is a 4D-tensor, the channels number is
  893. num_gen_params.
  894. - mask_feat (Tensor): Output feature of the mask head. Each is a
  895. 4D-tensor, the channels number is num_prototypes.
  896. """
  897. mask_feat = self.mask_head(feats)
  898. cls_scores = []
  899. bbox_preds = []
  900. kernel_preds = []
  901. for idx, (x, stride) in enumerate(
  902. zip(feats, self.prior_generator.strides)):
  903. cls_feat = x
  904. reg_feat = x
  905. kernel_feat = x
  906. for cls_layer in self.cls_convs[idx]:
  907. cls_feat = cls_layer(cls_feat)
  908. cls_score = self.rtm_cls[idx](cls_feat)
  909. for kernel_layer in self.kernel_convs[idx]:
  910. kernel_feat = kernel_layer(kernel_feat)
  911. kernel_pred = self.rtm_kernel[idx](kernel_feat)
  912. for reg_layer in self.reg_convs[idx]:
  913. reg_feat = reg_layer(reg_feat)
  914. if self.with_objectness:
  915. objectness = self.rtm_obj[idx](reg_feat)
  916. cls_score = inverse_sigmoid(
  917. sigmoid_geometric_mean(cls_score, objectness))
  918. reg_dist = F.relu(self.rtm_reg[idx](reg_feat)) * stride[0]
  919. cls_scores.append(cls_score)
  920. bbox_preds.append(reg_dist)
  921. kernel_preds.append(kernel_pred)
  922. return tuple(cls_scores), tuple(bbox_preds), tuple(
  923. kernel_preds), mask_feat