test_ld_head.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine import Config
  5. from mmengine.structures import InstanceData
  6. from mmdet import * # noqa
  7. from mmdet.models.dense_heads import GFLHead, LDHead
  8. class TestLDHead(TestCase):
  9. def test_ld_head_loss(self):
  10. """Tests ld head loss when truth is empty and non-empty."""
  11. s = 256
  12. img_metas = [{
  13. 'img_shape': (s, s, 3),
  14. 'pad_shape': (s, s, 3),
  15. 'scale_factor': 1
  16. }]
  17. train_cfg = Config(
  18. dict(
  19. assigner=dict(type='ATSSAssigner', topk=9, ignore_iof_thr=0.1),
  20. allowed_border=-1,
  21. pos_weight=-1,
  22. debug=False))
  23. ld_head = LDHead(
  24. num_classes=4,
  25. in_channels=1,
  26. train_cfg=train_cfg,
  27. loss_ld=dict(
  28. type='KnowledgeDistillationKLDivLoss', loss_weight=1.0),
  29. loss_cls=dict(
  30. type='QualityFocalLoss',
  31. use_sigmoid=True,
  32. beta=2.0,
  33. loss_weight=1.0),
  34. loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
  35. anchor_generator=dict(
  36. type='AnchorGenerator',
  37. ratios=[1.0],
  38. octave_base_scale=8,
  39. scales_per_octave=1,
  40. strides=[8, 16, 32, 64, 128]))
  41. teacher_model = GFLHead(
  42. num_classes=4,
  43. in_channels=1,
  44. train_cfg=train_cfg,
  45. loss_cls=dict(
  46. type='QualityFocalLoss',
  47. use_sigmoid=True,
  48. beta=2.0,
  49. loss_weight=1.0),
  50. loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
  51. anchor_generator=dict(
  52. type='AnchorGenerator',
  53. ratios=[1.0],
  54. octave_base_scale=8,
  55. scales_per_octave=1,
  56. strides=[8, 16, 32, 64, 128]))
  57. feat = [
  58. torch.rand(1, 1, s // feat_size, s // feat_size)
  59. for feat_size in [4, 8, 16, 32, 64]
  60. ]
  61. cls_scores, bbox_preds = ld_head.forward(feat)
  62. rand_soft_target = teacher_model.forward(feat)[1]
  63. # Test that empty ground truth encourages the network to predict
  64. # background
  65. gt_instances = InstanceData()
  66. gt_instances.bboxes = torch.empty((0, 4))
  67. gt_instances.labels = torch.LongTensor([])
  68. batch_gt_instances_ignore = None
  69. empty_gt_losses = ld_head.loss_by_feat(cls_scores, bbox_preds,
  70. [gt_instances], img_metas,
  71. rand_soft_target,
  72. batch_gt_instances_ignore)
  73. # When there is no truth, the cls loss should be nonzero, ld loss
  74. # should be non-negative but there should be no box loss.
  75. empty_cls_loss = sum(empty_gt_losses['loss_cls'])
  76. empty_box_loss = sum(empty_gt_losses['loss_bbox'])
  77. empty_ld_loss = sum(empty_gt_losses['loss_ld'])
  78. self.assertGreater(empty_cls_loss.item(), 0,
  79. 'cls loss should be non-zero')
  80. self.assertEqual(
  81. empty_box_loss.item(), 0,
  82. 'there should be no box loss when there are no true boxes')
  83. self.assertGreaterEqual(empty_ld_loss.item(), 0,
  84. 'ld loss should be non-negative')
  85. # When truth is non-empty then both cls and box loss should be nonzero
  86. # for random inputs
  87. gt_instances = InstanceData()
  88. gt_instances.bboxes = torch.Tensor(
  89. [[23.6667, 23.8757, 238.6326, 151.8874]])
  90. gt_instances.labels = torch.LongTensor([2])
  91. batch_gt_instances_ignore = None
  92. one_gt_losses = ld_head.loss_by_feat(cls_scores, bbox_preds,
  93. [gt_instances], img_metas,
  94. rand_soft_target,
  95. batch_gt_instances_ignore)
  96. onegt_cls_loss = sum(one_gt_losses['loss_cls'])
  97. onegt_box_loss = sum(one_gt_losses['loss_bbox'])
  98. self.assertGreater(onegt_cls_loss.item(), 0,
  99. 'cls loss should be non-zero')
  100. self.assertGreater(onegt_box_loss.item(), 0,
  101. 'box loss should be non-zero')
  102. batch_gt_instances_ignore = gt_instances
  103. # When truth is non-empty but ignored then the cls loss should be
  104. # nonzero, but there should be no box loss.
  105. ignore_gt_losses = ld_head.loss_by_feat(cls_scores, bbox_preds,
  106. [gt_instances], img_metas,
  107. rand_soft_target,
  108. batch_gt_instances_ignore)
  109. ignore_cls_loss = sum(ignore_gt_losses['loss_cls'])
  110. ignore_box_loss = sum(ignore_gt_losses['loss_bbox'])
  111. self.assertGreater(ignore_cls_loss.item(), 0,
  112. 'cls loss should be non-zero')
  113. self.assertEqual(ignore_box_loss.item(), 0,
  114. 'gt bbox ignored loss should be zero')
  115. # When truth is non-empty and not ignored then both cls and box loss
  116. # should be nonzero for random inputs
  117. batch_gt_instances_ignore = InstanceData()
  118. batch_gt_instances_ignore.bboxes = torch.randn(1, 4)
  119. not_ignore_gt_losses = ld_head.loss_by_feat(cls_scores, bbox_preds,
  120. [gt_instances], img_metas,
  121. rand_soft_target,
  122. batch_gt_instances_ignore)
  123. not_ignore_cls_loss = sum(not_ignore_gt_losses['loss_cls'])
  124. not_ignore_box_loss = sum(not_ignore_gt_losses['loss_bbox'])
  125. self.assertGreater(not_ignore_cls_loss.item(), 0,
  126. 'cls loss should be non-zero')
  127. self.assertGreaterEqual(not_ignore_box_loss.item(), 0,
  128. 'gt bbox not ignored loss should be non-zero')