test_reppoints_head.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import unittest
  2. import torch
  3. from mmengine.config import ConfigDict
  4. from mmengine.structures import InstanceData
  5. from parameterized import parameterized
  6. from mmdet.models.dense_heads import RepPointsHead
  7. from mmdet.structures import DetDataSample
  8. class TestRepPointsHead(unittest.TestCase):
  9. @parameterized.expand(['moment', 'minmax', 'partial_minmax'])
  10. def test_head_loss(self, transform_method='moment'):
  11. cfg = ConfigDict(
  12. dict(
  13. num_classes=2,
  14. in_channels=32,
  15. point_feat_channels=10,
  16. num_points=9,
  17. gradient_mul=0.1,
  18. point_strides=[8, 16, 32, 64, 128],
  19. point_base_scale=4,
  20. loss_cls=dict(
  21. type='FocalLoss',
  22. use_sigmoid=True,
  23. gamma=2.0,
  24. alpha=0.25,
  25. loss_weight=1.0),
  26. loss_bbox_init=dict(
  27. type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
  28. loss_bbox_refine=dict(
  29. type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
  30. use_grid_points=False,
  31. center_init=True,
  32. transform_method=transform_method,
  33. moment_mul=0.01,
  34. init_cfg=dict(
  35. type='Normal',
  36. layer='Conv2d',
  37. std=0.01,
  38. override=dict(
  39. type='Normal',
  40. name='reppoints_cls_out',
  41. std=0.01,
  42. bias_prob=0.01)),
  43. train_cfg=dict(
  44. init=dict(
  45. assigner=dict(
  46. type='PointAssigner', scale=4, pos_num=1),
  47. allowed_border=-1,
  48. pos_weight=-1,
  49. debug=False),
  50. refine=dict(
  51. assigner=dict(
  52. type='MaxIoUAssigner',
  53. pos_iou_thr=0.5,
  54. neg_iou_thr=0.4,
  55. min_pos_iou=0,
  56. ignore_iof_thr=-1),
  57. allowed_border=-1,
  58. pos_weight=-1,
  59. debug=False)),
  60. test_cfg=dict(
  61. nms_pre=1000,
  62. min_bbox_size=0,
  63. score_thr=0.05,
  64. nms=dict(type='nms', iou_threshold=0.5),
  65. max_per_img=100)))
  66. reppoints_head = RepPointsHead(**cfg)
  67. s = 256
  68. img_metas = [{
  69. 'img_shape': (s, s),
  70. 'scale_factor': (1, 1),
  71. 'pad_shape': (s, s),
  72. 'batch_input_shape': (s, s)
  73. }]
  74. x = [
  75. torch.rand(1, 32, s // 2**(i + 2), s // 2**(i + 2))
  76. for i in range(5)
  77. ]
  78. # Test that empty ground truth encourages the network to
  79. # predict background
  80. gt_instances = InstanceData()
  81. gt_instances.bboxes = torch.empty((0, 4))
  82. gt_instances.labels = torch.LongTensor([])
  83. gt_bboxes_ignore = None
  84. reppoints_head.train()
  85. forward_outputs = reppoints_head.forward(x)
  86. empty_gt_losses = reppoints_head.loss_by_feat(*forward_outputs,
  87. [gt_instances],
  88. img_metas,
  89. gt_bboxes_ignore)
  90. # When there is no truth, the cls loss should be nonzero but there
  91. # should be no pts loss.
  92. for key, losses in empty_gt_losses.items():
  93. for loss in losses:
  94. if 'cls' in key:
  95. self.assertGreater(loss.item(), 0,
  96. 'cls loss should be non-zero')
  97. elif 'pts' in key:
  98. self.assertEqual(
  99. loss.item(), 0,
  100. 'there should be no reg loss when no ground true boxes'
  101. )
  102. # When truth is non-empty then both cls and pts loss should be nonzero
  103. # for random inputs
  104. gt_instances = InstanceData()
  105. gt_instances.bboxes = torch.Tensor(
  106. [[23.6667, 23.8757, 238.6326, 151.8874]])
  107. gt_instances.labels = torch.LongTensor([2])
  108. one_gt_losses = reppoints_head.loss_by_feat(*forward_outputs,
  109. [gt_instances], img_metas,
  110. gt_bboxes_ignore)
  111. # loss_cls should all be non-zero
  112. self.assertTrue(
  113. all([loss.item() > 0 for loss in one_gt_losses['loss_cls']]))
  114. # only one level loss_pts_init is non-zero
  115. cnt_non_zero = 0
  116. for loss in one_gt_losses['loss_pts_init']:
  117. if loss.item() != 0:
  118. cnt_non_zero += 1
  119. self.assertEqual(cnt_non_zero, 1)
  120. # only one level loss_pts_refine is non-zero
  121. cnt_non_zero = 0
  122. for loss in one_gt_losses['loss_pts_init']:
  123. if loss.item() != 0:
  124. cnt_non_zero += 1
  125. self.assertEqual(cnt_non_zero, 1)
  126. # test loss
  127. samples = DetDataSample()
  128. samples.set_metainfo(img_metas[0])
  129. samples.gt_instances = gt_instances
  130. reppoints_head.loss(x, [samples])
  131. # test only predict
  132. reppoints_head.eval()
  133. reppoints_head.predict(x, [samples], rescale=True)