mask_point_head.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
  3. from typing import List, Tuple
  4. import torch
  5. import torch.nn as nn
  6. from mmcv.cnn import ConvModule
  7. from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point
  8. from mmengine.model import BaseModule
  9. from mmengine.structures import InstanceData
  10. from torch import Tensor
  11. from mmdet.models.task_modules.samplers import SamplingResult
  12. from mmdet.models.utils import (get_uncertain_point_coords_with_randomness,
  13. get_uncertainty)
  14. from mmdet.registry import MODELS
  15. from mmdet.structures.bbox import bbox2roi
  16. from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType
  17. @MODELS.register_module()
  18. class MaskPointHead(BaseModule):
  19. """A mask point head use in PointRend.
  20. ``MaskPointHead`` use shared multi-layer perceptron (equivalent to
  21. nn.Conv1d) to predict the logit of input points. The fine-grained feature
  22. and coarse feature will be concatenate together for predication.
  23. Args:
  24. num_fcs (int): Number of fc layers in the head. Defaults to 3.
  25. in_channels (int): Number of input channels. Defaults to 256.
  26. fc_channels (int): Number of fc channels. Defaults to 256.
  27. num_classes (int): Number of classes for logits. Defaults to 80.
  28. class_agnostic (bool): Whether use class agnostic classification.
  29. If so, the output channels of logits will be 1. Defaults to False.
  30. coarse_pred_each_layer (bool): Whether concatenate coarse feature with
  31. the output of each fc layer. Defaults to True.
  32. conv_cfg (:obj:`ConfigDict` or dict): Dictionary to construct
  33. and config conv layer. Defaults to dict(type='Conv1d')).
  34. norm_cfg (:obj:`ConfigDict` or dict, optional): Dictionary to construct
  35. and config norm layer. Defaults to None.
  36. loss_point (:obj:`ConfigDict` or dict): Dictionary to construct and
  37. config loss layer of point head. Defaults to
  38. dict(type='CrossEntropyLoss', use_mask=True, loss_weight=1.0).
  39. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  40. dict], optional): Initialization config dict.
  41. """
  42. def __init__(
  43. self,
  44. num_classes: int,
  45. num_fcs: int = 3,
  46. in_channels: int = 256,
  47. fc_channels: int = 256,
  48. class_agnostic: bool = False,
  49. coarse_pred_each_layer: bool = True,
  50. conv_cfg: ConfigType = dict(type='Conv1d'),
  51. norm_cfg: OptConfigType = None,
  52. act_cfg: ConfigType = dict(type='ReLU'),
  53. loss_point: ConfigType = dict(
  54. type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),
  55. init_cfg: MultiConfig = dict(
  56. type='Normal', std=0.001, override=dict(name='fc_logits'))
  57. ) -> None:
  58. super().__init__(init_cfg=init_cfg)
  59. self.num_fcs = num_fcs
  60. self.in_channels = in_channels
  61. self.fc_channels = fc_channels
  62. self.num_classes = num_classes
  63. self.class_agnostic = class_agnostic
  64. self.coarse_pred_each_layer = coarse_pred_each_layer
  65. self.conv_cfg = conv_cfg
  66. self.norm_cfg = norm_cfg
  67. self.loss_point = MODELS.build(loss_point)
  68. fc_in_channels = in_channels + num_classes
  69. self.fcs = nn.ModuleList()
  70. for _ in range(num_fcs):
  71. fc = ConvModule(
  72. fc_in_channels,
  73. fc_channels,
  74. kernel_size=1,
  75. stride=1,
  76. padding=0,
  77. conv_cfg=conv_cfg,
  78. norm_cfg=norm_cfg,
  79. act_cfg=act_cfg)
  80. self.fcs.append(fc)
  81. fc_in_channels = fc_channels
  82. fc_in_channels += num_classes if self.coarse_pred_each_layer else 0
  83. out_channels = 1 if self.class_agnostic else self.num_classes
  84. self.fc_logits = nn.Conv1d(
  85. fc_in_channels, out_channels, kernel_size=1, stride=1, padding=0)
  86. def forward(self, fine_grained_feats: Tensor,
  87. coarse_feats: Tensor) -> Tensor:
  88. """Classify each point base on fine grained and coarse feats.
  89. Args:
  90. fine_grained_feats (Tensor): Fine grained feature sampled from FPN,
  91. shape (num_rois, in_channels, num_points).
  92. coarse_feats (Tensor): Coarse feature sampled from CoarseMaskHead,
  93. shape (num_rois, num_classes, num_points).
  94. Returns:
  95. Tensor: Point classification results,
  96. shape (num_rois, num_class, num_points).
  97. """
  98. x = torch.cat([fine_grained_feats, coarse_feats], dim=1)
  99. for fc in self.fcs:
  100. x = fc(x)
  101. if self.coarse_pred_each_layer:
  102. x = torch.cat((x, coarse_feats), dim=1)
  103. return self.fc_logits(x)
  104. def get_targets(self, rois: Tensor, rel_roi_points: Tensor,
  105. sampling_results: List[SamplingResult],
  106. batch_gt_instances: InstanceList,
  107. cfg: ConfigType) -> Tensor:
  108. """Get training targets of MaskPointHead for all images.
  109. Args:
  110. rois (Tensor): Region of Interest, shape (num_rois, 5).
  111. rel_roi_points (Tensor): Points coordinates relative to RoI, shape
  112. (num_rois, num_points, 2).
  113. sampling_results (:obj:`SamplingResult`): Sampling result after
  114. sampling and assignment.
  115. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  116. gt_instance. It usually includes ``bboxes``, ``labels``, and
  117. ``masks`` attributes.
  118. cfg (obj:`ConfigDict` or dict): Training cfg.
  119. Returns:
  120. Tensor: Point target, shape (num_rois, num_points).
  121. """
  122. num_imgs = len(sampling_results)
  123. rois_list = []
  124. rel_roi_points_list = []
  125. for batch_ind in range(num_imgs):
  126. inds = (rois[:, 0] == batch_ind)
  127. rois_list.append(rois[inds])
  128. rel_roi_points_list.append(rel_roi_points[inds])
  129. pos_assigned_gt_inds_list = [
  130. res.pos_assigned_gt_inds for res in sampling_results
  131. ]
  132. cfg_list = [cfg for _ in range(num_imgs)]
  133. point_targets = map(self._get_targets_single, rois_list,
  134. rel_roi_points_list, pos_assigned_gt_inds_list,
  135. batch_gt_instances, cfg_list)
  136. point_targets = list(point_targets)
  137. if len(point_targets) > 0:
  138. point_targets = torch.cat(point_targets)
  139. return point_targets
  140. def _get_targets_single(self, rois: Tensor, rel_roi_points: Tensor,
  141. pos_assigned_gt_inds: Tensor,
  142. gt_instances: InstanceData,
  143. cfg: ConfigType) -> Tensor:
  144. """Get training target of MaskPointHead for each image."""
  145. num_pos = rois.size(0)
  146. num_points = cfg.num_points
  147. if num_pos > 0:
  148. gt_masks_th = (
  149. gt_instances.masks.to_tensor(rois.dtype,
  150. rois.device).index_select(
  151. 0, pos_assigned_gt_inds))
  152. gt_masks_th = gt_masks_th.unsqueeze(1)
  153. rel_img_points = rel_roi_point_to_rel_img_point(
  154. rois, rel_roi_points, gt_masks_th)
  155. point_targets = point_sample(gt_masks_th,
  156. rel_img_points).squeeze(1)
  157. else:
  158. point_targets = rois.new_zeros((0, num_points))
  159. return point_targets
  160. def loss_and_target(self, point_pred: Tensor, rel_roi_points: Tensor,
  161. sampling_results: List[SamplingResult],
  162. batch_gt_instances: InstanceList,
  163. cfg: ConfigType) -> dict:
  164. """Calculate loss for MaskPointHead.
  165. Args:
  166. point_pred (Tensor): Point predication result, shape
  167. (num_rois, num_classes, num_points).
  168. rel_roi_points (Tensor): Points coordinates relative to RoI, shape
  169. (num_rois, num_points, 2).
  170. sampling_results (:obj:`SamplingResult`): Sampling result after
  171. sampling and assignment.
  172. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  173. gt_instance. It usually includes ``bboxes``, ``labels``, and
  174. ``masks`` attributes.
  175. cfg (obj:`ConfigDict` or dict): Training cfg.
  176. Returns:
  177. dict: a dictionary of point loss and point target.
  178. """
  179. rois = bbox2roi([res.pos_bboxes for res in sampling_results])
  180. pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
  181. point_target = self.get_targets(rois, rel_roi_points, sampling_results,
  182. batch_gt_instances, cfg)
  183. if self.class_agnostic:
  184. loss_point = self.loss_point(point_pred, point_target,
  185. torch.zeros_like(pos_labels))
  186. else:
  187. loss_point = self.loss_point(point_pred, point_target, pos_labels)
  188. return dict(loss_point=loss_point, point_target=point_target)
  189. def get_roi_rel_points_train(self, mask_preds: Tensor, labels: Tensor,
  190. cfg: ConfigType) -> Tensor:
  191. """Get ``num_points`` most uncertain points with random points during
  192. train.
  193. Sample points in [0, 1] x [0, 1] coordinate space based on their
  194. uncertainty. The uncertainties are calculated for each point using
  195. '_get_uncertainty()' function that takes point's logit prediction as
  196. input.
  197. Args:
  198. mask_preds (Tensor): A tensor of shape (num_rois, num_classes,
  199. mask_height, mask_width) for class-specific or class-agnostic
  200. prediction.
  201. labels (Tensor): The ground truth class for each instance.
  202. cfg (:obj:`ConfigDict` or dict): Training config of point head.
  203. Returns:
  204. point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
  205. that contains the coordinates sampled points.
  206. """
  207. point_coords = get_uncertain_point_coords_with_randomness(
  208. mask_preds, labels, cfg.num_points, cfg.oversample_ratio,
  209. cfg.importance_sample_ratio)
  210. return point_coords
  211. def get_roi_rel_points_test(self, mask_preds: Tensor, label_preds: Tensor,
  212. cfg: ConfigType) -> Tuple[Tensor, Tensor]:
  213. """Get ``num_points`` most uncertain points during test.
  214. Args:
  215. mask_preds (Tensor): A tensor of shape (num_rois, num_classes,
  216. mask_height, mask_width) for class-specific or class-agnostic
  217. prediction.
  218. label_preds (Tensor): The predication class for each instance.
  219. cfg (:obj:`ConfigDict` or dict): Testing config of point head.
  220. Returns:
  221. tuple:
  222. - point_indices (Tensor): A tensor of shape (num_rois, num_points)
  223. that contains indices from [0, mask_height x mask_width) of the
  224. most uncertain points.
  225. - point_coords (Tensor): A tensor of shape (num_rois, num_points,
  226. 2) that contains [0, 1] x [0, 1] normalized coordinates of the
  227. most uncertain points from the [mask_height, mask_width] grid.
  228. """
  229. num_points = cfg.subdivision_num_points
  230. uncertainty_map = get_uncertainty(mask_preds, label_preds)
  231. num_rois, _, mask_height, mask_width = uncertainty_map.shape
  232. # During ONNX exporting, the type of each elements of 'shape' is
  233. # `Tensor(float)`, while it is `float` during PyTorch inference.
  234. if isinstance(mask_height, torch.Tensor):
  235. h_step = 1.0 / mask_height.float()
  236. w_step = 1.0 / mask_width.float()
  237. else:
  238. h_step = 1.0 / mask_height
  239. w_step = 1.0 / mask_width
  240. # cast to int to avoid dynamic K for TopK op in ONNX
  241. mask_size = int(mask_height * mask_width)
  242. uncertainty_map = uncertainty_map.view(num_rois, mask_size)
  243. num_points = min(mask_size, num_points)
  244. point_indices = uncertainty_map.topk(num_points, dim=1)[1]
  245. xs = w_step / 2.0 + (point_indices % mask_width).float() * w_step
  246. ys = h_step / 2.0 + (point_indices // mask_width).float() * h_step
  247. point_coords = torch.stack([xs, ys], dim=2)
  248. return point_indices, point_coords