test_tood_head.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine import Config, MessageHub
  5. from mmengine.structures import InstanceData
  6. from mmdet import * # noqa
  7. from mmdet.models.dense_heads import TOODHead
  8. def _tood_head(anchor_type):
  9. """Set type of tood head."""
  10. train_cfg = Config(
  11. dict(
  12. initial_epoch=4,
  13. initial_assigner=dict(type='ATSSAssigner', topk=9),
  14. assigner=dict(type='TaskAlignedAssigner', topk=13),
  15. alpha=1,
  16. beta=6,
  17. allowed_border=-1,
  18. pos_weight=-1,
  19. debug=False))
  20. test_cfg = Config(
  21. dict(
  22. nms_pre=1000,
  23. min_bbox_size=0,
  24. score_thr=0.05,
  25. nms=dict(type='nms', iou_threshold=0.6),
  26. max_per_img=100))
  27. tood_head = TOODHead(
  28. num_classes=80,
  29. in_channels=1,
  30. stacked_convs=1,
  31. feat_channels=8, # the same as `la_down_rate` in TaskDecomposition
  32. norm_cfg=None,
  33. anchor_type=anchor_type,
  34. anchor_generator=dict(
  35. type='AnchorGenerator',
  36. ratios=[1.0],
  37. octave_base_scale=8,
  38. scales_per_octave=1,
  39. strides=[8, 16, 32, 64, 128]),
  40. bbox_coder=dict(
  41. type='DeltaXYWHBBoxCoder',
  42. target_means=[.0, .0, .0, .0],
  43. target_stds=[0.1, 0.1, 0.2, 0.2]),
  44. initial_loss_cls=dict(
  45. type='FocalLoss',
  46. use_sigmoid=True,
  47. activated=True, # use probability instead of logit as input
  48. gamma=2.0,
  49. alpha=0.25,
  50. loss_weight=1.0),
  51. loss_cls=dict(
  52. type='QualityFocalLoss',
  53. use_sigmoid=True,
  54. activated=True, # use probability instead of logit as input
  55. beta=2.0,
  56. loss_weight=1.0),
  57. loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
  58. train_cfg=train_cfg,
  59. test_cfg=test_cfg)
  60. return tood_head
  61. class TestTOODHead(TestCase):
  62. def test_tood_head_anchor_free_loss(self):
  63. """Tests tood head loss when truth is empty and non-empty."""
  64. s = 256
  65. img_metas = [{
  66. 'img_shape': (s, s, 3),
  67. 'pad_shape': (s, s, 3),
  68. 'scale_factor': 1
  69. }]
  70. tood_head = _tood_head('anchor_free')
  71. tood_head.init_weights()
  72. feat = [
  73. torch.rand(1, 1, s // feat_size, s // feat_size)
  74. for feat_size in [8, 16, 32, 64, 128]
  75. ]
  76. cls_scores, bbox_preds = tood_head(feat)
  77. message_hub = MessageHub.get_instance('runtime_info')
  78. message_hub.update_info('epoch', 0)
  79. # Test that empty ground truth encourages the network to predict
  80. # background
  81. gt_instances = InstanceData()
  82. gt_instances.bboxes = torch.empty((0, 4))
  83. gt_instances.labels = torch.LongTensor([])
  84. gt_bboxes_ignore = None
  85. empty_gt_losses = tood_head.loss_by_feat(cls_scores, bbox_preds,
  86. [gt_instances], img_metas,
  87. gt_bboxes_ignore)
  88. # When there is no truth, the cls loss should be nonzero but there
  89. # should be no box loss.
  90. empty_cls_loss = empty_gt_losses['loss_cls']
  91. empty_box_loss = empty_gt_losses['loss_bbox']
  92. self.assertGreater(
  93. sum(empty_cls_loss).item(), 0, 'cls loss should be non-zero')
  94. self.assertEqual(
  95. sum(empty_box_loss).item(), 0,
  96. 'there should be no box loss when there are no true boxes')
  97. # When truth is non-empty then both cls and box loss should be nonzero
  98. # for random inputs
  99. gt_instances = InstanceData()
  100. gt_instances.bboxes = torch.Tensor(
  101. [[23.6667, 23.8757, 238.6326, 151.8874]])
  102. gt_instances.labels = torch.LongTensor([2])
  103. gt_bboxes_ignore = None
  104. one_gt_losses = tood_head.loss_by_feat(cls_scores, bbox_preds,
  105. [gt_instances], img_metas,
  106. gt_bboxes_ignore)
  107. onegt_cls_loss = one_gt_losses['loss_cls']
  108. onegt_box_loss = one_gt_losses['loss_bbox']
  109. self.assertGreater(
  110. sum(onegt_cls_loss).item(), 0, 'cls loss should be non-zero')
  111. self.assertGreater(
  112. sum(onegt_box_loss).item(), 0, 'box loss should be non-zero')
  113. # Test that empty ground truth encourages the network to predict
  114. # background
  115. gt_instances = InstanceData()
  116. gt_instances.bboxes = torch.empty((0, 4))
  117. gt_instances.labels = torch.LongTensor([])
  118. gt_bboxes_ignore = None
  119. empty_gt_losses = tood_head.loss_by_feat(cls_scores, bbox_preds,
  120. [gt_instances], img_metas,
  121. gt_bboxes_ignore)
  122. # When there is no truth, the cls loss should be nonzero but there
  123. # should be no box loss.
  124. empty_cls_loss = empty_gt_losses['loss_cls']
  125. empty_box_loss = empty_gt_losses['loss_bbox']
  126. self.assertGreater(
  127. sum(empty_cls_loss).item(), 0, 'cls loss should be non-zero')
  128. self.assertEqual(
  129. sum(empty_box_loss).item(), 0,
  130. 'there should be no box loss when there are no true boxes')
  131. # When truth is non-empty then both cls and box loss should be nonzero
  132. # for random inputs
  133. gt_instances = InstanceData()
  134. gt_instances.bboxes = torch.Tensor(
  135. [[23.6667, 23.8757, 238.6326, 151.8874]])
  136. gt_instances.labels = torch.LongTensor([2])
  137. gt_bboxes_ignore = None
  138. one_gt_losses = tood_head.loss_by_feat(cls_scores, bbox_preds,
  139. [gt_instances], img_metas,
  140. gt_bboxes_ignore)
  141. onegt_cls_loss = one_gt_losses['loss_cls']
  142. onegt_box_loss = one_gt_losses['loss_bbox']
  143. self.assertGreater(
  144. sum(onegt_cls_loss).item(), 0, 'cls loss should be non-zero')
  145. self.assertGreater(
  146. sum(onegt_box_loss).item(), 0, 'box loss should be non-zero')
  147. def test_tood_head_anchor_based_loss(self):
  148. """Tests tood head loss when truth is empty and non-empty."""
  149. s = 256
  150. img_metas = [{
  151. 'img_shape': (s, s, 3),
  152. 'pad_shape': (s, s, 3),
  153. 'scale_factor': 1
  154. }]
  155. tood_head = _tood_head('anchor_based')
  156. tood_head.init_weights()
  157. feat = [
  158. torch.rand(1, 1, s // feat_size, s // feat_size)
  159. for feat_size in [8, 16, 32, 64, 128]
  160. ]
  161. cls_scores, bbox_preds = tood_head(feat)
  162. message_hub = MessageHub.get_instance('runtime_info')
  163. message_hub.update_info('epoch', 0)
  164. # Test that empty ground truth encourages the network to predict
  165. # background
  166. gt_instances = InstanceData()
  167. gt_instances.bboxes = torch.empty((0, 4))
  168. gt_instances.labels = torch.LongTensor([])
  169. gt_bboxes_ignore = None
  170. empty_gt_losses = tood_head.loss_by_feat(cls_scores, bbox_preds,
  171. [gt_instances], img_metas,
  172. gt_bboxes_ignore)
  173. # When there is no truth, the cls loss should be nonzero but there
  174. # should be no box loss.
  175. empty_cls_loss = empty_gt_losses['loss_cls']
  176. empty_box_loss = empty_gt_losses['loss_bbox']
  177. self.assertGreater(
  178. sum(empty_cls_loss).item(), 0, 'cls loss should be non-zero')
  179. self.assertEqual(
  180. sum(empty_box_loss).item(), 0,
  181. 'there should be no box loss when there are no true boxes')