test_solo_head.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. import torch
  5. from mmengine.config import ConfigDict
  6. from mmengine.structures import InstanceData
  7. from parameterized import parameterized
  8. from mmdet import * # noqa
  9. from mmdet.models.dense_heads import (DecoupledSOLOHead,
  10. DecoupledSOLOLightHead, SOLOHead)
  11. from mmdet.structures.mask import BitmapMasks
  12. def _rand_masks(num_items, bboxes, img_w, img_h):
  13. rng = np.random.RandomState(0)
  14. masks = np.zeros((num_items, img_h, img_w))
  15. for i, bbox in enumerate(bboxes):
  16. bbox = bbox.astype(np.int32)
  17. mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) >
  18. 0.3).astype(np.int64)
  19. masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask
  20. return BitmapMasks(masks, height=img_h, width=img_w)
  21. class TestSOLOHead(TestCase):
  22. @parameterized.expand([(SOLOHead, ), (DecoupledSOLOHead, ),
  23. (DecoupledSOLOLightHead, )])
  24. def test_mask_head_loss(self, MaskHead):
  25. """Tests mask head loss when truth is empty and non-empty."""
  26. s = 256
  27. img_metas = [{
  28. 'img_shape': (s, s, 3),
  29. 'ori_shape': (s, s, 3),
  30. 'scale_factor': 1,
  31. 'batch_input_shape': (s, s, 3)
  32. }]
  33. mask_head = MaskHead(num_classes=4, in_channels=1)
  34. # SOLO head expects a multiple levels of features per image
  35. feats = []
  36. for i in range(len(mask_head.strides)):
  37. feats.append(
  38. torch.rand(1, 1, s // (2**(i + 2)), s // (2**(i + 2))))
  39. feats = tuple(feats)
  40. mask_outs = mask_head.forward(feats)
  41. # Test that empty ground truth encourages the network to
  42. # predict background
  43. gt_instances = InstanceData()
  44. gt_instances.bboxes = torch.empty(0, 4)
  45. gt_instances.labels = torch.LongTensor([])
  46. gt_instances.masks = _rand_masks(0, gt_instances.bboxes.numpy(), s, s)
  47. empty_gt_losses = mask_head.loss_by_feat(
  48. *mask_outs,
  49. batch_gt_instances=[gt_instances],
  50. batch_img_metas=img_metas)
  51. # When there is no truth, the cls loss should be nonzero but
  52. # there should be no box loss.
  53. empty_cls_loss = empty_gt_losses['loss_cls']
  54. empty_mask_loss = empty_gt_losses['loss_mask']
  55. self.assertGreater(empty_cls_loss.item(), 0,
  56. 'cls loss should be non-zero')
  57. self.assertEqual(
  58. empty_mask_loss.item(), 0,
  59. 'there should be no mask loss when there are no true mask')
  60. # When truth is non-empty then both cls and box loss
  61. # should be nonzero for random inputs
  62. gt_instances = InstanceData()
  63. gt_instances.bboxes = torch.Tensor(
  64. [[23.6667, 23.8757, 238.6326, 151.8874]])
  65. gt_instances.labels = torch.LongTensor([2])
  66. gt_instances.masks = _rand_masks(1, gt_instances.bboxes.numpy(), s, s)
  67. one_gt_losses = mask_head.loss_by_feat(
  68. *mask_outs,
  69. batch_gt_instances=[gt_instances],
  70. batch_img_metas=img_metas)
  71. onegt_cls_loss = one_gt_losses['loss_cls']
  72. onegt_mask_loss = one_gt_losses['loss_mask']
  73. self.assertGreater(onegt_cls_loss.item(), 0,
  74. 'cls loss should be non-zero')
  75. self.assertGreater(onegt_mask_loss.item(), 0,
  76. 'mask loss should be non-zero')
  77. def test_solo_head_empty_result(self):
  78. s = 256
  79. img_metas = {
  80. 'img_shape': (s, s, 3),
  81. 'ori_shape': (s, s, 3),
  82. 'scale_factor': 1,
  83. 'batch_input_shape': (s, s, 3)
  84. }
  85. mask_head = SOLOHead(num_classes=4, in_channels=1)
  86. cls_scores = torch.empty(0, 80)
  87. mask_preds = torch.empty(0, 16, 16)
  88. test_cfg = ConfigDict(
  89. score_thr=0.1,
  90. mask_thr=0.5,
  91. )
  92. results = mask_head._predict_by_feat_single(
  93. cls_scores=cls_scores,
  94. mask_preds=mask_preds,
  95. img_meta=img_metas,
  96. cfg=test_cfg)
  97. self.assertIsInstance(results, InstanceData)
  98. self.assertEqual(len(results), 0)
  99. def test_decoupled_solo_head_empty_result(self):
  100. s = 256
  101. img_metas = {
  102. 'img_shape': (s, s, 3),
  103. 'ori_shape': (s, s, 3),
  104. 'scale_factor': 1,
  105. 'batch_input_shape': (s, s, 3)
  106. }
  107. mask_head = DecoupledSOLOHead(num_classes=4, in_channels=1)
  108. cls_scores = torch.empty(0, 80)
  109. mask_preds_x = torch.empty(0, 16, 16)
  110. mask_preds_y = torch.empty(0, 16, 16)
  111. test_cfg = ConfigDict(
  112. score_thr=0.1,
  113. mask_thr=0.5,
  114. )
  115. results = mask_head._predict_by_feat_single(
  116. cls_scores=cls_scores,
  117. mask_preds_x=mask_preds_x,
  118. mask_preds_y=mask_preds_y,
  119. img_meta=img_metas,
  120. cfg=test_cfg)
  121. self.assertIsInstance(results, InstanceData)
  122. self.assertEqual(len(results), 0)