test_fsaf_head.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from math import ceil
  3. from unittest import TestCase
  4. import torch
  5. from mmengine import Config
  6. from mmengine.structures import InstanceData
  7. from mmdet.models.dense_heads import FSAFHead
  8. class TestFSAFHead(TestCase):
  9. def test_fsaf_head_loss(self):
  10. """Tests fsaf head loss when truth is empty and non-empty."""
  11. s = 300
  12. img_metas = [{
  13. 'img_shape': (s, s),
  14. 'pad_shape': (s, s),
  15. 'scale_factor': 1,
  16. }]
  17. cfg = Config(
  18. dict(
  19. assigner=dict(
  20. type='CenterRegionAssigner',
  21. pos_scale=0.2,
  22. neg_scale=0.2,
  23. min_pos_iof=0.01),
  24. allowed_border=-1,
  25. pos_weight=-1,
  26. debug=False))
  27. fsaf_head = FSAFHead(
  28. num_classes=4,
  29. in_channels=1,
  30. stacked_convs=1,
  31. feat_channels=1,
  32. reg_decoded_bbox=True,
  33. anchor_generator=dict(
  34. type='AnchorGenerator',
  35. octave_base_scale=1,
  36. scales_per_octave=1,
  37. ratios=[1.0],
  38. strides=[8, 16, 32, 64, 128]),
  39. bbox_coder=dict(type='TBLRBBoxCoder', normalizer=4.0),
  40. loss_cls=dict(
  41. type='FocalLoss',
  42. use_sigmoid=True,
  43. gamma=2.0,
  44. alpha=0.25,
  45. loss_weight=1.0,
  46. reduction='none'),
  47. loss_bbox=dict(
  48. type='IoULoss', eps=1e-6, loss_weight=1.0, reduction='none'),
  49. train_cfg=cfg)
  50. # FSAF head expects a multiple levels of features per image
  51. feats = (
  52. torch.rand(1, 1, ceil(s / stride[0]), ceil(s / stride[0]))
  53. for stride in fsaf_head.prior_generator.strides)
  54. cls_scores, bbox_preds = fsaf_head.forward(feats)
  55. # Test that empty ground truth encourages the network to
  56. # predict background
  57. gt_instances = InstanceData()
  58. gt_instances.bboxes = torch.empty((0, 4))
  59. gt_instances.labels = torch.LongTensor([])
  60. empty_gt_losses = fsaf_head.loss_by_feat(cls_scores, bbox_preds,
  61. [gt_instances], img_metas)
  62. # When there is no truth, the cls loss should be nonzero but
  63. # box loss should be zero
  64. empty_cls_loss = sum(empty_gt_losses['loss_cls'])
  65. empty_box_loss = sum(empty_gt_losses['loss_bbox'])
  66. self.assertGreater(empty_cls_loss, 0, 'cls loss should be non-zero')
  67. self.assertEqual(
  68. empty_box_loss.item(), 0,
  69. 'there should be no box loss when there are no true boxes')
  70. # When truth is non-empty then both cls and box loss
  71. # should be nonzero for random inputs
  72. gt_instances = InstanceData()
  73. gt_instances.bboxes = torch.Tensor(
  74. [[23.6667, 23.8757, 238.6326, 151.8874]])
  75. gt_instances.labels = torch.LongTensor([2])
  76. one_gt_losses = fsaf_head.loss_by_feat(cls_scores, bbox_preds,
  77. [gt_instances], img_metas)
  78. onegt_cls_loss = sum(one_gt_losses['loss_cls'])
  79. onegt_box_loss = sum(one_gt_losses['loss_bbox'])
  80. self.assertGreater(onegt_cls_loss.item(), 0,
  81. 'cls loss should be non-zero')
  82. self.assertGreater(onegt_box_loss.item(), 0,
  83. 'box loss should be non-zero')