test_boxinst_head.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. import torch
  5. from mmengine import MessageHub
  6. from mmengine.config import ConfigDict
  7. from mmengine.structures import InstanceData
  8. from mmdet.models.dense_heads import BoxInstBboxHead, BoxInstMaskHead
  9. from mmdet.structures.mask import BitmapMasks
  10. def _rand_masks(num_items, bboxes, img_w, img_h):
  11. rng = np.random.RandomState(0)
  12. masks = np.zeros((num_items, img_h, img_w), dtype=np.float32)
  13. for i, bbox in enumerate(bboxes):
  14. bbox = bbox.astype(np.int32)
  15. mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) >
  16. 0.3).astype(np.int64)
  17. masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask
  18. return BitmapMasks(masks, height=img_h, width=img_w)
  19. def _fake_mask_feature_head():
  20. mask_feature_head = ConfigDict(
  21. in_channels=1,
  22. feat_channels=1,
  23. start_level=0,
  24. end_level=2,
  25. out_channels=8,
  26. mask_stride=8,
  27. num_stacked_convs=4,
  28. norm_cfg=dict(type='BN', requires_grad=True))
  29. return mask_feature_head
  30. class TestBoxInstHead(TestCase):
  31. def test_boxinst_maskhead_loss(self):
  32. """Tests boxinst maskhead loss when truth is empty and non-empty."""
  33. s = 256
  34. img_metas = [{
  35. 'img_shape': (s, s, 3),
  36. 'pad_shape': (s, s, 3),
  37. 'scale_factor': 1,
  38. }]
  39. boxinst_bboxhead = BoxInstBboxHead(
  40. num_classes=4,
  41. in_channels=1,
  42. feat_channels=1,
  43. stacked_convs=1,
  44. norm_cfg=None)
  45. mask_feature_head = _fake_mask_feature_head()
  46. boxinst_maskhead = BoxInstMaskHead(
  47. mask_feature_head=mask_feature_head,
  48. loss_mask=dict(
  49. type='DiceLoss',
  50. use_sigmoid=True,
  51. activate=True,
  52. eps=5e-6,
  53. loss_weight=1.0))
  54. # Fcos head expects a multiple levels of features per image
  55. feats = []
  56. for i in range(len(boxinst_bboxhead.strides)):
  57. feats.append(
  58. torch.rand(1, 1, s // (2**(i + 3)), s // (2**(i + 3))))
  59. feats = tuple(feats)
  60. cls_scores, bbox_preds, centernesses, param_preds =\
  61. boxinst_bboxhead.forward(feats)
  62. # Test that empty ground truth encourages the network to
  63. # predict background
  64. gt_instances = InstanceData()
  65. gt_instances.bboxes = torch.empty((0, 4))
  66. gt_instances.labels = torch.LongTensor([])
  67. gt_instances.masks = _rand_masks(0, gt_instances.bboxes.numpy(), s, s)
  68. gt_instances.pairwise_masks = _rand_masks(
  69. 0, gt_instances.bboxes.numpy(), s // 4, s // 4).to_tensor(
  70. dtype=torch.float32,
  71. device='cpu').unsqueeze(1).repeat(1, 8, 1, 1)
  72. message_hub = MessageHub.get_instance('runtime_info')
  73. message_hub.update_info('iter', 1)
  74. _ = boxinst_bboxhead.loss_by_feat(cls_scores, bbox_preds, centernesses,
  75. param_preds, [gt_instances],
  76. img_metas)
  77. # When truth is empty then all mask loss
  78. # should be zero for random inputs
  79. positive_infos = boxinst_bboxhead.get_positive_infos()
  80. mask_outs = boxinst_maskhead.forward(feats, positive_infos)
  81. empty_gt_mask_losses = boxinst_maskhead.loss_by_feat(
  82. *mask_outs, [gt_instances], img_metas, positive_infos)
  83. loss_mask_project = empty_gt_mask_losses['loss_mask_project']
  84. loss_mask_pairwise = empty_gt_mask_losses['loss_mask_pairwise']
  85. self.assertEqual(loss_mask_project, 0,
  86. 'mask project loss should be zero')
  87. self.assertEqual(loss_mask_pairwise, 0,
  88. 'mask pairwise loss should be zero')
  89. # When truth is non-empty then all cls, box loss and centerness loss
  90. # should be nonzero for random inputs
  91. gt_instances = InstanceData()
  92. gt_instances.bboxes = torch.Tensor([[0.111, 0.222, 25.6667, 29.8757]])
  93. gt_instances.labels = torch.LongTensor([2])
  94. gt_instances.masks = _rand_masks(1, gt_instances.bboxes.numpy(), s, s)
  95. gt_instances.pairwise_masks = _rand_masks(
  96. 1, gt_instances.bboxes.numpy(), s // 4, s // 4).to_tensor(
  97. dtype=torch.float32,
  98. device='cpu').unsqueeze(1).repeat(1, 8, 1, 1)
  99. _ = boxinst_bboxhead.loss_by_feat(cls_scores, bbox_preds, centernesses,
  100. param_preds, [gt_instances],
  101. img_metas)
  102. positive_infos = boxinst_bboxhead.get_positive_infos()
  103. mask_outs = boxinst_maskhead.forward(feats, positive_infos)
  104. one_gt_mask_losses = boxinst_maskhead.loss_by_feat(
  105. *mask_outs, [gt_instances], img_metas, positive_infos)
  106. loss_mask_project = one_gt_mask_losses['loss_mask_project']
  107. loss_mask_pairwise = one_gt_mask_losses['loss_mask_pairwise']
  108. self.assertGreater(loss_mask_project, 0,
  109. 'mask project loss should be nonzero')
  110. self.assertGreater(loss_mask_pairwise, 0,
  111. 'mask pairwise loss should be nonzero')