sparse_roi_head.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Tuple
  3. import torch
  4. from mmengine.config import ConfigDict
  5. from mmengine.structures import InstanceData
  6. from torch import Tensor
  7. from mmdet.models.task_modules.samplers import PseudoSampler
  8. from mmdet.registry import MODELS
  9. from mmdet.structures import SampleList
  10. from mmdet.structures.bbox import bbox2roi
  11. from mmdet.utils import ConfigType, InstanceList, OptConfigType
  12. from ..utils.misc import empty_instances, unpack_gt_instances
  13. from .cascade_roi_head import CascadeRoIHead
  14. @MODELS.register_module()
  15. class SparseRoIHead(CascadeRoIHead):
  16. r"""The RoIHead for `Sparse R-CNN: End-to-End Object Detection with
  17. Learnable Proposals <https://arxiv.org/abs/2011.12450>`_
  18. and `Instances as Queries <http://arxiv.org/abs/2105.01928>`_
  19. Args:
  20. num_stages (int): Number of stage whole iterative process.
  21. Defaults to 6.
  22. stage_loss_weights (Tuple[float]): The loss
  23. weight of each stage. By default all stages have
  24. the same weight 1.
  25. bbox_roi_extractor (:obj:`ConfigDict` or dict): Config of box
  26. roi extractor.
  27. mask_roi_extractor (:obj:`ConfigDict` or dict): Config of mask
  28. roi extractor.
  29. bbox_head (:obj:`ConfigDict` or dict): Config of box head.
  30. mask_head (:obj:`ConfigDict` or dict): Config of mask head.
  31. train_cfg (:obj:`ConfigDict` or dict, Optional): Configuration
  32. information in train stage. Defaults to None.
  33. test_cfg (:obj:`ConfigDict` or dict, Optional): Configuration
  34. information in test stage. Defaults to None.
  35. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  36. dict]): Initialization config dict. Defaults to None.
  37. """
  38. def __init__(self,
  39. num_stages: int = 6,
  40. stage_loss_weights: Tuple[float] = (1, 1, 1, 1, 1, 1),
  41. proposal_feature_channel: int = 256,
  42. bbox_roi_extractor: ConfigType = dict(
  43. type='SingleRoIExtractor',
  44. roi_layer=dict(
  45. type='RoIAlign', output_size=7, sampling_ratio=2),
  46. out_channels=256,
  47. featmap_strides=[4, 8, 16, 32]),
  48. mask_roi_extractor: OptConfigType = None,
  49. bbox_head: ConfigType = dict(
  50. type='DIIHead',
  51. num_classes=80,
  52. num_fcs=2,
  53. num_heads=8,
  54. num_cls_fcs=1,
  55. num_reg_fcs=3,
  56. feedforward_channels=2048,
  57. hidden_channels=256,
  58. dropout=0.0,
  59. roi_feat_size=7,
  60. ffn_act_cfg=dict(type='ReLU', inplace=True)),
  61. mask_head: OptConfigType = None,
  62. train_cfg: OptConfigType = None,
  63. test_cfg: OptConfigType = None,
  64. init_cfg: OptConfigType = None) -> None:
  65. assert bbox_roi_extractor is not None
  66. assert bbox_head is not None
  67. assert len(stage_loss_weights) == num_stages
  68. self.num_stages = num_stages
  69. self.stage_loss_weights = stage_loss_weights
  70. self.proposal_feature_channel = proposal_feature_channel
  71. super().__init__(
  72. num_stages=num_stages,
  73. stage_loss_weights=stage_loss_weights,
  74. bbox_roi_extractor=bbox_roi_extractor,
  75. mask_roi_extractor=mask_roi_extractor,
  76. bbox_head=bbox_head,
  77. mask_head=mask_head,
  78. train_cfg=train_cfg,
  79. test_cfg=test_cfg,
  80. init_cfg=init_cfg)
  81. # train_cfg would be None when run the test.py
  82. if train_cfg is not None:
  83. for stage in range(num_stages):
  84. assert isinstance(self.bbox_sampler[stage], PseudoSampler), \
  85. 'Sparse R-CNN and QueryInst only support `PseudoSampler`'
  86. def bbox_loss(self, stage: int, x: Tuple[Tensor],
  87. results_list: InstanceList, object_feats: Tensor,
  88. batch_img_metas: List[dict],
  89. batch_gt_instances: InstanceList) -> dict:
  90. """Perform forward propagation and loss calculation of the bbox head on
  91. the features of the upstream network.
  92. Args:
  93. stage (int): The current stage in iterative process.
  94. x (tuple[Tensor]): List of multi-level img features.
  95. results_list (List[:obj:`InstanceData`]) : List of region
  96. proposals.
  97. object_feats (Tensor): The object feature extracted from
  98. the previous stage.
  99. batch_img_metas (list[dict]): Meta information of each image.
  100. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  101. gt_instance. It usually includes ``bboxes``, ``labels``, and
  102. ``masks`` attributes.
  103. Returns:
  104. dict[str, Tensor]: Usually returns a dictionary with keys:
  105. - `cls_score` (Tensor): Classification scores.
  106. - `bbox_pred` (Tensor): Box energies / deltas.
  107. - `bbox_feats` (Tensor): Extract bbox RoI features.
  108. - `loss_bbox` (dict): A dictionary of bbox loss components.
  109. """
  110. proposal_list = [res.bboxes for res in results_list]
  111. rois = bbox2roi(proposal_list)
  112. bbox_results = self._bbox_forward(stage, x, rois, object_feats,
  113. batch_img_metas)
  114. imgs_whwh = torch.cat(
  115. [res.imgs_whwh[None, ...] for res in results_list])
  116. cls_pred_list = bbox_results['detached_cls_scores']
  117. proposal_list = bbox_results['detached_proposals']
  118. sampling_results = []
  119. bbox_head = self.bbox_head[stage]
  120. for i in range(len(batch_img_metas)):
  121. pred_instances = InstanceData()
  122. # TODO: Enhance the logic
  123. pred_instances.bboxes = proposal_list[i] # for assinger
  124. pred_instances.scores = cls_pred_list[i]
  125. pred_instances.priors = proposal_list[i] # for sampler
  126. assign_result = self.bbox_assigner[stage].assign(
  127. pred_instances=pred_instances,
  128. gt_instances=batch_gt_instances[i],
  129. gt_instances_ignore=None,
  130. img_meta=batch_img_metas[i])
  131. sampling_result = self.bbox_sampler[stage].sample(
  132. assign_result, pred_instances, batch_gt_instances[i])
  133. sampling_results.append(sampling_result)
  134. bbox_results.update(sampling_results=sampling_results)
  135. cls_score = bbox_results['cls_score']
  136. decoded_bboxes = bbox_results['decoded_bboxes']
  137. cls_score = cls_score.view(-1, cls_score.size(-1))
  138. decoded_bboxes = decoded_bboxes.view(-1, 4)
  139. bbox_loss_and_target = bbox_head.loss_and_target(
  140. cls_score,
  141. decoded_bboxes,
  142. sampling_results,
  143. self.train_cfg[stage],
  144. imgs_whwh=imgs_whwh,
  145. concat=True)
  146. bbox_results.update(bbox_loss_and_target)
  147. # propose for the new proposal_list
  148. proposal_list = []
  149. for idx in range(len(batch_img_metas)):
  150. results = InstanceData()
  151. results.imgs_whwh = results_list[idx].imgs_whwh
  152. results.bboxes = bbox_results['detached_proposals'][idx]
  153. proposal_list.append(results)
  154. bbox_results.update(results_list=proposal_list)
  155. return bbox_results
  156. def _bbox_forward(self, stage: int, x: Tuple[Tensor], rois: Tensor,
  157. object_feats: Tensor,
  158. batch_img_metas: List[dict]) -> dict:
  159. """Box head forward function used in both training and testing. Returns
  160. all regression, classification results and a intermediate feature.
  161. Args:
  162. stage (int): The current stage in iterative process.
  163. x (tuple[Tensor]): List of multi-level img features.
  164. rois (Tensor): RoIs with the shape (n, 5) where the first
  165. column indicates batch id of each RoI.
  166. Each dimension means (img_index, x1, y1, x2, y2).
  167. object_feats (Tensor): The object feature extracted from
  168. the previous stage.
  169. batch_img_metas (list[dict]): Meta information of each image.
  170. Returns:
  171. dict[str, Tensor]: a dictionary of bbox head outputs,
  172. Containing the following results:
  173. - cls_score (Tensor): The score of each class, has
  174. shape (batch_size, num_proposals, num_classes)
  175. when use focal loss or
  176. (batch_size, num_proposals, num_classes+1)
  177. otherwise.
  178. - decoded_bboxes (Tensor): The regression results
  179. with shape (batch_size, num_proposal, 4).
  180. The last dimension 4 represents
  181. [tl_x, tl_y, br_x, br_y].
  182. - object_feats (Tensor): The object feature extracted
  183. from current stage
  184. - detached_cls_scores (list[Tensor]): The detached
  185. classification results, length is batch_size, and
  186. each tensor has shape (num_proposal, num_classes).
  187. - detached_proposals (list[tensor]): The detached
  188. regression results, length is batch_size, and each
  189. tensor has shape (num_proposal, 4). The last
  190. dimension 4 represents [tl_x, tl_y, br_x, br_y].
  191. """
  192. num_imgs = len(batch_img_metas)
  193. bbox_roi_extractor = self.bbox_roi_extractor[stage]
  194. bbox_head = self.bbox_head[stage]
  195. bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
  196. rois)
  197. cls_score, bbox_pred, object_feats, attn_feats = bbox_head(
  198. bbox_feats, object_feats)
  199. fake_bbox_results = dict(
  200. rois=rois,
  201. bbox_targets=(rois.new_zeros(len(rois), dtype=torch.long), None),
  202. bbox_pred=bbox_pred.view(-1, bbox_pred.size(-1)),
  203. cls_score=cls_score.view(-1, cls_score.size(-1)))
  204. fake_sampling_results = [
  205. InstanceData(pos_is_gt=rois.new_zeros(object_feats.size(1)))
  206. for _ in range(len(batch_img_metas))
  207. ]
  208. results_list = bbox_head.refine_bboxes(
  209. sampling_results=fake_sampling_results,
  210. bbox_results=fake_bbox_results,
  211. batch_img_metas=batch_img_metas)
  212. proposal_list = [res.bboxes for res in results_list]
  213. bbox_results = dict(
  214. cls_score=cls_score,
  215. decoded_bboxes=torch.cat(proposal_list),
  216. object_feats=object_feats,
  217. attn_feats=attn_feats,
  218. # detach then use it in label assign
  219. detached_cls_scores=[
  220. cls_score[i].detach() for i in range(num_imgs)
  221. ],
  222. detached_proposals=[item.detach() for item in proposal_list])
  223. return bbox_results
  224. def _mask_forward(self, stage: int, x: Tuple[Tensor], rois: Tensor,
  225. attn_feats) -> dict:
  226. """Mask head forward function used in both training and testing.
  227. Args:
  228. stage (int): The current stage in Cascade RoI Head.
  229. x (tuple[Tensor]): Tuple of multi-level img features.
  230. rois (Tensor): RoIs with the shape (n, 5) where the first
  231. column indicates batch id of each RoI.
  232. attn_feats (Tensot): Intermediate feature get from the last
  233. diihead, has shape
  234. (batch_size*num_proposals, feature_dimensions)
  235. Returns:
  236. dict: Usually returns a dictionary with keys:
  237. - `mask_preds` (Tensor): Mask prediction.
  238. """
  239. mask_roi_extractor = self.mask_roi_extractor[stage]
  240. mask_head = self.mask_head[stage]
  241. mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
  242. rois)
  243. # do not support caffe_c4 model anymore
  244. mask_preds = mask_head(mask_feats, attn_feats)
  245. mask_results = dict(mask_preds=mask_preds)
  246. return mask_results
  247. def mask_loss(self, stage: int, x: Tuple[Tensor], bbox_results: dict,
  248. batch_gt_instances: InstanceList,
  249. rcnn_train_cfg: ConfigDict) -> dict:
  250. """Run forward function and calculate loss for mask head in training.
  251. Args:
  252. stage (int): The current stage in Cascade RoI Head.
  253. x (tuple[Tensor]): Tuple of multi-level img features.
  254. bbox_results (dict): Results obtained from `bbox_loss`.
  255. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  256. gt_instance. It usually includes ``bboxes``, ``labels``, and
  257. ``masks`` attributes.
  258. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
  259. Returns:
  260. dict: Usually returns a dictionary with keys:
  261. - `mask_preds` (Tensor): Mask prediction.
  262. - `loss_mask` (dict): A dictionary of mask loss components.
  263. """
  264. attn_feats = bbox_results['attn_feats']
  265. sampling_results = bbox_results['sampling_results']
  266. pos_rois = bbox2roi([res.pos_priors for res in sampling_results])
  267. attn_feats = torch.cat([
  268. feats[res.pos_inds]
  269. for (feats, res) in zip(attn_feats, sampling_results)
  270. ])
  271. mask_results = self._mask_forward(stage, x, pos_rois, attn_feats)
  272. mask_loss_and_target = self.mask_head[stage].loss_and_target(
  273. mask_preds=mask_results['mask_preds'],
  274. sampling_results=sampling_results,
  275. batch_gt_instances=batch_gt_instances,
  276. rcnn_train_cfg=rcnn_train_cfg)
  277. mask_results.update(mask_loss_and_target)
  278. return mask_results
  279. def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  280. batch_data_samples: SampleList) -> dict:
  281. """Perform forward propagation and loss calculation of the detection
  282. roi on the features of the upstream network.
  283. Args:
  284. x (tuple[Tensor]): List of multi-level img features.
  285. rpn_results_list (List[:obj:`InstanceData`]): List of region
  286. proposals.
  287. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  288. data samples. It usually includes information such
  289. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  290. Returns:
  291. dict: a dictionary of loss components of all stage.
  292. """
  293. outputs = unpack_gt_instances(batch_data_samples)
  294. batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
  295. = outputs
  296. object_feats = torch.cat(
  297. [res.pop('features')[None, ...] for res in rpn_results_list])
  298. results_list = rpn_results_list
  299. losses = {}
  300. for stage in range(self.num_stages):
  301. stage_loss_weight = self.stage_loss_weights[stage]
  302. # bbox head forward and loss
  303. bbox_results = self.bbox_loss(
  304. stage=stage,
  305. x=x,
  306. object_feats=object_feats,
  307. results_list=results_list,
  308. batch_img_metas=batch_img_metas,
  309. batch_gt_instances=batch_gt_instances)
  310. for name, value in bbox_results['loss_bbox'].items():
  311. losses[f's{stage}.{name}'] = (
  312. value * stage_loss_weight if 'loss' in name else value)
  313. if self.with_mask:
  314. mask_results = self.mask_loss(
  315. stage=stage,
  316. x=x,
  317. bbox_results=bbox_results,
  318. batch_gt_instances=batch_gt_instances,
  319. rcnn_train_cfg=self.train_cfg[stage])
  320. for name, value in mask_results['loss_mask'].items():
  321. losses[f's{stage}.{name}'] = (
  322. value * stage_loss_weight if 'loss' in name else value)
  323. object_feats = bbox_results['object_feats']
  324. results_list = bbox_results['results_list']
  325. return losses
  326. def predict_bbox(self,
  327. x: Tuple[Tensor],
  328. batch_img_metas: List[dict],
  329. rpn_results_list: InstanceList,
  330. rcnn_test_cfg: ConfigType,
  331. rescale: bool = False) -> InstanceList:
  332. """Perform forward propagation of the bbox head and predict detection
  333. results on the features of the upstream network.
  334. Args:
  335. x(tuple[Tensor]): Feature maps of all scale level.
  336. batch_img_metas (list[dict]): List of image information.
  337. rpn_results_list (list[:obj:`InstanceData`]): List of region
  338. proposals.
  339. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
  340. rescale (bool): If True, return boxes in original image space.
  341. Defaults to False.
  342. Returns:
  343. list[:obj:`InstanceData`]: Detection results of each image
  344. after the post process.
  345. Each item usually contains following keys.
  346. - scores (Tensor): Classification scores, has a shape
  347. (num_instance, )
  348. - labels (Tensor): Labels of bboxes, has a shape
  349. (num_instances, ).
  350. - bboxes (Tensor): Has a shape (num_instances, 4),
  351. the last dimension 4 arrange as (x1, y1, x2, y2).
  352. """
  353. proposal_list = [res.bboxes for res in rpn_results_list]
  354. object_feats = torch.cat(
  355. [res.pop('features')[None, ...] for res in rpn_results_list])
  356. if all([proposal.shape[0] == 0 for proposal in proposal_list]):
  357. # There is no proposal in the whole batch
  358. return empty_instances(
  359. batch_img_metas, x[0].device, task_type='bbox')
  360. for stage in range(self.num_stages):
  361. rois = bbox2roi(proposal_list)
  362. bbox_results = self._bbox_forward(stage, x, rois, object_feats,
  363. batch_img_metas)
  364. object_feats = bbox_results['object_feats']
  365. cls_score = bbox_results['cls_score']
  366. proposal_list = bbox_results['detached_proposals']
  367. num_classes = self.bbox_head[-1].num_classes
  368. if self.bbox_head[-1].loss_cls.use_sigmoid:
  369. cls_score = cls_score.sigmoid()
  370. else:
  371. cls_score = cls_score.softmax(-1)[..., :-1]
  372. topk_inds_list = []
  373. results_list = []
  374. for img_id in range(len(batch_img_metas)):
  375. cls_score_per_img = cls_score[img_id]
  376. scores_per_img, topk_inds = cls_score_per_img.flatten(0, 1).topk(
  377. self.test_cfg.max_per_img, sorted=False)
  378. labels_per_img = topk_inds % num_classes
  379. bboxes_per_img = proposal_list[img_id][topk_inds // num_classes]
  380. topk_inds_list.append(topk_inds)
  381. if rescale and bboxes_per_img.size(0) > 0:
  382. assert batch_img_metas[img_id].get('scale_factor') is not None
  383. scale_factor = bboxes_per_img.new_tensor(
  384. batch_img_metas[img_id]['scale_factor']).repeat((1, 2))
  385. bboxes_per_img = (
  386. bboxes_per_img.view(bboxes_per_img.size(0), -1, 4) /
  387. scale_factor).view(bboxes_per_img.size()[0], -1)
  388. results = InstanceData()
  389. results.bboxes = bboxes_per_img
  390. results.scores = scores_per_img
  391. results.labels = labels_per_img
  392. results_list.append(results)
  393. if self.with_mask:
  394. for img_id in range(len(batch_img_metas)):
  395. # add positive information in InstanceData to predict
  396. # mask results in `mask_head`.
  397. proposals = bbox_results['detached_proposals'][img_id]
  398. topk_inds = topk_inds_list[img_id]
  399. attn_feats = bbox_results['attn_feats'][img_id]
  400. results_list[img_id].proposals = proposals
  401. results_list[img_id].topk_inds = topk_inds
  402. results_list[img_id].attn_feats = attn_feats
  403. return results_list
  404. def predict_mask(self,
  405. x: Tuple[Tensor],
  406. batch_img_metas: List[dict],
  407. results_list: InstanceList,
  408. rescale: bool = False) -> InstanceList:
  409. """Perform forward propagation of the mask head and predict detection
  410. results on the features of the upstream network.
  411. Args:
  412. x (tuple[Tensor]): Feature maps of all scale level.
  413. batch_img_metas (list[dict]): List of image information.
  414. results_list (list[:obj:`InstanceData`]): Detection results of
  415. each image. Each item usually contains following keys:
  416. - scores (Tensor): Classification scores, has a shape
  417. (num_instance, )
  418. - labels (Tensor): Labels of bboxes, has a shape
  419. (num_instances, ).
  420. - bboxes (Tensor): Has a shape (num_instances, 4),
  421. the last dimension 4 arrange as (x1, y1, x2, y2).
  422. - proposal (Tensor): Bboxes predicted from bbox_head,
  423. has a shape (num_instances, 4).
  424. - topk_inds (Tensor): Topk indices of each image, has
  425. shape (num_instances, )
  426. - attn_feats (Tensor): Intermediate feature get from the last
  427. diihead, has shape (num_instances, feature_dimensions)
  428. rescale (bool): If True, return boxes in original image space.
  429. Defaults to False.
  430. Returns:
  431. list[:obj:`InstanceData`]: Detection results of each image
  432. after the post process.
  433. Each item usually contains following keys.
  434. - scores (Tensor): Classification scores, has a shape
  435. (num_instance, )
  436. - labels (Tensor): Labels of bboxes, has a shape
  437. (num_instances, ).
  438. - bboxes (Tensor): Has a shape (num_instances, 4),
  439. the last dimension 4 arrange as (x1, y1, x2, y2).
  440. - masks (Tensor): Has a shape (num_instances, H, W).
  441. """
  442. proposal_list = [res.pop('proposals') for res in results_list]
  443. topk_inds_list = [res.pop('topk_inds') for res in results_list]
  444. attn_feats = torch.cat(
  445. [res.pop('attn_feats')[None, ...] for res in results_list])
  446. rois = bbox2roi(proposal_list)
  447. if rois.shape[0] == 0:
  448. results_list = empty_instances(
  449. batch_img_metas,
  450. rois.device,
  451. task_type='mask',
  452. instance_results=results_list,
  453. mask_thr_binary=self.test_cfg.mask_thr_binary)
  454. return results_list
  455. last_stage = self.num_stages - 1
  456. mask_results = self._mask_forward(last_stage, x, rois, attn_feats)
  457. num_imgs = len(batch_img_metas)
  458. mask_results['mask_preds'] = mask_results['mask_preds'].reshape(
  459. num_imgs, -1, *mask_results['mask_preds'].size()[1:])
  460. num_classes = self.bbox_head[-1].num_classes
  461. mask_preds = []
  462. for img_id in range(num_imgs):
  463. topk_inds = topk_inds_list[img_id]
  464. masks_per_img = mask_results['mask_preds'][img_id].flatten(
  465. 0, 1)[topk_inds]
  466. masks_per_img = masks_per_img[:, None,
  467. ...].repeat(1, num_classes, 1, 1)
  468. mask_preds.append(masks_per_img)
  469. results_list = self.mask_head[-1].predict_by_feat(
  470. mask_preds,
  471. results_list,
  472. batch_img_metas,
  473. rcnn_test_cfg=self.test_cfg,
  474. rescale=rescale)
  475. return results_list
  476. # TODO: Need to refactor later
  477. def forward(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  478. batch_data_samples: SampleList) -> tuple:
  479. """Network forward process. Usually includes backbone, neck and head
  480. forward without any post-processing.
  481. Args:
  482. x (List[Tensor]): Multi-level features that may have different
  483. resolutions.
  484. rpn_results_list (List[:obj:`InstanceData`]): List of region
  485. proposals.
  486. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  487. data samples. It usually includes information such
  488. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  489. Returns
  490. tuple: A tuple of features from ``bbox_head`` and ``mask_head``
  491. forward.
  492. """
  493. outputs = unpack_gt_instances(batch_data_samples)
  494. (batch_gt_instances, batch_gt_instances_ignore,
  495. batch_img_metas) = outputs
  496. all_stage_bbox_results = []
  497. object_feats = torch.cat(
  498. [res.pop('features')[None, ...] for res in rpn_results_list])
  499. results_list = rpn_results_list
  500. if self.with_bbox:
  501. for stage in range(self.num_stages):
  502. bbox_results = self.bbox_loss(
  503. stage=stage,
  504. x=x,
  505. results_list=results_list,
  506. object_feats=object_feats,
  507. batch_img_metas=batch_img_metas,
  508. batch_gt_instances=batch_gt_instances)
  509. bbox_results.pop('loss_bbox')
  510. # torch.jit does not support obj:SamplingResult
  511. bbox_results.pop('results_list')
  512. bbox_res = bbox_results.copy()
  513. bbox_res.pop('sampling_results')
  514. all_stage_bbox_results.append((bbox_res, ))
  515. if self.with_mask:
  516. attn_feats = bbox_results['attn_feats']
  517. sampling_results = bbox_results['sampling_results']
  518. pos_rois = bbox2roi(
  519. [res.pos_priors for res in sampling_results])
  520. attn_feats = torch.cat([
  521. feats[res.pos_inds]
  522. for (feats, res) in zip(attn_feats, sampling_results)
  523. ])
  524. mask_results = self._mask_forward(stage, x, pos_rois,
  525. attn_feats)
  526. all_stage_bbox_results[-1] += (mask_results, )
  527. return tuple(all_stage_bbox_results)