scnet_roi_head.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from mmengine.structures import InstanceData
  6. from torch import Tensor
  7. from mmdet.registry import MODELS
  8. from mmdet.structures import SampleList
  9. from mmdet.structures.bbox import bbox2roi
  10. from mmdet.utils import ConfigType, InstanceList, OptConfigType
  11. from ..layers import adaptive_avg_pool2d
  12. from ..task_modules.samplers import SamplingResult
  13. from ..utils import empty_instances, unpack_gt_instances
  14. from .cascade_roi_head import CascadeRoIHead
  15. @MODELS.register_module()
  16. class SCNetRoIHead(CascadeRoIHead):
  17. """RoIHead for `SCNet <https://arxiv.org/abs/2012.10150>`_.
  18. Args:
  19. num_stages (int): number of cascade stages.
  20. stage_loss_weights (list): loss weight of cascade stages.
  21. semantic_roi_extractor (dict): config to init semantic roi extractor.
  22. semantic_head (dict): config to init semantic head.
  23. feat_relay_head (dict): config to init feature_relay_head.
  24. glbctx_head (dict): config to init global context head.
  25. """
  26. def __init__(self,
  27. num_stages: int,
  28. stage_loss_weights: List[float],
  29. semantic_roi_extractor: OptConfigType = None,
  30. semantic_head: OptConfigType = None,
  31. feat_relay_head: OptConfigType = None,
  32. glbctx_head: OptConfigType = None,
  33. **kwargs) -> None:
  34. super().__init__(
  35. num_stages=num_stages,
  36. stage_loss_weights=stage_loss_weights,
  37. **kwargs)
  38. assert self.with_bbox and self.with_mask
  39. assert not self.with_shared_head # shared head is not supported
  40. if semantic_head is not None:
  41. self.semantic_roi_extractor = MODELS.build(semantic_roi_extractor)
  42. self.semantic_head = MODELS.build(semantic_head)
  43. if feat_relay_head is not None:
  44. self.feat_relay_head = MODELS.build(feat_relay_head)
  45. if glbctx_head is not None:
  46. self.glbctx_head = MODELS.build(glbctx_head)
  47. def init_mask_head(self, mask_roi_extractor: ConfigType,
  48. mask_head: ConfigType) -> None:
  49. """Initialize ``mask_head``"""
  50. if mask_roi_extractor is not None:
  51. self.mask_roi_extractor = MODELS.build(mask_roi_extractor)
  52. self.mask_head = MODELS.build(mask_head)
  53. # TODO move to base_roi_head later
  54. @property
  55. def with_semantic(self) -> bool:
  56. """bool: whether the head has semantic head"""
  57. return hasattr(self,
  58. 'semantic_head') and self.semantic_head is not None
  59. @property
  60. def with_feat_relay(self) -> bool:
  61. """bool: whether the head has feature relay head"""
  62. return (hasattr(self, 'feat_relay_head')
  63. and self.feat_relay_head is not None)
  64. @property
  65. def with_glbctx(self) -> bool:
  66. """bool: whether the head has global context head"""
  67. return hasattr(self, 'glbctx_head') and self.glbctx_head is not None
  68. def _fuse_glbctx(self, roi_feats: Tensor, glbctx_feat: Tensor,
  69. rois: Tensor) -> Tensor:
  70. """Fuse global context feats with roi feats.
  71. Args:
  72. roi_feats (Tensor): RoI features.
  73. glbctx_feat (Tensor): Global context feature..
  74. rois (Tensor): RoIs with the shape (n, 5) where the first
  75. column indicates batch id of each RoI.
  76. Returns:
  77. Tensor: Fused feature.
  78. """
  79. assert roi_feats.size(0) == rois.size(0)
  80. # RuntimeError: isDifferentiableType(variable.scalar_type())
  81. # INTERNAL ASSERT FAILED if detach() is not used when calling
  82. # roi_head.predict().
  83. img_inds = torch.unique(rois[:, 0].detach().cpu(), sorted=True).long()
  84. fused_feats = torch.zeros_like(roi_feats)
  85. for img_id in img_inds:
  86. inds = (rois[:, 0] == img_id.item())
  87. fused_feats[inds] = roi_feats[inds] + glbctx_feat[img_id]
  88. return fused_feats
  89. def _slice_pos_feats(self, feats: Tensor,
  90. sampling_results: List[SamplingResult]) -> Tensor:
  91. """Get features from pos rois.
  92. Args:
  93. feats (Tensor): Input features.
  94. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  95. Returns:
  96. Tensor: Sliced features.
  97. """
  98. num_rois = [res.priors.size(0) for res in sampling_results]
  99. num_pos_rois = [res.pos_priors.size(0) for res in sampling_results]
  100. inds = torch.zeros(sum(num_rois), dtype=torch.bool)
  101. start = 0
  102. for i in range(len(num_rois)):
  103. start = 0 if i == 0 else start + num_rois[i - 1]
  104. stop = start + num_pos_rois[i]
  105. inds[start:stop] = 1
  106. sliced_feats = feats[inds]
  107. return sliced_feats
  108. def _bbox_forward(self,
  109. stage: int,
  110. x: Tuple[Tensor],
  111. rois: Tensor,
  112. semantic_feat: Optional[Tensor] = None,
  113. glbctx_feat: Optional[Tensor] = None) -> dict:
  114. """Box head forward function used in both training and testing.
  115. Args:
  116. stage (int): The current stage in Cascade RoI Head.
  117. x (tuple[Tensor]): List of multi-level img features.
  118. rois (Tensor): RoIs with the shape (n, 5) where the first
  119. column indicates batch id of each RoI.
  120. semantic_feat (Tensor): Semantic feature. Defaults to None.
  121. glbctx_feat (Tensor): Global context feature. Defaults to None.
  122. Returns:
  123. dict[str, Tensor]: Usually returns a dictionary with keys:
  124. - `cls_score` (Tensor): Classification scores.
  125. - `bbox_pred` (Tensor): Box energies / deltas.
  126. - `bbox_feats` (Tensor): Extract bbox RoI features.
  127. """
  128. bbox_roi_extractor = self.bbox_roi_extractor[stage]
  129. bbox_head = self.bbox_head[stage]
  130. bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
  131. rois)
  132. if self.with_semantic and semantic_feat is not None:
  133. bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
  134. rois)
  135. if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
  136. bbox_semantic_feat = adaptive_avg_pool2d(
  137. bbox_semantic_feat, bbox_feats.shape[-2:])
  138. bbox_feats += bbox_semantic_feat
  139. if self.with_glbctx and glbctx_feat is not None:
  140. bbox_feats = self._fuse_glbctx(bbox_feats, glbctx_feat, rois)
  141. cls_score, bbox_pred, relayed_feat = bbox_head(
  142. bbox_feats, return_shared_feat=True)
  143. bbox_results = dict(
  144. cls_score=cls_score,
  145. bbox_pred=bbox_pred,
  146. relayed_feat=relayed_feat)
  147. return bbox_results
  148. def _mask_forward(self,
  149. x: Tuple[Tensor],
  150. rois: Tensor,
  151. semantic_feat: Optional[Tensor] = None,
  152. glbctx_feat: Optional[Tensor] = None,
  153. relayed_feat: Optional[Tensor] = None) -> dict:
  154. """Mask head forward function used in both training and testing.
  155. Args:
  156. stage (int): The current stage in Cascade RoI Head.
  157. x (tuple[Tensor]): Tuple of multi-level img features.
  158. rois (Tensor): RoIs with the shape (n, 5) where the first
  159. column indicates batch id of each RoI.
  160. semantic_feat (Tensor): Semantic feature. Defaults to None.
  161. glbctx_feat (Tensor): Global context feature. Defaults to None.
  162. relayed_feat (Tensor): Relayed feature. Defaults to None.
  163. Returns:
  164. dict: Usually returns a dictionary with keys:
  165. - `mask_preds` (Tensor): Mask prediction.
  166. """
  167. mask_feats = self.mask_roi_extractor(
  168. x[:self.mask_roi_extractor.num_inputs], rois)
  169. if self.with_semantic and semantic_feat is not None:
  170. mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
  171. rois)
  172. if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
  173. mask_semantic_feat = F.adaptive_avg_pool2d(
  174. mask_semantic_feat, mask_feats.shape[-2:])
  175. mask_feats += mask_semantic_feat
  176. if self.with_glbctx and glbctx_feat is not None:
  177. mask_feats = self._fuse_glbctx(mask_feats, glbctx_feat, rois)
  178. if self.with_feat_relay and relayed_feat is not None:
  179. mask_feats = mask_feats + relayed_feat
  180. mask_preds = self.mask_head(mask_feats)
  181. mask_results = dict(mask_preds=mask_preds)
  182. return mask_results
  183. def bbox_loss(self,
  184. stage: int,
  185. x: Tuple[Tensor],
  186. sampling_results: List[SamplingResult],
  187. semantic_feat: Optional[Tensor] = None,
  188. glbctx_feat: Optional[Tensor] = None) -> dict:
  189. """Run forward function and calculate loss for box head in training.
  190. Args:
  191. stage (int): The current stage in Cascade RoI Head.
  192. x (tuple[Tensor]): List of multi-level img features.
  193. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  194. semantic_feat (Tensor): Semantic feature. Defaults to None.
  195. glbctx_feat (Tensor): Global context feature. Defaults to None.
  196. Returns:
  197. dict: Usually returns a dictionary with keys:
  198. - `cls_score` (Tensor): Classification scores.
  199. - `bbox_pred` (Tensor): Box energies / deltas.
  200. - `bbox_feats` (Tensor): Extract bbox RoI features.
  201. - `loss_bbox` (dict): A dictionary of bbox loss components.
  202. - `rois` (Tensor): RoIs with the shape (n, 5) where the first
  203. column indicates batch id of each RoI.
  204. - `bbox_targets` (tuple): Ground truth for proposals in a
  205. single image. Containing the following list of Tensors:
  206. (labels, label_weights, bbox_targets, bbox_weights)
  207. """
  208. bbox_head = self.bbox_head[stage]
  209. rois = bbox2roi([res.priors for res in sampling_results])
  210. bbox_results = self._bbox_forward(
  211. stage,
  212. x,
  213. rois,
  214. semantic_feat=semantic_feat,
  215. glbctx_feat=glbctx_feat)
  216. bbox_results.update(rois=rois)
  217. bbox_loss_and_target = bbox_head.loss_and_target(
  218. cls_score=bbox_results['cls_score'],
  219. bbox_pred=bbox_results['bbox_pred'],
  220. rois=rois,
  221. sampling_results=sampling_results,
  222. rcnn_train_cfg=self.train_cfg[stage])
  223. bbox_results.update(bbox_loss_and_target)
  224. return bbox_results
  225. def mask_loss(self,
  226. x: Tuple[Tensor],
  227. sampling_results: List[SamplingResult],
  228. batch_gt_instances: InstanceList,
  229. semantic_feat: Optional[Tensor] = None,
  230. glbctx_feat: Optional[Tensor] = None,
  231. relayed_feat: Optional[Tensor] = None) -> dict:
  232. """Run forward function and calculate loss for mask head in training.
  233. Args:
  234. x (tuple[Tensor]): Tuple of multi-level img features.
  235. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  236. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  237. gt_instance. It usually includes ``bboxes``, ``labels``, and
  238. ``masks`` attributes.
  239. semantic_feat (Tensor): Semantic feature. Defaults to None.
  240. glbctx_feat (Tensor): Global context feature. Defaults to None.
  241. relayed_feat (Tensor): Relayed feature. Defaults to None.
  242. Returns:
  243. dict: Usually returns a dictionary with keys:
  244. - `mask_preds` (Tensor): Mask prediction.
  245. - `loss_mask` (dict): A dictionary of mask loss components.
  246. """
  247. pos_rois = bbox2roi([res.pos_priors for res in sampling_results])
  248. mask_results = self._mask_forward(
  249. x,
  250. pos_rois,
  251. semantic_feat=semantic_feat,
  252. glbctx_feat=glbctx_feat,
  253. relayed_feat=relayed_feat)
  254. mask_loss_and_target = self.mask_head.loss_and_target(
  255. mask_preds=mask_results['mask_preds'],
  256. sampling_results=sampling_results,
  257. batch_gt_instances=batch_gt_instances,
  258. rcnn_train_cfg=self.train_cfg[-1])
  259. mask_results.update(mask_loss_and_target)
  260. return mask_results
  261. def semantic_loss(self, x: Tuple[Tensor],
  262. batch_data_samples: SampleList) -> dict:
  263. """Semantic segmentation loss.
  264. Args:
  265. x (Tuple[Tensor]): Tuple of multi-level img features.
  266. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  267. data samples. It usually includes information such
  268. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  269. Returns:
  270. dict: Usually returns a dictionary with keys:
  271. - `semantic_feat` (Tensor): Semantic feature.
  272. - `loss_seg` (dict): Semantic segmentation loss.
  273. """
  274. gt_semantic_segs = [
  275. data_sample.gt_sem_seg.sem_seg
  276. for data_sample in batch_data_samples
  277. ]
  278. gt_semantic_segs = torch.stack(gt_semantic_segs)
  279. semantic_pred, semantic_feat = self.semantic_head(x)
  280. loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_segs)
  281. semantic_results = dict(loss_seg=loss_seg, semantic_feat=semantic_feat)
  282. return semantic_results
  283. def global_context_loss(self, x: Tuple[Tensor],
  284. batch_gt_instances: InstanceList) -> dict:
  285. """Global context loss.
  286. Args:
  287. x (Tuple[Tensor]): Tuple of multi-level img features.
  288. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  289. gt_instance. It usually includes ``bboxes``, ``labels``, and
  290. ``masks`` attributes.
  291. Returns:
  292. dict: Usually returns a dictionary with keys:
  293. - `glbctx_feat` (Tensor): Global context feature.
  294. - `loss_glbctx` (dict): Global context loss.
  295. """
  296. gt_labels = [
  297. gt_instances.labels for gt_instances in batch_gt_instances
  298. ]
  299. mc_pred, glbctx_feat = self.glbctx_head(x)
  300. loss_glbctx = self.glbctx_head.loss(mc_pred, gt_labels)
  301. global_context_results = dict(
  302. loss_glbctx=loss_glbctx, glbctx_feat=glbctx_feat)
  303. return global_context_results
  304. def loss(self, x: Tensor, rpn_results_list: InstanceList,
  305. batch_data_samples: SampleList) -> dict:
  306. """Perform forward propagation and loss calculation of the detection
  307. roi on the features of the upstream network.
  308. Args:
  309. x (tuple[Tensor]): List of multi-level img features.
  310. rpn_results_list (list[:obj:`InstanceData`]): List of region
  311. proposals.
  312. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  313. data samples. It usually includes information such
  314. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  315. Returns:
  316. dict[str, Tensor]: A dictionary of loss components
  317. """
  318. assert len(rpn_results_list) == len(batch_data_samples)
  319. outputs = unpack_gt_instances(batch_data_samples)
  320. batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
  321. = outputs
  322. losses = dict()
  323. # semantic segmentation branch
  324. if self.with_semantic:
  325. semantic_results = self.semantic_loss(
  326. x=x, batch_data_samples=batch_data_samples)
  327. losses['loss_semantic_seg'] = semantic_results['loss_seg']
  328. semantic_feat = semantic_results['semantic_feat']
  329. else:
  330. semantic_feat = None
  331. # global context branch
  332. if self.with_glbctx:
  333. global_context_results = self.global_context_loss(
  334. x=x, batch_gt_instances=batch_gt_instances)
  335. losses['loss_glbctx'] = global_context_results['loss_glbctx']
  336. glbctx_feat = global_context_results['glbctx_feat']
  337. else:
  338. glbctx_feat = None
  339. results_list = rpn_results_list
  340. num_imgs = len(batch_img_metas)
  341. for stage in range(self.num_stages):
  342. stage_loss_weight = self.stage_loss_weights[stage]
  343. # assign gts and sample proposals
  344. sampling_results = []
  345. bbox_assigner = self.bbox_assigner[stage]
  346. bbox_sampler = self.bbox_sampler[stage]
  347. for i in range(num_imgs):
  348. results = results_list[i]
  349. # rename rpn_results.bboxes to rpn_results.priors
  350. results.priors = results.pop('bboxes')
  351. assign_result = bbox_assigner.assign(
  352. results, batch_gt_instances[i],
  353. batch_gt_instances_ignore[i])
  354. sampling_result = bbox_sampler.sample(
  355. assign_result,
  356. results,
  357. batch_gt_instances[i],
  358. feats=[lvl_feat[i][None] for lvl_feat in x])
  359. sampling_results.append(sampling_result)
  360. # bbox head forward and loss
  361. bbox_results = self.bbox_loss(
  362. stage=stage,
  363. x=x,
  364. sampling_results=sampling_results,
  365. semantic_feat=semantic_feat,
  366. glbctx_feat=glbctx_feat)
  367. for name, value in bbox_results['loss_bbox'].items():
  368. losses[f's{stage}.{name}'] = (
  369. value * stage_loss_weight if 'loss' in name else value)
  370. # refine bboxes
  371. if stage < self.num_stages - 1:
  372. bbox_head = self.bbox_head[stage]
  373. with torch.no_grad():
  374. results_list = bbox_head.refine_bboxes(
  375. sampling_results=sampling_results,
  376. bbox_results=bbox_results,
  377. batch_img_metas=batch_img_metas)
  378. if self.with_feat_relay:
  379. relayed_feat = self._slice_pos_feats(bbox_results['relayed_feat'],
  380. sampling_results)
  381. relayed_feat = self.feat_relay_head(relayed_feat)
  382. else:
  383. relayed_feat = None
  384. # mask head forward and loss
  385. mask_results = self.mask_loss(
  386. x=x,
  387. sampling_results=sampling_results,
  388. batch_gt_instances=batch_gt_instances,
  389. semantic_feat=semantic_feat,
  390. glbctx_feat=glbctx_feat,
  391. relayed_feat=relayed_feat)
  392. mask_stage_loss_weight = sum(self.stage_loss_weights)
  393. losses['loss_mask'] = mask_stage_loss_weight * mask_results[
  394. 'loss_mask']['loss_mask']
  395. return losses
  396. def predict(self,
  397. x: Tuple[Tensor],
  398. rpn_results_list: InstanceList,
  399. batch_data_samples: SampleList,
  400. rescale: bool = False) -> InstanceList:
  401. """Perform forward propagation of the roi head and predict detection
  402. results on the features of the upstream network.
  403. Args:
  404. x (tuple[Tensor]): Features from upstream network. Each
  405. has shape (N, C, H, W).
  406. rpn_results_list (list[:obj:`InstanceData`]): list of region
  407. proposals.
  408. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  409. Samples. It usually includes information such as
  410. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  411. rescale (bool): Whether to rescale the results to
  412. the original image. Defaults to False.
  413. Returns:
  414. list[obj:`InstanceData`]: Detection results of each image.
  415. Each item usually contains following keys.
  416. - scores (Tensor): Classification scores, has a shape
  417. (num_instance, )
  418. - labels (Tensor): Labels of bboxes, has a shape
  419. (num_instances, ).
  420. - bboxes (Tensor): Has a shape (num_instances, 4),
  421. the last dimension 4 arrange as (x1, y1, x2, y2).
  422. - masks (Tensor): Has a shape (num_instances, H, W).
  423. """
  424. assert self.with_bbox, 'Bbox head must be implemented.'
  425. batch_img_metas = [
  426. data_samples.metainfo for data_samples in batch_data_samples
  427. ]
  428. if self.with_semantic:
  429. _, semantic_feat = self.semantic_head(x)
  430. else:
  431. semantic_feat = None
  432. if self.with_glbctx:
  433. _, glbctx_feat = self.glbctx_head(x)
  434. else:
  435. glbctx_feat = None
  436. # TODO: nms_op in mmcv need be enhanced, the bbox result may get
  437. # difference when not rescale in bbox_head
  438. # If it has the mask branch, the bbox branch does not need
  439. # to be scaled to the original image scale, because the mask
  440. # branch will scale both bbox and mask at the same time.
  441. bbox_rescale = rescale if not self.with_mask else False
  442. results_list = self.predict_bbox(
  443. x=x,
  444. semantic_feat=semantic_feat,
  445. glbctx_feat=glbctx_feat,
  446. batch_img_metas=batch_img_metas,
  447. rpn_results_list=rpn_results_list,
  448. rcnn_test_cfg=self.test_cfg,
  449. rescale=bbox_rescale)
  450. if self.with_mask:
  451. results_list = self.predict_mask(
  452. x=x,
  453. semantic_heat=semantic_feat,
  454. glbctx_feat=glbctx_feat,
  455. batch_img_metas=batch_img_metas,
  456. results_list=results_list,
  457. rescale=rescale)
  458. return results_list
  459. def predict_mask(self,
  460. x: Tuple[Tensor],
  461. semantic_heat: Tensor,
  462. glbctx_feat: Tensor,
  463. batch_img_metas: List[dict],
  464. results_list: List[InstanceData],
  465. rescale: bool = False) -> List[InstanceData]:
  466. """Perform forward propagation of the mask head and predict detection
  467. results on the features of the upstream network.
  468. Args:
  469. x (tuple[Tensor]): Feature maps of all scale level.
  470. semantic_feat (Tensor): Semantic feature.
  471. glbctx_feat (Tensor): Global context feature.
  472. batch_img_metas (list[dict]): List of image information.
  473. results_list (list[:obj:`InstanceData`]): Detection results of
  474. each image.
  475. rescale (bool): If True, return boxes in original image space.
  476. Defaults to False.
  477. Returns:
  478. list[:obj:`InstanceData`]: Detection results of each image
  479. after the post process.
  480. Each item usually contains following keys.
  481. - scores (Tensor): Classification scores, has a shape
  482. (num_instance, )
  483. - labels (Tensor): Labels of bboxes, has a shape
  484. (num_instances, ).
  485. - bboxes (Tensor): Has a shape (num_instances, 4),
  486. the last dimension 4 arrange as (x1, y1, x2, y2).
  487. - masks (Tensor): Has a shape (num_instances, H, W).
  488. """
  489. bboxes = [res.bboxes for res in results_list]
  490. mask_rois = bbox2roi(bboxes)
  491. if mask_rois.shape[0] == 0:
  492. results_list = empty_instances(
  493. batch_img_metas=batch_img_metas,
  494. device=mask_rois.device,
  495. task_type='mask',
  496. instance_results=results_list,
  497. mask_thr_binary=self.test_cfg.mask_thr_binary)
  498. return results_list
  499. bboxes_results = self._bbox_forward(
  500. stage=-1,
  501. x=x,
  502. rois=mask_rois,
  503. semantic_feat=semantic_heat,
  504. glbctx_feat=glbctx_feat)
  505. relayed_feat = bboxes_results['relayed_feat']
  506. relayed_feat = self.feat_relay_head(relayed_feat)
  507. mask_results = self._mask_forward(
  508. x=x,
  509. rois=mask_rois,
  510. semantic_feat=semantic_heat,
  511. glbctx_feat=glbctx_feat,
  512. relayed_feat=relayed_feat)
  513. mask_preds = mask_results['mask_preds']
  514. # split batch mask prediction back to each image
  515. num_bbox_per_img = tuple(len(_bbox) for _bbox in bboxes)
  516. mask_preds = mask_preds.split(num_bbox_per_img, 0)
  517. results_list = self.mask_head.predict_by_feat(
  518. mask_preds=mask_preds,
  519. results_list=results_list,
  520. batch_img_metas=batch_img_metas,
  521. rcnn_test_cfg=self.test_cfg,
  522. rescale=rescale)
  523. return results_list
  524. def forward(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  525. batch_data_samples: SampleList) -> tuple:
  526. """Network forward process. Usually includes backbone, neck and head
  527. forward without any post-processing.
  528. Args:
  529. x (List[Tensor]): Multi-level features that may have different
  530. resolutions.
  531. rpn_results_list (list[:obj:`InstanceData`]): List of region
  532. proposals.
  533. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
  534. the meta information of each image and corresponding
  535. annotations.
  536. Returns
  537. tuple: A tuple of features from ``bbox_head`` and ``mask_head``
  538. forward.
  539. """
  540. results = ()
  541. batch_img_metas = [
  542. data_samples.metainfo for data_samples in batch_data_samples
  543. ]
  544. if self.with_semantic:
  545. _, semantic_feat = self.semantic_head(x)
  546. else:
  547. semantic_feat = None
  548. if self.with_glbctx:
  549. _, glbctx_feat = self.glbctx_head(x)
  550. else:
  551. glbctx_feat = None
  552. proposals = [rpn_results.bboxes for rpn_results in rpn_results_list]
  553. num_proposals_per_img = tuple(len(p) for p in proposals)
  554. rois = bbox2roi(proposals)
  555. # bbox head
  556. if self.with_bbox:
  557. rois, cls_scores, bbox_preds = self._refine_roi(
  558. x=x,
  559. rois=rois,
  560. semantic_feat=semantic_feat,
  561. glbctx_feat=glbctx_feat,
  562. batch_img_metas=batch_img_metas,
  563. num_proposals_per_img=num_proposals_per_img)
  564. results = results + (cls_scores, bbox_preds)
  565. # mask head
  566. if self.with_mask:
  567. rois = torch.cat(rois)
  568. bboxes_results = self._bbox_forward(
  569. stage=-1,
  570. x=x,
  571. rois=rois,
  572. semantic_feat=semantic_feat,
  573. glbctx_feat=glbctx_feat)
  574. relayed_feat = bboxes_results['relayed_feat']
  575. relayed_feat = self.feat_relay_head(relayed_feat)
  576. mask_results = self._mask_forward(
  577. x=x,
  578. rois=rois,
  579. semantic_feat=semantic_feat,
  580. glbctx_feat=glbctx_feat,
  581. relayed_feat=relayed_feat)
  582. mask_preds = mask_results['mask_preds']
  583. mask_preds = mask_preds.split(num_proposals_per_img, 0)
  584. results = results + (mask_preds, )
  585. return results