mask_target.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. from torch.nn.modules.utils import _pair
  5. def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list,
  6. cfg):
  7. """Compute mask target for positive proposals in multiple images.
  8. Args:
  9. pos_proposals_list (list[Tensor]): Positive proposals in multiple
  10. images, each has shape (num_pos, 4).
  11. pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each
  12. positive proposals, each has shape (num_pos,).
  13. gt_masks_list (list[:obj:`BaseInstanceMasks`]): Ground truth masks of
  14. each image.
  15. cfg (dict): Config dict that specifies the mask size.
  16. Returns:
  17. Tensor: Mask target of each image, has shape (num_pos, w, h).
  18. Example:
  19. >>> from mmengine.config import Config
  20. >>> import mmdet
  21. >>> from mmdet.data_elements.mask import BitmapMasks
  22. >>> from mmdet.data_elements.mask.mask_target import *
  23. >>> H, W = 17, 18
  24. >>> cfg = Config({'mask_size': (13, 14)})
  25. >>> rng = np.random.RandomState(0)
  26. >>> # Positive proposals (tl_x, tl_y, br_x, br_y) for each image
  27. >>> pos_proposals_list = [
  28. >>> torch.Tensor([
  29. >>> [ 7.2425, 5.5929, 13.9414, 14.9541],
  30. >>> [ 7.3241, 3.6170, 16.3850, 15.3102],
  31. >>> ]),
  32. >>> torch.Tensor([
  33. >>> [ 4.8448, 6.4010, 7.0314, 9.7681],
  34. >>> [ 5.9790, 2.6989, 7.4416, 4.8580],
  35. >>> [ 0.0000, 0.0000, 0.1398, 9.8232],
  36. >>> ]),
  37. >>> ]
  38. >>> # Corresponding class index for each proposal for each image
  39. >>> pos_assigned_gt_inds_list = [
  40. >>> torch.LongTensor([7, 0]),
  41. >>> torch.LongTensor([5, 4, 1]),
  42. >>> ]
  43. >>> # Ground truth mask for each true object for each image
  44. >>> gt_masks_list = [
  45. >>> BitmapMasks(rng.rand(8, H, W), height=H, width=W),
  46. >>> BitmapMasks(rng.rand(6, H, W), height=H, width=W),
  47. >>> ]
  48. >>> mask_targets = mask_target(
  49. >>> pos_proposals_list, pos_assigned_gt_inds_list,
  50. >>> gt_masks_list, cfg)
  51. >>> assert mask_targets.shape == (5,) + cfg['mask_size']
  52. """
  53. cfg_list = [cfg for _ in range(len(pos_proposals_list))]
  54. mask_targets = map(mask_target_single, pos_proposals_list,
  55. pos_assigned_gt_inds_list, gt_masks_list, cfg_list)
  56. mask_targets = list(mask_targets)
  57. if len(mask_targets) > 0:
  58. mask_targets = torch.cat(mask_targets)
  59. return mask_targets
  60. def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg):
  61. """Compute mask target for each positive proposal in the image.
  62. Args:
  63. pos_proposals (Tensor): Positive proposals.
  64. pos_assigned_gt_inds (Tensor): Assigned GT inds of positive proposals.
  65. gt_masks (:obj:`BaseInstanceMasks`): GT masks in the format of Bitmap
  66. or Polygon.
  67. cfg (dict): Config dict that indicate the mask size.
  68. Returns:
  69. Tensor: Mask target of each positive proposals in the image.
  70. Example:
  71. >>> from mmengine.config import Config
  72. >>> import mmdet
  73. >>> from mmdet.data_elements.mask import BitmapMasks
  74. >>> from mmdet.data_elements.mask.mask_target import * # NOQA
  75. >>> H, W = 32, 32
  76. >>> cfg = Config({'mask_size': (7, 11)})
  77. >>> rng = np.random.RandomState(0)
  78. >>> # Masks for each ground truth box (relative to the image)
  79. >>> gt_masks_data = rng.rand(3, H, W)
  80. >>> gt_masks = BitmapMasks(gt_masks_data, height=H, width=W)
  81. >>> # Predicted positive boxes in one image
  82. >>> pos_proposals = torch.FloatTensor([
  83. >>> [ 16.2, 5.5, 19.9, 20.9],
  84. >>> [ 17.3, 13.6, 19.3, 19.3],
  85. >>> [ 14.8, 16.4, 17.0, 23.7],
  86. >>> [ 0.0, 0.0, 16.0, 16.0],
  87. >>> [ 4.0, 0.0, 20.0, 16.0],
  88. >>> ])
  89. >>> # For each predicted proposal, its assignment to a gt mask
  90. >>> pos_assigned_gt_inds = torch.LongTensor([0, 1, 2, 1, 1])
  91. >>> mask_targets = mask_target_single(
  92. >>> pos_proposals, pos_assigned_gt_inds, gt_masks, cfg)
  93. >>> assert mask_targets.shape == (5,) + cfg['mask_size']
  94. """
  95. device = pos_proposals.device
  96. mask_size = _pair(cfg.mask_size)
  97. binarize = not cfg.get('soft_mask_target', False)
  98. num_pos = pos_proposals.size(0)
  99. if num_pos > 0:
  100. proposals_np = pos_proposals.cpu().numpy()
  101. maxh, maxw = gt_masks.height, gt_masks.width
  102. proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw)
  103. proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh)
  104. pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
  105. mask_targets = gt_masks.crop_and_resize(
  106. proposals_np,
  107. mask_size,
  108. device=device,
  109. inds=pos_assigned_gt_inds,
  110. binarize=binarize).to_ndarray()
  111. mask_targets = torch.from_numpy(mask_targets).float().to(device)
  112. else:
  113. mask_targets = pos_proposals.new_zeros((0, ) + mask_size)
  114. return mask_targets