test_centernet_head.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.config import ConfigDict
  5. from mmengine.structures import InstanceData
  6. from mmdet.models.dense_heads import CenterNetHead
  7. class TestCenterNetHead(TestCase):
  8. def test_center_head_loss(self):
  9. """Tests center head loss when truth is empty and non-empty."""
  10. s = 256
  11. img_metas = [{'batch_input_shape': (s, s, 3)}]
  12. test_cfg = dict(topK=100, max_per_img=100)
  13. centernet_head = CenterNetHead(
  14. num_classes=4, in_channels=1, feat_channels=4, test_cfg=test_cfg)
  15. feat = [torch.rand(1, 1, s, s)]
  16. center_out, wh_out, offset_out = centernet_head.forward(feat)
  17. # Test that empty ground truth encourages the network to
  18. # predict background
  19. gt_instances = InstanceData()
  20. gt_instances.bboxes = torch.empty((0, 4))
  21. gt_instances.labels = torch.LongTensor([])
  22. empty_gt_losses = centernet_head.loss_by_feat(center_out, wh_out,
  23. offset_out,
  24. [gt_instances],
  25. img_metas)
  26. loss_center = empty_gt_losses['loss_center_heatmap']
  27. loss_wh = empty_gt_losses['loss_wh']
  28. loss_offset = empty_gt_losses['loss_offset']
  29. assert loss_center.item() > 0, 'loss_center should be non-zero'
  30. assert loss_wh.item() == 0, (
  31. 'there should be no loss_wh when there are no true boxes')
  32. assert loss_offset.item() == 0, (
  33. 'there should be no loss_offset when there are no true boxes')
  34. # When truth is non-empty then both cls and box loss
  35. # should be nonzero for random inputs
  36. gt_instances = InstanceData()
  37. gt_instances.bboxes = torch.Tensor(
  38. [[23.6667, 23.8757, 238.6326, 151.8874]])
  39. gt_instances.labels = torch.LongTensor([2])
  40. one_gt_losses = centernet_head.loss_by_feat(center_out, wh_out,
  41. offset_out, [gt_instances],
  42. img_metas)
  43. loss_center = one_gt_losses['loss_center_heatmap']
  44. loss_wh = one_gt_losses['loss_wh']
  45. loss_offset = one_gt_losses['loss_offset']
  46. assert loss_center.item() > 0, 'loss_center should be non-zero'
  47. assert loss_wh.item() > 0, 'loss_wh should be non-zero'
  48. assert loss_offset.item() > 0, 'loss_offset should be non-zero'
  49. def test_centernet_head_get_targets(self):
  50. """Tests center head generating and decoding the heatmap."""
  51. s = 256
  52. img_metas = [{
  53. 'img_shape': (s, s, 3),
  54. 'batch_input_shape': (s, s),
  55. }]
  56. test_cfg = ConfigDict(
  57. dict(topk=100, local_maximum_kernel=3, max_per_img=100))
  58. gt_bboxes = [
  59. torch.Tensor([[10, 20, 200, 240], [40, 50, 100, 200],
  60. [10, 20, 100, 240]])
  61. ]
  62. gt_labels = [torch.LongTensor([1, 1, 2])]
  63. centernet_head = CenterNetHead(
  64. num_classes=4, in_channels=1, feat_channels=4, test_cfg=test_cfg)
  65. self.feat_shape = (1, 1, s // 4, s // 4)
  66. targets, _ = centernet_head.get_targets(gt_bboxes, gt_labels,
  67. self.feat_shape,
  68. img_metas[0]['img_shape'])
  69. center_target = targets['center_heatmap_target']
  70. wh_target = targets['wh_target']
  71. offset_target = targets['offset_target']
  72. # make sure assign target right
  73. for i in range(len(gt_bboxes[0])):
  74. bbox, label = gt_bboxes[0][i] / 4, gt_labels[0][i]
  75. ctx, cty = sum(bbox[0::2]) / 2, sum(bbox[1::2]) / 2
  76. int_ctx, int_cty = int(sum(bbox[0::2]) / 2), int(
  77. sum(bbox[1::2]) / 2)
  78. w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
  79. x_off = ctx - int(ctx)
  80. y_off = cty - int(cty)
  81. assert center_target[0, label, int_cty, int_ctx] == 1
  82. assert wh_target[0, 0, int_cty, int_ctx] == w
  83. assert wh_target[0, 1, int_cty, int_ctx] == h
  84. assert offset_target[0, 0, int_cty, int_ctx] == x_off
  85. assert offset_target[0, 1, int_cty, int_ctx] == y_off
  86. def test_centernet_head_get_results(self):
  87. """Tests center head generating and decoding the heatmap."""
  88. s = 256
  89. img_metas = [{
  90. 'img_shape': (s, s, 3),
  91. 'batch_input_shape': (s, s),
  92. 'border': (0, 0, 0, 0),
  93. }]
  94. test_cfg = ConfigDict(
  95. dict(
  96. topk=100,
  97. local_maximum_kernel=3,
  98. max_per_img=100,
  99. nms=dict(type='nms', iou_threshold=0.5)))
  100. gt_bboxes = [
  101. torch.Tensor([[10, 20, 200, 240], [40, 50, 100, 200],
  102. [10, 20, 100, 240]])
  103. ]
  104. gt_labels = [torch.LongTensor([1, 1, 2])]
  105. centernet_head = CenterNetHead(
  106. num_classes=4, in_channels=1, feat_channels=4, test_cfg=test_cfg)
  107. self.feat_shape = (1, 1, s // 4, s // 4)
  108. targets, _ = centernet_head.get_targets(gt_bboxes, gt_labels,
  109. self.feat_shape,
  110. img_metas[0]['img_shape'])
  111. center_target = targets['center_heatmap_target']
  112. wh_target = targets['wh_target']
  113. offset_target = targets['offset_target']
  114. # make sure get_bboxes is right
  115. detections = centernet_head.predict_by_feat([center_target],
  116. [wh_target],
  117. [offset_target],
  118. img_metas,
  119. rescale=True,
  120. with_nms=False)
  121. pred_instances = detections[0]
  122. out_bboxes = pred_instances.bboxes[:3]
  123. out_clses = pred_instances.labels[:3]
  124. for bbox, cls in zip(out_bboxes, out_clses):
  125. flag = False
  126. for gt_bbox, gt_cls in zip(gt_bboxes[0], gt_labels[0]):
  127. if (bbox[:4] == gt_bbox[:4]).all():
  128. flag = True
  129. assert flag, 'get_bboxes is wrong'
  130. detections = centernet_head.predict_by_feat([center_target],
  131. [wh_target],
  132. [offset_target],
  133. img_metas,
  134. rescale=True,
  135. with_nms=True)
  136. pred_instances = detections[0]
  137. out_bboxes = pred_instances.bboxes[:3]
  138. out_clses = pred_instances.labels[:3]
  139. for bbox, cls in zip(out_bboxes, out_clses):
  140. flag = False
  141. for gt_bbox, gt_cls in zip(gt_bboxes[0], gt_labels[0]):
  142. if (bbox[:4] == gt_bbox[:4]).all():
  143. flag = True
  144. assert flag, 'get_bboxes is wrong'