boxinst_head.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List
  3. import torch
  4. import torch.nn.functional as F
  5. from mmengine import MessageHub
  6. from mmengine.structures import InstanceData
  7. from torch import Tensor
  8. from mmdet.registry import MODELS
  9. from mmdet.utils import InstanceList
  10. from ..utils.misc import unfold_wo_center
  11. from .condinst_head import CondInstBboxHead, CondInstMaskHead
  12. @MODELS.register_module()
  13. class BoxInstBboxHead(CondInstBboxHead):
  14. """BoxInst box head used in https://arxiv.org/abs/2012.02310."""
  15. def __init__(self, *args, **kwargs) -> None:
  16. super().__init__(*args, **kwargs)
  17. @MODELS.register_module()
  18. class BoxInstMaskHead(CondInstMaskHead):
  19. """BoxInst mask head used in https://arxiv.org/abs/2012.02310.
  20. This head outputs the mask for BoxInst.
  21. Args:
  22. pairwise_size (dict): The size of neighborhood for each pixel.
  23. Defaults to 3.
  24. pairwise_dilation (int): The dilation of neighborhood for each pixel.
  25. Defaults to 2.
  26. warmup_iters (int): Warmup iterations for pair-wise loss.
  27. Defaults to 10000.
  28. """
  29. def __init__(self,
  30. *arg,
  31. pairwise_size: int = 3,
  32. pairwise_dilation: int = 2,
  33. warmup_iters: int = 10000,
  34. **kwargs) -> None:
  35. self.pairwise_size = pairwise_size
  36. self.pairwise_dilation = pairwise_dilation
  37. self.warmup_iters = warmup_iters
  38. super().__init__(*arg, **kwargs)
  39. def get_pairwise_affinity(self, mask_logits: Tensor) -> Tensor:
  40. """Compute the pairwise affinity for each pixel."""
  41. log_fg_prob = F.logsigmoid(mask_logits).unsqueeze(1)
  42. log_bg_prob = F.logsigmoid(-mask_logits).unsqueeze(1)
  43. log_fg_prob_unfold = unfold_wo_center(
  44. log_fg_prob,
  45. kernel_size=self.pairwise_size,
  46. dilation=self.pairwise_dilation)
  47. log_bg_prob_unfold = unfold_wo_center(
  48. log_bg_prob,
  49. kernel_size=self.pairwise_size,
  50. dilation=self.pairwise_dilation)
  51. # the probability of making the same prediction:
  52. # p_i * p_j + (1 - p_i) * (1 - p_j)
  53. # we compute the the probability in log space
  54. # to avoid numerical instability
  55. log_same_fg_prob = log_fg_prob[:, :, None] + log_fg_prob_unfold
  56. log_same_bg_prob = log_bg_prob[:, :, None] + log_bg_prob_unfold
  57. # TODO: Figure out the difference between it and directly sum
  58. max_ = torch.max(log_same_fg_prob, log_same_bg_prob)
  59. log_same_prob = torch.log(
  60. torch.exp(log_same_fg_prob - max_) +
  61. torch.exp(log_same_bg_prob - max_)) + max_
  62. return -log_same_prob[:, 0]
  63. def loss_by_feat(self, mask_preds: List[Tensor],
  64. batch_gt_instances: InstanceList,
  65. batch_img_metas: List[dict], positive_infos: InstanceList,
  66. **kwargs) -> dict:
  67. """Calculate the loss based on the features extracted by the mask head.
  68. Args:
  69. mask_preds (list[Tensor]): List of predicted masks, each has
  70. shape (num_classes, H, W).
  71. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  72. gt_instance. It usually includes ``bboxes``, ``masks``,
  73. and ``labels`` attributes.
  74. batch_img_metas (list[dict]): Meta information of multiple images.
  75. positive_infos (List[:obj:``InstanceData``]): Information of
  76. positive samples of each image that are assigned in detection
  77. head.
  78. Returns:
  79. dict[str, Tensor]: A dictionary of loss components.
  80. """
  81. assert positive_infos is not None, \
  82. 'positive_infos should not be None in `BoxInstMaskHead`'
  83. losses = dict()
  84. loss_mask_project = 0.
  85. loss_mask_pairwise = 0.
  86. num_imgs = len(mask_preds)
  87. total_pos = 0.
  88. avg_fatcor = 0.
  89. for idx in range(num_imgs):
  90. (mask_pred, pos_mask_targets, pos_pairwise_masks, num_pos) = \
  91. self._get_targets_single(
  92. mask_preds[idx], batch_gt_instances[idx],
  93. positive_infos[idx])
  94. # mask loss
  95. total_pos += num_pos
  96. if num_pos == 0 or pos_mask_targets is None:
  97. loss_project = mask_pred.new_zeros(1).mean()
  98. loss_pairwise = mask_pred.new_zeros(1).mean()
  99. avg_fatcor += 0.
  100. else:
  101. # compute the project term
  102. loss_project_x = self.loss_mask(
  103. mask_pred.max(dim=1, keepdim=True)[0],
  104. pos_mask_targets.max(dim=1, keepdim=True)[0],
  105. reduction_override='none').sum()
  106. loss_project_y = self.loss_mask(
  107. mask_pred.max(dim=2, keepdim=True)[0],
  108. pos_mask_targets.max(dim=2, keepdim=True)[0],
  109. reduction_override='none').sum()
  110. loss_project = loss_project_x + loss_project_y
  111. # compute the pairwise term
  112. pairwise_affinity = self.get_pairwise_affinity(mask_pred)
  113. avg_fatcor += pos_pairwise_masks.sum().clamp(min=1.0)
  114. loss_pairwise = (pairwise_affinity * pos_pairwise_masks).sum()
  115. loss_mask_project += loss_project
  116. loss_mask_pairwise += loss_pairwise
  117. if total_pos == 0:
  118. total_pos += 1 # avoid nan
  119. if avg_fatcor == 0:
  120. avg_fatcor += 1 # avoid nan
  121. loss_mask_project = loss_mask_project / total_pos
  122. loss_mask_pairwise = loss_mask_pairwise / avg_fatcor
  123. message_hub = MessageHub.get_current_instance()
  124. iter = message_hub.get_info('iter')
  125. warmup_factor = min(iter / float(self.warmup_iters), 1.0)
  126. loss_mask_pairwise *= warmup_factor
  127. losses.update(
  128. loss_mask_project=loss_mask_project,
  129. loss_mask_pairwise=loss_mask_pairwise)
  130. return losses
  131. def _get_targets_single(self, mask_preds: Tensor,
  132. gt_instances: InstanceData,
  133. positive_info: InstanceData):
  134. """Compute targets for predictions of single image.
  135. Args:
  136. mask_preds (Tensor): Predicted prototypes with shape
  137. (num_classes, H, W).
  138. gt_instances (:obj:`InstanceData`): Ground truth of instance
  139. annotations. It should includes ``bboxes``, ``labels``,
  140. and ``masks`` attributes.
  141. positive_info (:obj:`InstanceData`): Information of positive
  142. samples that are assigned in detection head. It usually
  143. contains following keys.
  144. - pos_assigned_gt_inds (Tensor): Assigner GT indexes of
  145. positive proposals, has shape (num_pos, )
  146. - pos_inds (Tensor): Positive index of image, has
  147. shape (num_pos, ).
  148. - param_pred (Tensor): Positive param preditions
  149. with shape (num_pos, num_params).
  150. Returns:
  151. tuple: Usually returns a tuple containing learning targets.
  152. - mask_preds (Tensor): Positive predicted mask with shape
  153. (num_pos, mask_h, mask_w).
  154. - pos_mask_targets (Tensor): Positive mask targets with shape
  155. (num_pos, mask_h, mask_w).
  156. - pos_pairwise_masks (Tensor): Positive pairwise masks with
  157. shape: (num_pos, num_neighborhood, mask_h, mask_w).
  158. - num_pos (int): Positive numbers.
  159. """
  160. gt_bboxes = gt_instances.bboxes
  161. device = gt_bboxes.device
  162. # Note that gt_masks are generated by full box
  163. # from BoxInstDataPreprocessor
  164. gt_masks = gt_instances.masks.to_tensor(
  165. dtype=torch.bool, device=device).float()
  166. # Note that pairwise_masks are generated by image color similarity
  167. # from BoxInstDataPreprocessor
  168. pairwise_masks = gt_instances.pairwise_masks
  169. pairwise_masks = pairwise_masks.to(device=device)
  170. # process with mask targets
  171. pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds')
  172. scores = positive_info.get('scores')
  173. centernesses = positive_info.get('centernesses')
  174. num_pos = pos_assigned_gt_inds.size(0)
  175. if gt_masks.size(0) == 0 or num_pos == 0:
  176. return mask_preds, None, None, 0
  177. # Since we're producing (near) full image masks,
  178. # it'd take too much vram to backprop on every single mask.
  179. # Thus we select only a subset.
  180. if (self.max_masks_to_train != -1) and \
  181. (num_pos > self.max_masks_to_train):
  182. perm = torch.randperm(num_pos)
  183. select = perm[:self.max_masks_to_train]
  184. mask_preds = mask_preds[select]
  185. pos_assigned_gt_inds = pos_assigned_gt_inds[select]
  186. num_pos = self.max_masks_to_train
  187. elif self.topk_masks_per_img != -1:
  188. unique_gt_inds = pos_assigned_gt_inds.unique()
  189. num_inst_per_gt = max(
  190. int(self.topk_masks_per_img / len(unique_gt_inds)), 1)
  191. keep_mask_preds = []
  192. keep_pos_assigned_gt_inds = []
  193. for gt_ind in unique_gt_inds:
  194. per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind)
  195. mask_preds_per_inst = mask_preds[per_inst_pos_inds]
  196. gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds]
  197. if sum(per_inst_pos_inds) > num_inst_per_gt:
  198. per_inst_scores = scores[per_inst_pos_inds].sigmoid().max(
  199. dim=1)[0]
  200. per_inst_centerness = centernesses[
  201. per_inst_pos_inds].sigmoid().reshape(-1, )
  202. select = (per_inst_scores * per_inst_centerness).topk(
  203. k=num_inst_per_gt, dim=0)[1]
  204. mask_preds_per_inst = mask_preds_per_inst[select]
  205. gt_inds_per_inst = gt_inds_per_inst[select]
  206. keep_mask_preds.append(mask_preds_per_inst)
  207. keep_pos_assigned_gt_inds.append(gt_inds_per_inst)
  208. mask_preds = torch.cat(keep_mask_preds)
  209. pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds)
  210. num_pos = pos_assigned_gt_inds.size(0)
  211. # Follow the origin implement
  212. start = int(self.mask_out_stride // 2)
  213. gt_masks = gt_masks[:, start::self.mask_out_stride,
  214. start::self.mask_out_stride]
  215. gt_masks = gt_masks.gt(0.5).float()
  216. pos_mask_targets = gt_masks[pos_assigned_gt_inds]
  217. pos_pairwise_masks = pairwise_masks[pos_assigned_gt_inds]
  218. pos_pairwise_masks = pos_pairwise_masks * pos_mask_targets.unsqueeze(1)
  219. return (mask_preds, pos_mask_targets, pos_pairwise_masks, num_pos)