test_vfnet_head.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine import Config
  5. from mmengine.structures import InstanceData
  6. from mmdet import * # noqa
  7. from mmdet.models.dense_heads import VFNetHead
  8. class TestVFNetHead(TestCase):
  9. def test_vfnet_head_loss(self):
  10. """Tests vfnet head loss when truth is empty and non-empty."""
  11. s = 256
  12. img_metas = [{
  13. 'img_shape': (s, s, 3),
  14. 'scale_factor': 1,
  15. 'pad_shape': (s, s, 3)
  16. }]
  17. train_cfg = Config(
  18. dict(
  19. assigner=dict(type='ATSSAssigner', topk=9),
  20. allowed_border=-1,
  21. pos_weight=-1,
  22. debug=False))
  23. # since VarFocal Loss is not supported on CPU
  24. vfnet_head = VFNetHead(
  25. num_classes=4,
  26. in_channels=1,
  27. train_cfg=train_cfg,
  28. loss_cls=dict(
  29. type='VarifocalLoss', use_sigmoid=True, loss_weight=1.0))
  30. feat = [
  31. torch.rand(1, 1, s // feat_size, s // feat_size)
  32. for feat_size in [4, 8, 16, 32, 64]
  33. ]
  34. cls_scores, bbox_preds, bbox_preds_refine = vfnet_head.forward(feat)
  35. # Test that empty ground truth encourages the network to predict
  36. # background
  37. gt_instances = InstanceData()
  38. gt_instances.bboxes = torch.empty((0, 4))
  39. gt_instances.labels = torch.LongTensor([])
  40. empty_gt_losses = vfnet_head.loss_by_feat(cls_scores, bbox_preds,
  41. bbox_preds_refine,
  42. [gt_instances], img_metas)
  43. # When there is no truth, the cls loss should be nonzero but there
  44. # should be no box loss.
  45. empty_cls_loss = empty_gt_losses['loss_cls']
  46. empty_box_loss = empty_gt_losses['loss_bbox']
  47. self.assertGreater(empty_cls_loss.item(), 0,
  48. 'cls loss should be non-zero')
  49. self.assertEqual(
  50. empty_box_loss.item(), 0,
  51. 'there should be no box loss when there are no true boxes')
  52. # When truth is non-empty then both cls and box loss should be nonzero
  53. # for random inputs
  54. gt_instances = InstanceData()
  55. gt_instances.bboxes = torch.Tensor(
  56. [[23.6667, 23.8757, 238.6326, 151.8874]])
  57. gt_instances.labels = torch.LongTensor([2])
  58. one_gt_losses = vfnet_head.loss_by_feat(cls_scores, bbox_preds,
  59. bbox_preds_refine,
  60. [gt_instances], img_metas)
  61. onegt_cls_loss = one_gt_losses['loss_cls']
  62. onegt_box_loss = one_gt_losses['loss_bbox']
  63. self.assertGreater(onegt_cls_loss.item(), 0,
  64. 'cls loss should be non-zero')
  65. self.assertGreater(onegt_box_loss.item(), 0,
  66. 'box loss should be non-zero')
  67. def test_vfnet_head_loss_without_atss(self):
  68. """Tests vfnet head loss when truth is empty and non-empty."""
  69. s = 256
  70. img_metas = [{
  71. 'img_shape': (s, s, 3),
  72. 'scale_factor': 1,
  73. 'pad_shape': (s, s, 3)
  74. }]
  75. train_cfg = Config(
  76. dict(
  77. assigner=dict(type='ATSSAssigner', topk=9),
  78. allowed_border=-1,
  79. pos_weight=-1,
  80. debug=False))
  81. # since VarFocal Loss is not supported on CPU
  82. vfnet_head = VFNetHead(
  83. num_classes=4,
  84. in_channels=1,
  85. train_cfg=train_cfg,
  86. use_atss=False,
  87. loss_cls=dict(
  88. type='VarifocalLoss', use_sigmoid=True, loss_weight=1.0))
  89. feat = [
  90. torch.rand(1, 1, s // feat_size, s // feat_size)
  91. for feat_size in [4, 8, 16, 32, 64]
  92. ]
  93. cls_scores, bbox_preds, bbox_preds_refine = vfnet_head.forward(feat)
  94. # Test that empty ground truth encourages the network to predict
  95. # background
  96. gt_instances = InstanceData()
  97. gt_instances.bboxes = torch.empty((0, 4))
  98. gt_instances.labels = torch.LongTensor([])
  99. empty_gt_losses = vfnet_head.loss_by_feat(cls_scores, bbox_preds,
  100. bbox_preds_refine,
  101. [gt_instances], img_metas)
  102. # When there is no truth, the cls loss should be nonzero but there
  103. # should be no box loss.
  104. empty_cls_loss = empty_gt_losses['loss_cls']
  105. empty_box_loss = empty_gt_losses['loss_bbox']
  106. self.assertGreater(empty_cls_loss.item(), 0,
  107. 'cls loss should be non-zero')
  108. self.assertEqual(
  109. empty_box_loss.item(), 0,
  110. 'there should be no box loss when there are no true boxes')
  111. # When truth is non-empty then both cls and box loss should be nonzero
  112. # for random inputs
  113. gt_instances = InstanceData()
  114. gt_instances.bboxes = torch.Tensor(
  115. [[23.6667, 23.8757, 238.6326, 151.8874]])
  116. gt_instances.labels = torch.LongTensor([2])
  117. one_gt_losses = vfnet_head.loss_by_feat(cls_scores, bbox_preds,
  118. bbox_preds_refine,
  119. [gt_instances], img_metas)
  120. onegt_cls_loss = one_gt_losses['loss_cls']
  121. onegt_box_loss = one_gt_losses['loss_bbox']
  122. self.assertGreater(onegt_cls_loss.item(), 0,
  123. 'cls loss should be non-zero')
  124. self.assertGreater(onegt_box_loss.item(), 0,
  125. 'box loss should be non-zero')