test_rpn_head.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import pytest
  4. import torch
  5. from mmengine import Config
  6. from mmengine.structures import InstanceData
  7. from mmdet import * # noqa
  8. from mmdet.models.dense_heads import RPNHead
  9. class TestRPNHead(TestCase):
  10. def test_init(self):
  11. """Test init rpn head."""
  12. rpn_head = RPNHead(num_classes=1, in_channels=1)
  13. self.assertTrue(rpn_head.rpn_conv)
  14. self.assertTrue(rpn_head.rpn_cls)
  15. self.assertTrue(rpn_head.rpn_reg)
  16. # rpn_head.num_convs > 1
  17. rpn_head = RPNHead(num_classes=1, in_channels=1, num_convs=2)
  18. self.assertTrue(rpn_head.rpn_conv)
  19. self.assertTrue(rpn_head.rpn_cls)
  20. self.assertTrue(rpn_head.rpn_reg)
  21. def test_rpn_head_loss(self):
  22. """Tests rpn head loss when truth is empty and non-empty."""
  23. s = 256
  24. img_metas = [{
  25. 'img_shape': (s, s, 3),
  26. 'pad_shape': (s, s, 3),
  27. 'scale_factor': 1,
  28. }]
  29. cfg = Config(
  30. dict(
  31. assigner=dict(
  32. type='MaxIoUAssigner',
  33. pos_iou_thr=0.7,
  34. neg_iou_thr=0.3,
  35. min_pos_iou=0.3,
  36. ignore_iof_thr=-1),
  37. sampler=dict(
  38. type='RandomSampler',
  39. num=256,
  40. pos_fraction=0.5,
  41. neg_pos_ub=-1,
  42. add_gt_as_proposals=False),
  43. allowed_border=0,
  44. pos_weight=-1,
  45. debug=False))
  46. rpn_head = RPNHead(num_classes=1, in_channels=1, train_cfg=cfg)
  47. # Anchor head expects a multiple levels of features per image
  48. feats = (
  49. torch.rand(1, 1, s // (2**(i + 2)), s // (2**(i + 2)))
  50. for i in range(len(rpn_head.prior_generator.strides)))
  51. cls_scores, bbox_preds = rpn_head.forward(feats)
  52. # Test that empty ground truth encourages the network to
  53. # predict background
  54. gt_instances = InstanceData()
  55. gt_instances.bboxes = torch.empty((0, 4))
  56. gt_instances.labels = torch.LongTensor([])
  57. empty_gt_losses = rpn_head.loss_by_feat(cls_scores, bbox_preds,
  58. [gt_instances], img_metas)
  59. # When there is no truth, the cls loss should be nonzero but
  60. # there should be no box loss.
  61. empty_cls_loss = sum(empty_gt_losses['loss_rpn_cls'])
  62. empty_box_loss = sum(empty_gt_losses['loss_rpn_bbox'])
  63. self.assertGreater(empty_cls_loss.item(), 0,
  64. 'rpn cls loss should be non-zero')
  65. self.assertEqual(
  66. empty_box_loss.item(), 0,
  67. 'there should be no box loss when there are no true boxes')
  68. # When truth is non-empty then both cls and box loss
  69. # should be nonzero for random inputs
  70. gt_instances = InstanceData()
  71. gt_instances.bboxes = torch.Tensor(
  72. [[23.6667, 23.8757, 238.6326, 151.8874]])
  73. gt_instances.labels = torch.LongTensor([0])
  74. one_gt_losses = rpn_head.loss_by_feat(cls_scores, bbox_preds,
  75. [gt_instances], img_metas)
  76. onegt_cls_loss = sum(one_gt_losses['loss_rpn_cls'])
  77. onegt_box_loss = sum(one_gt_losses['loss_rpn_bbox'])
  78. self.assertGreater(onegt_cls_loss.item(), 0,
  79. 'rpn cls loss should be non-zero')
  80. self.assertGreater(onegt_box_loss.item(), 0,
  81. 'rpn box loss should be non-zero')
  82. # When there is no valid anchor, the loss will be None,
  83. # and this will raise a ValueError.
  84. img_metas = [{
  85. 'img_shape': (8, 8, 3),
  86. 'pad_shape': (8, 8, 3),
  87. 'scale_factor': 1,
  88. }]
  89. with pytest.raises(ValueError):
  90. rpn_head.loss_by_feat(cls_scores, bbox_preds, [gt_instances],
  91. img_metas)
  92. def test_bbox_post_process(self):
  93. """Test the length of detection instance results is 0."""
  94. from mmengine.config import ConfigDict
  95. cfg = ConfigDict(
  96. nms_pre=1000,
  97. max_per_img=1000,
  98. nms=dict(type='nms', iou_threshold=0.7),
  99. min_bbox_size=0)
  100. rpn_head = RPNHead(num_classes=1, in_channels=1)
  101. results = InstanceData(metainfo=dict())
  102. results.bboxes = torch.zeros((0, 4))
  103. results.scores = torch.zeros(0)
  104. results = rpn_head._bbox_post_process(results, cfg, img_meta=dict())
  105. self.assertEqual(len(results), 0)
  106. self.assertEqual(results.bboxes.size(), (0, 4))
  107. self.assertEqual(results.scores.size(), (0, ))
  108. self.assertEqual(results.labels.size(), (0, ))