point_rend_roi_head.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
  3. from typing import List, Tuple
  4. import torch
  5. import torch.nn.functional as F
  6. from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.structures.bbox import bbox2roi
  10. from mmdet.utils import ConfigType, InstanceList
  11. from ..task_modules.samplers import SamplingResult
  12. from ..utils import empty_instances
  13. from .standard_roi_head import StandardRoIHead
  14. @MODELS.register_module()
  15. class PointRendRoIHead(StandardRoIHead):
  16. """`PointRend <https://arxiv.org/abs/1912.08193>`_."""
  17. def __init__(self, point_head: ConfigType, *args, **kwargs) -> None:
  18. super().__init__(*args, **kwargs)
  19. assert self.with_bbox and self.with_mask
  20. self.init_point_head(point_head)
  21. def init_point_head(self, point_head: ConfigType) -> None:
  22. """Initialize ``point_head``"""
  23. self.point_head = MODELS.build(point_head)
  24. def mask_loss(self, x: Tuple[Tensor],
  25. sampling_results: List[SamplingResult], bbox_feats: Tensor,
  26. batch_gt_instances: InstanceList) -> dict:
  27. """Run forward function and calculate loss for mask head and point head
  28. in training."""
  29. mask_results = super().mask_loss(
  30. x=x,
  31. sampling_results=sampling_results,
  32. bbox_feats=bbox_feats,
  33. batch_gt_instances=batch_gt_instances)
  34. mask_point_results = self._mask_point_loss(
  35. x=x,
  36. sampling_results=sampling_results,
  37. mask_preds=mask_results['mask_preds'],
  38. batch_gt_instances=batch_gt_instances)
  39. mask_results['loss_mask'].update(
  40. loss_point=mask_point_results['loss_point'])
  41. return mask_results
  42. def _mask_point_loss(self, x: Tuple[Tensor],
  43. sampling_results: List[SamplingResult],
  44. mask_preds: Tensor,
  45. batch_gt_instances: InstanceList) -> dict:
  46. """Run forward function and calculate loss for point head in
  47. training."""
  48. pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
  49. rel_roi_points = self.point_head.get_roi_rel_points_train(
  50. mask_preds, pos_labels, cfg=self.train_cfg)
  51. rois = bbox2roi([res.pos_bboxes for res in sampling_results])
  52. fine_grained_point_feats = self._get_fine_grained_point_feats(
  53. x, rois, rel_roi_points)
  54. coarse_point_feats = point_sample(mask_preds, rel_roi_points)
  55. mask_point_pred = self.point_head(fine_grained_point_feats,
  56. coarse_point_feats)
  57. loss_and_target = self.point_head.loss_and_target(
  58. point_pred=mask_point_pred,
  59. rel_roi_points=rel_roi_points,
  60. sampling_results=sampling_results,
  61. batch_gt_instances=batch_gt_instances,
  62. cfg=self.train_cfg)
  63. return loss_and_target
  64. def _mask_point_forward_test(self, x: Tuple[Tensor], rois: Tensor,
  65. label_preds: Tensor,
  66. mask_preds: Tensor) -> Tensor:
  67. """Mask refining process with point head in testing.
  68. Args:
  69. x (tuple[Tensor]): Feature maps of all scale level.
  70. rois (Tensor): shape (num_rois, 5).
  71. label_preds (Tensor): The predication class for each rois.
  72. mask_preds (Tensor): The predication coarse masks of
  73. shape (num_rois, num_classes, small_size, small_size).
  74. Returns:
  75. Tensor: The refined masks of shape (num_rois, num_classes,
  76. large_size, large_size).
  77. """
  78. refined_mask_pred = mask_preds.clone()
  79. for subdivision_step in range(self.test_cfg.subdivision_steps):
  80. refined_mask_pred = F.interpolate(
  81. refined_mask_pred,
  82. scale_factor=self.test_cfg.scale_factor,
  83. mode='bilinear',
  84. align_corners=False)
  85. # If `subdivision_num_points` is larger or equal to the
  86. # resolution of the next step, then we can skip this step
  87. num_rois, channels, mask_height, mask_width = \
  88. refined_mask_pred.shape
  89. if (self.test_cfg.subdivision_num_points >=
  90. self.test_cfg.scale_factor**2 * mask_height * mask_width
  91. and
  92. subdivision_step < self.test_cfg.subdivision_steps - 1):
  93. continue
  94. point_indices, rel_roi_points = \
  95. self.point_head.get_roi_rel_points_test(
  96. refined_mask_pred, label_preds, cfg=self.test_cfg)
  97. fine_grained_point_feats = self._get_fine_grained_point_feats(
  98. x=x, rois=rois, rel_roi_points=rel_roi_points)
  99. coarse_point_feats = point_sample(mask_preds, rel_roi_points)
  100. mask_point_pred = self.point_head(fine_grained_point_feats,
  101. coarse_point_feats)
  102. point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
  103. refined_mask_pred = refined_mask_pred.reshape(
  104. num_rois, channels, mask_height * mask_width)
  105. refined_mask_pred = refined_mask_pred.scatter_(
  106. 2, point_indices, mask_point_pred)
  107. refined_mask_pred = refined_mask_pred.view(num_rois, channels,
  108. mask_height, mask_width)
  109. return refined_mask_pred
  110. def _get_fine_grained_point_feats(self, x: Tuple[Tensor], rois: Tensor,
  111. rel_roi_points: Tensor) -> Tensor:
  112. """Sample fine grained feats from each level feature map and
  113. concatenate them together.
  114. Args:
  115. x (tuple[Tensor]): Feature maps of all scale level.
  116. rois (Tensor): shape (num_rois, 5).
  117. rel_roi_points (Tensor): A tensor of shape (num_rois, num_points,
  118. 2) that contains [0, 1] x [0, 1] normalized coordinates of the
  119. most uncertain points from the [mask_height, mask_width] grid.
  120. Returns:
  121. Tensor: The fine grained features for each points,
  122. has shape (num_rois, feats_channels, num_points).
  123. """
  124. assert rois.shape[0] > 0, 'RoI is a empty tensor.'
  125. num_imgs = x[0].shape[0]
  126. fine_grained_feats = []
  127. for idx in range(self.mask_roi_extractor.num_inputs):
  128. feats = x[idx]
  129. spatial_scale = 1. / float(
  130. self.mask_roi_extractor.featmap_strides[idx])
  131. point_feats = []
  132. for batch_ind in range(num_imgs):
  133. # unravel batch dim
  134. feat = feats[batch_ind].unsqueeze(0)
  135. inds = (rois[:, 0].long() == batch_ind)
  136. if inds.any():
  137. rel_img_points = rel_roi_point_to_rel_img_point(
  138. rois=rois[inds],
  139. rel_roi_points=rel_roi_points[inds],
  140. img=feat.shape[2:],
  141. spatial_scale=spatial_scale).unsqueeze(0)
  142. point_feat = point_sample(feat, rel_img_points)
  143. point_feat = point_feat.squeeze(0).transpose(0, 1)
  144. point_feats.append(point_feat)
  145. fine_grained_feats.append(torch.cat(point_feats, dim=0))
  146. return torch.cat(fine_grained_feats, dim=1)
  147. def predict_mask(self,
  148. x: Tuple[Tensor],
  149. batch_img_metas: List[dict],
  150. results_list: InstanceList,
  151. rescale: bool = False) -> InstanceList:
  152. """Perform forward propagation of the mask head and predict detection
  153. results on the features of the upstream network.
  154. Args:
  155. x (tuple[Tensor]): Feature maps of all scale level.
  156. batch_img_metas (list[dict]): List of image information.
  157. results_list (list[:obj:`InstanceData`]): Detection results of
  158. each image.
  159. rescale (bool): If True, return boxes in original image space.
  160. Defaults to False.
  161. Returns:
  162. list[:obj:`InstanceData`]: Detection results of each image
  163. after the post process.
  164. Each item usually contains following keys.
  165. - scores (Tensor): Classification scores, has a shape
  166. (num_instance, )
  167. - labels (Tensor): Labels of bboxes, has a shape
  168. (num_instances, ).
  169. - bboxes (Tensor): Has a shape (num_instances, 4),
  170. the last dimension 4 arrange as (x1, y1, x2, y2).
  171. - masks (Tensor): Has a shape (num_instances, H, W).
  172. """
  173. # don't need to consider aug_test.
  174. bboxes = [res.bboxes for res in results_list]
  175. mask_rois = bbox2roi(bboxes)
  176. if mask_rois.shape[0] == 0:
  177. results_list = empty_instances(
  178. batch_img_metas,
  179. mask_rois.device,
  180. task_type='mask',
  181. instance_results=results_list,
  182. mask_thr_binary=self.test_cfg.mask_thr_binary)
  183. return results_list
  184. mask_results = self._mask_forward(x, mask_rois)
  185. mask_preds = mask_results['mask_preds']
  186. # split batch mask prediction back to each image
  187. num_mask_rois_per_img = [len(res) for res in results_list]
  188. mask_preds = mask_preds.split(num_mask_rois_per_img, 0)
  189. # refine mask_preds
  190. mask_rois = mask_rois.split(num_mask_rois_per_img, 0)
  191. mask_preds_refined = []
  192. for i in range(len(batch_img_metas)):
  193. labels = results_list[i].labels
  194. x_i = [xx[[i]] for xx in x]
  195. mask_rois_i = mask_rois[i]
  196. mask_rois_i[:, 0] = 0
  197. mask_pred_i = self._mask_point_forward_test(
  198. x_i, mask_rois_i, labels, mask_preds[i])
  199. mask_preds_refined.append(mask_pred_i)
  200. # TODO: Handle the case where rescale is false
  201. results_list = self.mask_head.predict_by_feat(
  202. mask_preds=mask_preds_refined,
  203. results_list=results_list,
  204. batch_img_metas=batch_img_metas,
  205. rcnn_test_cfg=self.test_cfg,
  206. rescale=rescale)
  207. return results_list