htc_roi_head.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import Tensor
  6. from mmdet.models.test_time_augs import merge_aug_masks
  7. from mmdet.registry import MODELS
  8. from mmdet.structures import SampleList
  9. from mmdet.structures.bbox import bbox2roi
  10. from mmdet.utils import InstanceList, OptConfigType
  11. from ..layers import adaptive_avg_pool2d
  12. from ..task_modules.samplers import SamplingResult
  13. from ..utils import empty_instances, unpack_gt_instances
  14. from .cascade_roi_head import CascadeRoIHead
  15. @MODELS.register_module()
  16. class HybridTaskCascadeRoIHead(CascadeRoIHead):
  17. """Hybrid task cascade roi head including one bbox head and one mask head.
  18. https://arxiv.org/abs/1901.07518
  19. Args:
  20. num_stages (int): Number of cascade stages.
  21. stage_loss_weights (list[float]): Loss weight for every stage.
  22. semantic_roi_extractor (:obj:`ConfigDict` or dict, optional):
  23. Config of semantic roi extractor. Defaults to None.
  24. Semantic_head (:obj:`ConfigDict` or dict, optional):
  25. Config of semantic head. Defaults to None.
  26. interleaved (bool): Whether to interleaves the box branch and mask
  27. branch. If True, the mask branch can take the refined bounding
  28. box predictions. Defaults to True.
  29. mask_info_flow (bool): Whether to turn on the mask information flow,
  30. which means that feeding the mask features of the preceding stage
  31. to the current stage. Defaults to True.
  32. """
  33. def __init__(self,
  34. num_stages: int,
  35. stage_loss_weights: List[float],
  36. semantic_roi_extractor: OptConfigType = None,
  37. semantic_head: OptConfigType = None,
  38. semantic_fusion: Tuple[str] = ('bbox', 'mask'),
  39. interleaved: bool = True,
  40. mask_info_flow: bool = True,
  41. **kwargs) -> None:
  42. super().__init__(
  43. num_stages=num_stages,
  44. stage_loss_weights=stage_loss_weights,
  45. **kwargs)
  46. assert self.with_bbox
  47. assert not self.with_shared_head # shared head is not supported
  48. if semantic_head is not None:
  49. self.semantic_roi_extractor = MODELS.build(semantic_roi_extractor)
  50. self.semantic_head = MODELS.build(semantic_head)
  51. self.semantic_fusion = semantic_fusion
  52. self.interleaved = interleaved
  53. self.mask_info_flow = mask_info_flow
  54. # TODO move to base_roi_head later
  55. @property
  56. def with_semantic(self) -> bool:
  57. """bool: whether the head has semantic head"""
  58. return hasattr(self,
  59. 'semantic_head') and self.semantic_head is not None
  60. def _bbox_forward(
  61. self,
  62. stage: int,
  63. x: Tuple[Tensor],
  64. rois: Tensor,
  65. semantic_feat: Optional[Tensor] = None) -> Dict[str, Tensor]:
  66. """Box head forward function used in both training and testing.
  67. Args:
  68. stage (int): The current stage in Cascade RoI Head.
  69. x (tuple[Tensor]): List of multi-level img features.
  70. rois (Tensor): RoIs with the shape (n, 5) where the first
  71. column indicates batch id of each RoI.
  72. semantic_feat (Tensor, optional): Semantic feature. Defaults to
  73. None.
  74. Returns:
  75. dict[str, Tensor]: Usually returns a dictionary with keys:
  76. - `cls_score` (Tensor): Classification scores.
  77. - `bbox_pred` (Tensor): Box energies / deltas.
  78. - `bbox_feats` (Tensor): Extract bbox RoI features.
  79. """
  80. bbox_roi_extractor = self.bbox_roi_extractor[stage]
  81. bbox_head = self.bbox_head[stage]
  82. bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
  83. rois)
  84. if self.with_semantic and 'bbox' in self.semantic_fusion:
  85. bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
  86. rois)
  87. if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
  88. bbox_semantic_feat = adaptive_avg_pool2d(
  89. bbox_semantic_feat, bbox_feats.shape[-2:])
  90. bbox_feats += bbox_semantic_feat
  91. cls_score, bbox_pred = bbox_head(bbox_feats)
  92. bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred)
  93. return bbox_results
  94. def bbox_loss(self,
  95. stage: int,
  96. x: Tuple[Tensor],
  97. sampling_results: List[SamplingResult],
  98. semantic_feat: Optional[Tensor] = None) -> dict:
  99. """Run forward function and calculate loss for box head in training.
  100. Args:
  101. stage (int): The current stage in Cascade RoI Head.
  102. x (tuple[Tensor]): List of multi-level img features.
  103. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  104. semantic_feat (Tensor, optional): Semantic feature. Defaults to
  105. None.
  106. Returns:
  107. dict: Usually returns a dictionary with keys:
  108. - `cls_score` (Tensor): Classification scores.
  109. - `bbox_pred` (Tensor): Box energies / deltas.
  110. - `bbox_feats` (Tensor): Extract bbox RoI features.
  111. - `loss_bbox` (dict): A dictionary of bbox loss components.
  112. - `rois` (Tensor): RoIs with the shape (n, 5) where the first
  113. column indicates batch id of each RoI.
  114. - `bbox_targets` (tuple): Ground truth for proposals in a
  115. single image. Containing the following list of Tensors:
  116. (labels, label_weights, bbox_targets, bbox_weights)
  117. """
  118. bbox_head = self.bbox_head[stage]
  119. rois = bbox2roi([res.priors for res in sampling_results])
  120. bbox_results = self._bbox_forward(
  121. stage, x, rois, semantic_feat=semantic_feat)
  122. bbox_results.update(rois=rois)
  123. bbox_loss_and_target = bbox_head.loss_and_target(
  124. cls_score=bbox_results['cls_score'],
  125. bbox_pred=bbox_results['bbox_pred'],
  126. rois=rois,
  127. sampling_results=sampling_results,
  128. rcnn_train_cfg=self.train_cfg[stage])
  129. bbox_results.update(bbox_loss_and_target)
  130. return bbox_results
  131. def _mask_forward(self,
  132. stage: int,
  133. x: Tuple[Tensor],
  134. rois: Tensor,
  135. semantic_feat: Optional[Tensor] = None,
  136. training: bool = True) -> Dict[str, Tensor]:
  137. """Mask head forward function used only in training.
  138. Args:
  139. stage (int): The current stage in Cascade RoI Head.
  140. x (tuple[Tensor]): Tuple of multi-level img features.
  141. rois (Tensor): RoIs with the shape (n, 5) where the first
  142. column indicates batch id of each RoI.
  143. semantic_feat (Tensor, optional): Semantic feature. Defaults to
  144. None.
  145. training (bool): Mask Forward is different between training and
  146. testing. If True, use the mask forward in training.
  147. Defaults to True.
  148. Returns:
  149. dict: Usually returns a dictionary with keys:
  150. - `mask_preds` (Tensor): Mask prediction.
  151. """
  152. mask_roi_extractor = self.mask_roi_extractor[stage]
  153. mask_head = self.mask_head[stage]
  154. mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
  155. rois)
  156. # semantic feature fusion
  157. # element-wise sum for original features and pooled semantic features
  158. if self.with_semantic and 'mask' in self.semantic_fusion:
  159. mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
  160. rois)
  161. if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
  162. mask_semantic_feat = F.adaptive_avg_pool2d(
  163. mask_semantic_feat, mask_feats.shape[-2:])
  164. mask_feats = mask_feats + mask_semantic_feat
  165. # mask information flow
  166. # forward all previous mask heads to obtain last_feat, and fuse it
  167. # with the normal mask feature
  168. if training:
  169. if self.mask_info_flow:
  170. last_feat = None
  171. for i in range(stage):
  172. last_feat = self.mask_head[i](
  173. mask_feats, last_feat, return_logits=False)
  174. mask_preds = mask_head(
  175. mask_feats, last_feat, return_feat=False)
  176. else:
  177. mask_preds = mask_head(mask_feats, return_feat=False)
  178. mask_results = dict(mask_preds=mask_preds)
  179. else:
  180. aug_masks = []
  181. last_feat = None
  182. for i in range(self.num_stages):
  183. mask_head = self.mask_head[i]
  184. if self.mask_info_flow:
  185. mask_preds, last_feat = mask_head(mask_feats, last_feat)
  186. else:
  187. mask_preds = mask_head(mask_feats)
  188. aug_masks.append(mask_preds)
  189. mask_results = dict(mask_preds=aug_masks)
  190. return mask_results
  191. def mask_loss(self,
  192. stage: int,
  193. x: Tuple[Tensor],
  194. sampling_results: List[SamplingResult],
  195. batch_gt_instances: InstanceList,
  196. semantic_feat: Optional[Tensor] = None) -> dict:
  197. """Run forward function and calculate loss for mask head in training.
  198. Args:
  199. stage (int): The current stage in Cascade RoI Head.
  200. x (tuple[Tensor]): Tuple of multi-level img features.
  201. sampling_results (list["obj:`SamplingResult`]): Sampling results.
  202. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  203. gt_instance. It usually includes ``bboxes``, ``labels``, and
  204. ``masks`` attributes.
  205. semantic_feat (Tensor, optional): Semantic feature. Defaults to
  206. None.
  207. Returns:
  208. dict: Usually returns a dictionary with keys:
  209. - `mask_preds` (Tensor): Mask prediction.
  210. - `loss_mask` (dict): A dictionary of mask loss components.
  211. """
  212. pos_rois = bbox2roi([res.pos_priors for res in sampling_results])
  213. mask_results = self._mask_forward(
  214. stage=stage,
  215. x=x,
  216. rois=pos_rois,
  217. semantic_feat=semantic_feat,
  218. training=True)
  219. mask_head = self.mask_head[stage]
  220. mask_loss_and_target = mask_head.loss_and_target(
  221. mask_preds=mask_results['mask_preds'],
  222. sampling_results=sampling_results,
  223. batch_gt_instances=batch_gt_instances,
  224. rcnn_train_cfg=self.train_cfg[stage])
  225. mask_results.update(mask_loss_and_target)
  226. return mask_results
  227. def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  228. batch_data_samples: SampleList) -> dict:
  229. """Perform forward propagation and loss calculation of the detection
  230. roi on the features of the upstream network.
  231. Args:
  232. x (tuple[Tensor]): List of multi-level img features.
  233. rpn_results_list (list[:obj:`InstanceData`]): List of region
  234. proposals.
  235. batch_data_samples (list[:obj:`DetDataSample`]): The batch
  236. data samples. It usually includes information such
  237. as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
  238. Returns:
  239. dict[str, Tensor]: A dictionary of loss components
  240. """
  241. assert len(rpn_results_list) == len(batch_data_samples)
  242. outputs = unpack_gt_instances(batch_data_samples)
  243. batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
  244. = outputs
  245. # semantic segmentation part
  246. # 2 outputs: segmentation prediction and embedded features
  247. losses = dict()
  248. if self.with_semantic:
  249. gt_semantic_segs = [
  250. data_sample.gt_sem_seg.sem_seg
  251. for data_sample in batch_data_samples
  252. ]
  253. gt_semantic_segs = torch.stack(gt_semantic_segs)
  254. semantic_pred, semantic_feat = self.semantic_head(x)
  255. loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_segs)
  256. losses['loss_semantic_seg'] = loss_seg
  257. else:
  258. semantic_feat = None
  259. results_list = rpn_results_list
  260. num_imgs = len(batch_img_metas)
  261. for stage in range(self.num_stages):
  262. self.current_stage = stage
  263. stage_loss_weight = self.stage_loss_weights[stage]
  264. # assign gts and sample proposals
  265. sampling_results = []
  266. bbox_assigner = self.bbox_assigner[stage]
  267. bbox_sampler = self.bbox_sampler[stage]
  268. for i in range(num_imgs):
  269. results = results_list[i]
  270. # rename rpn_results.bboxes to rpn_results.priors
  271. if 'bboxes' in results:
  272. results.priors = results.pop('bboxes')
  273. assign_result = bbox_assigner.assign(
  274. results, batch_gt_instances[i],
  275. batch_gt_instances_ignore[i])
  276. sampling_result = bbox_sampler.sample(
  277. assign_result,
  278. results,
  279. batch_gt_instances[i],
  280. feats=[lvl_feat[i][None] for lvl_feat in x])
  281. sampling_results.append(sampling_result)
  282. # bbox head forward and loss
  283. bbox_results = self.bbox_loss(
  284. stage=stage,
  285. x=x,
  286. sampling_results=sampling_results,
  287. semantic_feat=semantic_feat)
  288. for name, value in bbox_results['loss_bbox'].items():
  289. losses[f's{stage}.{name}'] = (
  290. value * stage_loss_weight if 'loss' in name else value)
  291. # mask head forward and loss
  292. if self.with_mask:
  293. # interleaved execution: use regressed bboxes by the box branch
  294. # to train the mask branch
  295. if self.interleaved:
  296. bbox_head = self.bbox_head[stage]
  297. with torch.no_grad():
  298. results_list = bbox_head.refine_bboxes(
  299. sampling_results, bbox_results, batch_img_metas)
  300. # re-assign and sample 512 RoIs from 512 RoIs
  301. sampling_results = []
  302. for i in range(num_imgs):
  303. results = results_list[i]
  304. # rename rpn_results.bboxes to rpn_results.priors
  305. results.priors = results.pop('bboxes')
  306. assign_result = bbox_assigner.assign(
  307. results, batch_gt_instances[i],
  308. batch_gt_instances_ignore[i])
  309. sampling_result = bbox_sampler.sample(
  310. assign_result,
  311. results,
  312. batch_gt_instances[i],
  313. feats=[lvl_feat[i][None] for lvl_feat in x])
  314. sampling_results.append(sampling_result)
  315. mask_results = self.mask_loss(
  316. stage=stage,
  317. x=x,
  318. sampling_results=sampling_results,
  319. batch_gt_instances=batch_gt_instances,
  320. semantic_feat=semantic_feat)
  321. for name, value in mask_results['loss_mask'].items():
  322. losses[f's{stage}.{name}'] = (
  323. value * stage_loss_weight if 'loss' in name else value)
  324. # refine bboxes (same as Cascade R-CNN)
  325. if stage < self.num_stages - 1 and not self.interleaved:
  326. bbox_head = self.bbox_head[stage]
  327. with torch.no_grad():
  328. results_list = bbox_head.refine_bboxes(
  329. sampling_results=sampling_results,
  330. bbox_results=bbox_results,
  331. batch_img_metas=batch_img_metas)
  332. return losses
  333. def predict(self,
  334. x: Tuple[Tensor],
  335. rpn_results_list: InstanceList,
  336. batch_data_samples: SampleList,
  337. rescale: bool = False) -> InstanceList:
  338. """Perform forward propagation of the roi head and predict detection
  339. results on the features of the upstream network.
  340. Args:
  341. x (tuple[Tensor]): Features from upstream network. Each
  342. has shape (N, C, H, W).
  343. rpn_results_list (list[:obj:`InstanceData`]): list of region
  344. proposals.
  345. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  346. Samples. It usually includes information such as
  347. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  348. rescale (bool): Whether to rescale the results to
  349. the original image. Defaults to False.
  350. Returns:
  351. list[obj:`InstanceData`]: Detection results of each image.
  352. Each item usually contains following keys.
  353. - scores (Tensor): Classification scores, has a shape
  354. (num_instance, )
  355. - labels (Tensor): Labels of bboxes, has a shape
  356. (num_instances, ).
  357. - bboxes (Tensor): Has a shape (num_instances, 4),
  358. the last dimension 4 arrange as (x1, y1, x2, y2).
  359. - masks (Tensor): Has a shape (num_instances, H, W).
  360. """
  361. assert self.with_bbox, 'Bbox head must be implemented.'
  362. batch_img_metas = [
  363. data_samples.metainfo for data_samples in batch_data_samples
  364. ]
  365. if self.with_semantic:
  366. _, semantic_feat = self.semantic_head(x)
  367. else:
  368. semantic_feat = None
  369. # TODO: nms_op in mmcv need be enhanced, the bbox result may get
  370. # difference when not rescale in bbox_head
  371. # If it has the mask branch, the bbox branch does not need
  372. # to be scaled to the original image scale, because the mask
  373. # branch will scale both bbox and mask at the same time.
  374. bbox_rescale = rescale if not self.with_mask else False
  375. results_list = self.predict_bbox(
  376. x=x,
  377. semantic_feat=semantic_feat,
  378. batch_img_metas=batch_img_metas,
  379. rpn_results_list=rpn_results_list,
  380. rcnn_test_cfg=self.test_cfg,
  381. rescale=bbox_rescale)
  382. if self.with_mask:
  383. results_list = self.predict_mask(
  384. x=x,
  385. semantic_heat=semantic_feat,
  386. batch_img_metas=batch_img_metas,
  387. results_list=results_list,
  388. rescale=rescale)
  389. return results_list
  390. def predict_mask(self,
  391. x: Tuple[Tensor],
  392. semantic_heat: Tensor,
  393. batch_img_metas: List[dict],
  394. results_list: InstanceList,
  395. rescale: bool = False) -> InstanceList:
  396. """Perform forward propagation of the mask head and predict detection
  397. results on the features of the upstream network.
  398. Args:
  399. x (tuple[Tensor]): Feature maps of all scale level.
  400. semantic_feat (Tensor): Semantic feature.
  401. batch_img_metas (list[dict]): List of image information.
  402. results_list (list[:obj:`InstanceData`]): Detection results of
  403. each image.
  404. rescale (bool): If True, return boxes in original image space.
  405. Defaults to False.
  406. Returns:
  407. list[:obj:`InstanceData`]: Detection results of each image
  408. after the post process.
  409. Each item usually contains following keys.
  410. - scores (Tensor): Classification scores, has a shape
  411. (num_instance, )
  412. - labels (Tensor): Labels of bboxes, has a shape
  413. (num_instances, ).
  414. - bboxes (Tensor): Has a shape (num_instances, 4),
  415. the last dimension 4 arrange as (x1, y1, x2, y2).
  416. - masks (Tensor): Has a shape (num_instances, H, W).
  417. """
  418. num_imgs = len(batch_img_metas)
  419. bboxes = [res.bboxes for res in results_list]
  420. mask_rois = bbox2roi(bboxes)
  421. if mask_rois.shape[0] == 0:
  422. results_list = empty_instances(
  423. batch_img_metas=batch_img_metas,
  424. device=mask_rois.device,
  425. task_type='mask',
  426. instance_results=results_list,
  427. mask_thr_binary=self.test_cfg.mask_thr_binary)
  428. return results_list
  429. num_mask_rois_per_img = [len(res) for res in results_list]
  430. mask_results = self._mask_forward(
  431. stage=-1,
  432. x=x,
  433. rois=mask_rois,
  434. semantic_feat=semantic_heat,
  435. training=False)
  436. # split batch mask prediction back to each image
  437. aug_masks = [[
  438. mask.sigmoid().detach()
  439. for mask in mask_preds.split(num_mask_rois_per_img, 0)
  440. ] for mask_preds in mask_results['mask_preds']]
  441. merged_masks = []
  442. for i in range(num_imgs):
  443. aug_mask = [mask[i] for mask in aug_masks]
  444. merged_mask = merge_aug_masks(aug_mask, batch_img_metas[i])
  445. merged_masks.append(merged_mask)
  446. results_list = self.mask_head[-1].predict_by_feat(
  447. mask_preds=merged_masks,
  448. results_list=results_list,
  449. batch_img_metas=batch_img_metas,
  450. rcnn_test_cfg=self.test_cfg,
  451. rescale=rescale,
  452. activate_map=True)
  453. return results_list
  454. def forward(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
  455. batch_data_samples: SampleList) -> tuple:
  456. """Network forward process. Usually includes backbone, neck and head
  457. forward without any post-processing.
  458. Args:
  459. x (List[Tensor]): Multi-level features that may have different
  460. resolutions.
  461. rpn_results_list (list[:obj:`InstanceData`]): List of region
  462. proposals.
  463. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
  464. the meta information of each image and corresponding
  465. annotations.
  466. Returns
  467. tuple: A tuple of features from ``bbox_head`` and ``mask_head``
  468. forward.
  469. """
  470. results = ()
  471. batch_img_metas = [
  472. data_samples.metainfo for data_samples in batch_data_samples
  473. ]
  474. num_imgs = len(batch_img_metas)
  475. if self.with_semantic:
  476. _, semantic_feat = self.semantic_head(x)
  477. else:
  478. semantic_feat = None
  479. proposals = [rpn_results.bboxes for rpn_results in rpn_results_list]
  480. num_proposals_per_img = tuple(len(p) for p in proposals)
  481. rois = bbox2roi(proposals)
  482. # bbox head
  483. if self.with_bbox:
  484. rois, cls_scores, bbox_preds = self._refine_roi(
  485. x=x,
  486. rois=rois,
  487. semantic_feat=semantic_feat,
  488. batch_img_metas=batch_img_metas,
  489. num_proposals_per_img=num_proposals_per_img)
  490. results = results + (cls_scores, bbox_preds)
  491. # mask head
  492. if self.with_mask:
  493. rois = torch.cat(rois)
  494. mask_results = self._mask_forward(
  495. stage=-1,
  496. x=x,
  497. rois=rois,
  498. semantic_feat=semantic_feat,
  499. training=False)
  500. aug_masks = [[
  501. mask.sigmoid().detach()
  502. for mask in mask_preds.split(num_proposals_per_img, 0)
  503. ] for mask_preds in mask_results['mask_preds']]
  504. merged_masks = []
  505. for i in range(num_imgs):
  506. aug_mask = [mask[i] for mask in aug_masks]
  507. merged_mask = merge_aug_masks(aug_mask, batch_img_metas[i])
  508. merged_masks.append(merged_mask)
  509. results = results + (merged_masks, )
  510. return results