test_paa_head.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. import torch
  5. from mmengine import Config
  6. from mmengine.structures import InstanceData
  7. from mmdet import * # noqa
  8. from mmdet.models.dense_heads import PAAHead, paa_head
  9. from mmdet.models.utils import levels_to_images
  10. class TestPAAHead(TestCase):
  11. def test_paa_head_loss(self):
  12. """Tests paa head loss when truth is empty and non-empty."""
  13. class mock_skm:
  14. def GaussianMixture(self, *args, **kwargs):
  15. return self
  16. def fit(self, loss):
  17. pass
  18. def predict(self, loss):
  19. components = np.zeros_like(loss, dtype=np.long)
  20. return components.reshape(-1)
  21. def score_samples(self, loss):
  22. scores = np.random.random(len(loss))
  23. return scores
  24. paa_head.skm = mock_skm()
  25. s = 256
  26. img_metas = [{
  27. 'img_shape': (s, s, 3),
  28. 'pad_shape': (s, s, 3),
  29. 'scale_factor': 1,
  30. }]
  31. train_cfg = Config(
  32. dict(
  33. assigner=dict(
  34. type='MaxIoUAssigner',
  35. pos_iou_thr=0.1,
  36. neg_iou_thr=0.1,
  37. min_pos_iou=0,
  38. ignore_iof_thr=-1),
  39. allowed_border=-1,
  40. pos_weight=-1,
  41. debug=False))
  42. # since Focal Loss is not supported on CPU
  43. paa = PAAHead(
  44. num_classes=4,
  45. in_channels=1,
  46. train_cfg=train_cfg,
  47. anchor_generator=dict(
  48. type='AnchorGenerator',
  49. ratios=[1.0],
  50. octave_base_scale=8,
  51. scales_per_octave=1,
  52. strides=[8, 16, 32, 64, 128]),
  53. loss_cls=dict(
  54. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
  55. loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
  56. loss_centerness=dict(
  57. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5))
  58. feat = [
  59. torch.rand(1, 1, s // feat_size, s // feat_size)
  60. for feat_size in [4, 8, 16, 32, 64]
  61. ]
  62. paa.init_weights()
  63. cls_scores, bbox_preds, iou_preds = paa(feat)
  64. # Test that empty ground truth encourages the network to predict
  65. # background
  66. gt_instances = InstanceData()
  67. gt_instances.bboxes = torch.empty((0, 4))
  68. gt_instances.labels = torch.LongTensor([])
  69. empty_gt_losses = paa.loss_by_feat(cls_scores, bbox_preds, iou_preds,
  70. [gt_instances], img_metas)
  71. # When there is no truth, the cls loss should be nonzero but there
  72. # should be no box loss.
  73. empty_cls_loss = empty_gt_losses['loss_cls']
  74. empty_box_loss = empty_gt_losses['loss_bbox']
  75. empty_iou_loss = empty_gt_losses['loss_iou']
  76. self.assertGreater(empty_cls_loss.item(), 0,
  77. 'cls loss should be non-zero')
  78. self.assertEqual(
  79. empty_box_loss.item(), 0,
  80. 'there should be no box loss when there are no true boxes')
  81. self.assertEqual(
  82. empty_iou_loss.item(), 0,
  83. 'there should be no box loss when there are no true boxes')
  84. # When truth is non-empty then both cls and box loss should be nonzero
  85. # for random inputs
  86. gt_instances = InstanceData()
  87. gt_instances.bboxes = torch.Tensor(
  88. [[23.6667, 23.8757, 238.6326, 151.8874]])
  89. gt_instances.labels = torch.LongTensor([2])
  90. one_gt_losses = paa.loss_by_feat(cls_scores, bbox_preds, iou_preds,
  91. [gt_instances], img_metas)
  92. onegt_cls_loss = one_gt_losses['loss_cls']
  93. onegt_box_loss = one_gt_losses['loss_bbox']
  94. onegt_iou_loss = one_gt_losses['loss_iou']
  95. self.assertGreater(onegt_cls_loss.item(), 0,
  96. 'cls loss should be non-zero')
  97. self.assertGreater(onegt_box_loss.item(), 0,
  98. 'box loss should be non-zero')
  99. self.assertGreater(onegt_iou_loss.item(), 0,
  100. 'box loss should be non-zero')
  101. n, c, h, w = 10, 4, 20, 20
  102. mlvl_tensor = [torch.ones(n, c, h, w) for i in range(5)]
  103. results = levels_to_images(mlvl_tensor)
  104. self.assertEqual(len(results), n)
  105. self.assertEqual(results[0].size(), (h * w * 5, c))
  106. self.assertTrue(paa.with_score_voting)
  107. paa = PAAHead(
  108. num_classes=4,
  109. in_channels=1,
  110. train_cfg=train_cfg,
  111. anchor_generator=dict(
  112. type='AnchorGenerator',
  113. ratios=[1.0],
  114. octave_base_scale=8,
  115. scales_per_octave=1,
  116. strides=[8]),
  117. loss_cls=dict(
  118. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
  119. loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
  120. loss_centerness=dict(
  121. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5))
  122. cls_scores = [torch.ones(2, 4, 5, 5)]
  123. bbox_preds = [torch.ones(2, 4, 5, 5)]
  124. iou_preds = [torch.ones(2, 1, 5, 5)]
  125. cfg = Config(
  126. dict(
  127. nms_pre=1000,
  128. min_bbox_size=0,
  129. score_thr=0.05,
  130. nms=dict(type='nms', iou_threshold=0.6),
  131. max_per_img=100))
  132. rescale = False
  133. paa.predict_by_feat(
  134. cls_scores, bbox_preds, iou_preds, img_metas, cfg, rescale=rescale)