dii_head.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import build_activation_layer, build_norm_layer
  6. from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
  7. from mmengine.config import ConfigDict
  8. from mmengine.model import bias_init_with_prob
  9. from torch import Tensor
  10. from mmdet.models.losses import accuracy
  11. from mmdet.models.task_modules import SamplingResult
  12. from mmdet.models.utils import multi_apply
  13. from mmdet.registry import MODELS
  14. from mmdet.utils import ConfigType, OptConfigType, reduce_mean
  15. from .bbox_head import BBoxHead
  16. @MODELS.register_module()
  17. class DIIHead(BBoxHead):
  18. r"""Dynamic Instance Interactive Head for `Sparse R-CNN: End-to-End Object
  19. Detection with Learnable Proposals <https://arxiv.org/abs/2011.12450>`_
  20. Args:
  21. num_classes (int): Number of class in dataset.
  22. Defaults to 80.
  23. num_ffn_fcs (int): The number of fully-connected
  24. layers in FFNs. Defaults to 2.
  25. num_heads (int): The hidden dimension of FFNs.
  26. Defaults to 8.
  27. num_cls_fcs (int): The number of fully-connected
  28. layers in classification subnet. Defaults to 1.
  29. num_reg_fcs (int): The number of fully-connected
  30. layers in regression subnet. Defaults to 3.
  31. feedforward_channels (int): The hidden dimension
  32. of FFNs. Defaults to 2048
  33. in_channels (int): Hidden_channels of MultiheadAttention.
  34. Defaults to 256.
  35. dropout (float): Probability of drop the channel.
  36. Defaults to 0.0
  37. ffn_act_cfg (:obj:`ConfigDict` or dict): The activation config
  38. for FFNs.
  39. dynamic_conv_cfg (:obj:`ConfigDict` or dict): The convolution
  40. config for DynamicConv.
  41. loss_iou (:obj:`ConfigDict` or dict): The config for iou or
  42. giou loss.
  43. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  44. dict]): Initialization config dict. Defaults to None.
  45. """
  46. def __init__(self,
  47. num_classes: int = 80,
  48. num_ffn_fcs: int = 2,
  49. num_heads: int = 8,
  50. num_cls_fcs: int = 1,
  51. num_reg_fcs: int = 3,
  52. feedforward_channels: int = 2048,
  53. in_channels: int = 256,
  54. dropout: float = 0.0,
  55. ffn_act_cfg: ConfigType = dict(type='ReLU', inplace=True),
  56. dynamic_conv_cfg: ConfigType = dict(
  57. type='DynamicConv',
  58. in_channels=256,
  59. feat_channels=64,
  60. out_channels=256,
  61. input_feat_shape=7,
  62. act_cfg=dict(type='ReLU', inplace=True),
  63. norm_cfg=dict(type='LN')),
  64. loss_iou: ConfigType = dict(type='GIoULoss', loss_weight=2.0),
  65. init_cfg: OptConfigType = None,
  66. **kwargs) -> None:
  67. assert init_cfg is None, 'To prevent abnormal initialization ' \
  68. 'behavior, init_cfg is not allowed to be set'
  69. super().__init__(
  70. num_classes=num_classes,
  71. reg_decoded_bbox=True,
  72. reg_class_agnostic=True,
  73. init_cfg=init_cfg,
  74. **kwargs)
  75. self.loss_iou = MODELS.build(loss_iou)
  76. self.in_channels = in_channels
  77. self.fp16_enabled = False
  78. self.attention = MultiheadAttention(in_channels, num_heads, dropout)
  79. self.attention_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
  80. self.instance_interactive_conv = MODELS.build(dynamic_conv_cfg)
  81. self.instance_interactive_conv_dropout = nn.Dropout(dropout)
  82. self.instance_interactive_conv_norm = build_norm_layer(
  83. dict(type='LN'), in_channels)[1]
  84. self.ffn = FFN(
  85. in_channels,
  86. feedforward_channels,
  87. num_ffn_fcs,
  88. act_cfg=ffn_act_cfg,
  89. dropout=dropout)
  90. self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
  91. self.cls_fcs = nn.ModuleList()
  92. for _ in range(num_cls_fcs):
  93. self.cls_fcs.append(
  94. nn.Linear(in_channels, in_channels, bias=False))
  95. self.cls_fcs.append(
  96. build_norm_layer(dict(type='LN'), in_channels)[1])
  97. self.cls_fcs.append(
  98. build_activation_layer(dict(type='ReLU', inplace=True)))
  99. # over load the self.fc_cls in BBoxHead
  100. if self.loss_cls.use_sigmoid:
  101. self.fc_cls = nn.Linear(in_channels, self.num_classes)
  102. else:
  103. self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)
  104. self.reg_fcs = nn.ModuleList()
  105. for _ in range(num_reg_fcs):
  106. self.reg_fcs.append(
  107. nn.Linear(in_channels, in_channels, bias=False))
  108. self.reg_fcs.append(
  109. build_norm_layer(dict(type='LN'), in_channels)[1])
  110. self.reg_fcs.append(
  111. build_activation_layer(dict(type='ReLU', inplace=True)))
  112. # over load the self.fc_cls in BBoxHead
  113. self.fc_reg = nn.Linear(in_channels, 4)
  114. assert self.reg_class_agnostic, 'DIIHead only ' \
  115. 'suppport `reg_class_agnostic=True` '
  116. assert self.reg_decoded_bbox, 'DIIHead only ' \
  117. 'suppport `reg_decoded_bbox=True`'
  118. def init_weights(self) -> None:
  119. """Use xavier initialization for all weight parameter and set
  120. classification head bias as a specific value when use focal loss."""
  121. super().init_weights()
  122. for p in self.parameters():
  123. if p.dim() > 1:
  124. nn.init.xavier_uniform_(p)
  125. else:
  126. # adopt the default initialization for
  127. # the weight and bias of the layer norm
  128. pass
  129. if self.loss_cls.use_sigmoid:
  130. bias_init = bias_init_with_prob(0.01)
  131. nn.init.constant_(self.fc_cls.bias, bias_init)
  132. def forward(self, roi_feat: Tensor, proposal_feat: Tensor) -> tuple:
  133. """Forward function of Dynamic Instance Interactive Head.
  134. Args:
  135. roi_feat (Tensor): Roi-pooling features with shape
  136. (batch_size*num_proposals, feature_dimensions,
  137. pooling_h , pooling_w).
  138. proposal_feat (Tensor): Intermediate feature get from
  139. diihead in last stage, has shape
  140. (batch_size, num_proposals, feature_dimensions)
  141. Returns:
  142. tuple[Tensor]: Usually a tuple of classification scores
  143. and bbox prediction and a intermediate feature.
  144. - cls_scores (Tensor): Classification scores for
  145. all proposals, has shape
  146. (batch_size, num_proposals, num_classes).
  147. - bbox_preds (Tensor): Box energies / deltas for
  148. all proposals, has shape
  149. (batch_size, num_proposals, 4).
  150. - obj_feat (Tensor): Object feature before classification
  151. and regression subnet, has shape
  152. (batch_size, num_proposal, feature_dimensions).
  153. - attn_feats (Tensor): Intermediate feature.
  154. """
  155. N, num_proposals = proposal_feat.shape[:2]
  156. # Self attention
  157. proposal_feat = proposal_feat.permute(1, 0, 2)
  158. proposal_feat = self.attention_norm(self.attention(proposal_feat))
  159. attn_feats = proposal_feat.permute(1, 0, 2)
  160. # instance interactive
  161. proposal_feat = attn_feats.reshape(-1, self.in_channels)
  162. proposal_feat_iic = self.instance_interactive_conv(
  163. proposal_feat, roi_feat)
  164. proposal_feat = proposal_feat + self.instance_interactive_conv_dropout(
  165. proposal_feat_iic)
  166. obj_feat = self.instance_interactive_conv_norm(proposal_feat)
  167. # FFN
  168. obj_feat = self.ffn_norm(self.ffn(obj_feat))
  169. cls_feat = obj_feat
  170. reg_feat = obj_feat
  171. for cls_layer in self.cls_fcs:
  172. cls_feat = cls_layer(cls_feat)
  173. for reg_layer in self.reg_fcs:
  174. reg_feat = reg_layer(reg_feat)
  175. cls_score = self.fc_cls(cls_feat).view(
  176. N, num_proposals, self.num_classes
  177. if self.loss_cls.use_sigmoid else self.num_classes + 1)
  178. bbox_delta = self.fc_reg(reg_feat).view(N, num_proposals, 4)
  179. return cls_score, bbox_delta, obj_feat.view(
  180. N, num_proposals, self.in_channels), attn_feats
  181. def loss_and_target(self,
  182. cls_score: Tensor,
  183. bbox_pred: Tensor,
  184. sampling_results: List[SamplingResult],
  185. rcnn_train_cfg: ConfigType,
  186. imgs_whwh: Tensor,
  187. concat: bool = True,
  188. reduction_override: str = None) -> dict:
  189. """Calculate the loss based on the features extracted by the DIIHead.
  190. Args:
  191. cls_score (Tensor): Classification prediction
  192. results of all class, has shape
  193. (batch_size * num_proposals_single_image, num_classes)
  194. bbox_pred (Tensor): Regression prediction results, has shape
  195. (batch_size * num_proposals_single_image, 4), the last
  196. dimension 4 represents [tl_x, tl_y, br_x, br_y].
  197. sampling_results (List[obj:SamplingResult]): Assign results of
  198. all images in a batch after sampling.
  199. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
  200. imgs_whwh (Tensor): imgs_whwh (Tensor): Tensor with\
  201. shape (batch_size, num_proposals, 4), the last
  202. dimension means
  203. [img_width,img_height, img_width, img_height].
  204. concat (bool): Whether to concatenate the results of all
  205. the images in a single batch. Defaults to True.
  206. reduction_override (str, optional): The reduction
  207. method used to override the original reduction
  208. method of the loss. Options are "none",
  209. "mean" and "sum". Defaults to None.
  210. Returns:
  211. dict: A dictionary of loss and targets components.
  212. The targets are only used for cascade rcnn.
  213. """
  214. cls_reg_targets = self.get_targets(
  215. sampling_results=sampling_results,
  216. rcnn_train_cfg=rcnn_train_cfg,
  217. concat=concat)
  218. (labels, label_weights, bbox_targets, bbox_weights) = cls_reg_targets
  219. losses = dict()
  220. bg_class_ind = self.num_classes
  221. # note in spare rcnn num_gt == num_pos
  222. pos_inds = (labels >= 0) & (labels < bg_class_ind)
  223. num_pos = pos_inds.sum().float()
  224. avg_factor = reduce_mean(num_pos)
  225. if cls_score is not None:
  226. if cls_score.numel() > 0:
  227. losses['loss_cls'] = self.loss_cls(
  228. cls_score,
  229. labels,
  230. label_weights,
  231. avg_factor=avg_factor,
  232. reduction_override=reduction_override)
  233. losses['pos_acc'] = accuracy(cls_score[pos_inds],
  234. labels[pos_inds])
  235. if bbox_pred is not None:
  236. # 0~self.num_classes-1 are FG, self.num_classes is BG
  237. # do not perform bounding box regression for BG anymore.
  238. if pos_inds.any():
  239. pos_bbox_pred = bbox_pred.reshape(bbox_pred.size(0),
  240. 4)[pos_inds.type(torch.bool)]
  241. imgs_whwh = imgs_whwh.reshape(bbox_pred.size(0),
  242. 4)[pos_inds.type(torch.bool)]
  243. losses['loss_bbox'] = self.loss_bbox(
  244. pos_bbox_pred / imgs_whwh,
  245. bbox_targets[pos_inds.type(torch.bool)] / imgs_whwh,
  246. bbox_weights[pos_inds.type(torch.bool)],
  247. avg_factor=avg_factor)
  248. losses['loss_iou'] = self.loss_iou(
  249. pos_bbox_pred,
  250. bbox_targets[pos_inds.type(torch.bool)],
  251. bbox_weights[pos_inds.type(torch.bool)],
  252. avg_factor=avg_factor)
  253. else:
  254. losses['loss_bbox'] = bbox_pred.sum() * 0
  255. losses['loss_iou'] = bbox_pred.sum() * 0
  256. return dict(loss_bbox=losses, bbox_targets=cls_reg_targets)
  257. def _get_targets_single(self, pos_inds: Tensor, neg_inds: Tensor,
  258. pos_priors: Tensor, neg_priors: Tensor,
  259. pos_gt_bboxes: Tensor, pos_gt_labels: Tensor,
  260. cfg: ConfigDict) -> tuple:
  261. """Calculate the ground truth for proposals in the single image
  262. according to the sampling results.
  263. Almost the same as the implementation in `bbox_head`,
  264. we add pos_inds and neg_inds to select positive and
  265. negative samples instead of selecting the first num_pos
  266. as positive samples.
  267. Args:
  268. pos_inds (Tensor): The length is equal to the
  269. positive sample numbers contain all index
  270. of the positive sample in the origin proposal set.
  271. neg_inds (Tensor): The length is equal to the
  272. negative sample numbers contain all index
  273. of the negative sample in the origin proposal set.
  274. pos_priors (Tensor): Contains all the positive boxes,
  275. has shape (num_pos, 4), the last dimension 4
  276. represents [tl_x, tl_y, br_x, br_y].
  277. neg_priors (Tensor): Contains all the negative boxes,
  278. has shape (num_neg, 4), the last dimension 4
  279. represents [tl_x, tl_y, br_x, br_y].
  280. pos_gt_bboxes (Tensor): Contains gt_boxes for
  281. all positive samples, has shape (num_pos, 4),
  282. the last dimension 4
  283. represents [tl_x, tl_y, br_x, br_y].
  284. pos_gt_labels (Tensor): Contains gt_labels for
  285. all positive samples, has shape (num_pos, ).
  286. cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
  287. Returns:
  288. Tuple[Tensor]: Ground truth for proposals in a single image.
  289. Containing the following Tensors:
  290. - labels(Tensor): Gt_labels for all proposals, has
  291. shape (num_proposals,).
  292. - label_weights(Tensor): Labels_weights for all proposals, has
  293. shape (num_proposals,).
  294. - bbox_targets(Tensor):Regression target for all proposals, has
  295. shape (num_proposals, 4), the last dimension 4
  296. represents [tl_x, tl_y, br_x, br_y].
  297. - bbox_weights(Tensor):Regression weights for all proposals,
  298. has shape (num_proposals, 4).
  299. """
  300. num_pos = pos_priors.size(0)
  301. num_neg = neg_priors.size(0)
  302. num_samples = num_pos + num_neg
  303. # original implementation uses new_zeros since BG are set to be 0
  304. # now use empty & fill because BG cat_id = num_classes,
  305. # FG cat_id = [0, num_classes-1]
  306. labels = pos_priors.new_full((num_samples, ),
  307. self.num_classes,
  308. dtype=torch.long)
  309. label_weights = pos_priors.new_zeros(num_samples)
  310. bbox_targets = pos_priors.new_zeros(num_samples, 4)
  311. bbox_weights = pos_priors.new_zeros(num_samples, 4)
  312. if num_pos > 0:
  313. labels[pos_inds] = pos_gt_labels
  314. pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
  315. label_weights[pos_inds] = pos_weight
  316. if not self.reg_decoded_bbox:
  317. pos_bbox_targets = self.bbox_coder.encode(
  318. pos_priors, pos_gt_bboxes)
  319. else:
  320. pos_bbox_targets = pos_gt_bboxes
  321. bbox_targets[pos_inds, :] = pos_bbox_targets
  322. bbox_weights[pos_inds, :] = 1
  323. if num_neg > 0:
  324. label_weights[neg_inds] = 1.0
  325. return labels, label_weights, bbox_targets, bbox_weights
  326. def get_targets(self,
  327. sampling_results: List[SamplingResult],
  328. rcnn_train_cfg: ConfigDict,
  329. concat: bool = True) -> tuple:
  330. """Calculate the ground truth for all samples in a batch according to
  331. the sampling_results.
  332. Almost the same as the implementation in bbox_head, we passed
  333. additional parameters pos_inds_list and neg_inds_list to
  334. `_get_targets_single` function.
  335. Args:
  336. sampling_results (List[obj:SamplingResult]): Assign results of
  337. all images in a batch after sampling.
  338. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
  339. concat (bool): Whether to concatenate the results of all
  340. the images in a single batch.
  341. Returns:
  342. Tuple[Tensor]: Ground truth for proposals in a single image.
  343. Containing the following list of Tensors:
  344. - labels (list[Tensor],Tensor): Gt_labels for all
  345. proposals in a batch, each tensor in list has
  346. shape (num_proposals,) when `concat=False`, otherwise just
  347. a single tensor has shape (num_all_proposals,).
  348. - label_weights (list[Tensor]): Labels_weights for
  349. all proposals in a batch, each tensor in list has shape
  350. (num_proposals,) when `concat=False`, otherwise just a
  351. single tensor has shape (num_all_proposals,).
  352. - bbox_targets (list[Tensor],Tensor): Regression target
  353. for all proposals in a batch, each tensor in list has
  354. shape (num_proposals, 4) when `concat=False`, otherwise
  355. just a single tensor has shape (num_all_proposals, 4),
  356. the last dimension 4 represents [tl_x, tl_y, br_x, br_y].
  357. - bbox_weights (list[tensor],Tensor): Regression weights for
  358. all proposals in a batch, each tensor in list has shape
  359. (num_proposals, 4) when `concat=False`, otherwise just a
  360. single tensor has shape (num_all_proposals, 4).
  361. """
  362. pos_inds_list = [res.pos_inds for res in sampling_results]
  363. neg_inds_list = [res.neg_inds for res in sampling_results]
  364. pos_priors_list = [res.pos_priors for res in sampling_results]
  365. neg_priors_list = [res.neg_priors for res in sampling_results]
  366. pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
  367. pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
  368. labels, label_weights, bbox_targets, bbox_weights = multi_apply(
  369. self._get_targets_single,
  370. pos_inds_list,
  371. neg_inds_list,
  372. pos_priors_list,
  373. neg_priors_list,
  374. pos_gt_bboxes_list,
  375. pos_gt_labels_list,
  376. cfg=rcnn_train_cfg)
  377. if concat:
  378. labels = torch.cat(labels, 0)
  379. label_weights = torch.cat(label_weights, 0)
  380. bbox_targets = torch.cat(bbox_targets, 0)
  381. bbox_weights = torch.cat(bbox_weights, 0)
  382. return labels, label_weights, bbox_targets, bbox_weights