test_atss_head.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine import Config
  5. from mmengine.structures import InstanceData
  6. from mmdet import * # noqa
  7. from mmdet.models.dense_heads import ATSSHead
  8. class TestATSSHead(TestCase):
  9. def test_atss_head_loss(self):
  10. """Tests atss head loss when truth is empty and non-empty."""
  11. s = 256
  12. img_metas = [{
  13. 'img_shape': (s, s, 3),
  14. 'pad_shape': (s, s, 3),
  15. 'scale_factor': 1
  16. }]
  17. cfg = Config(
  18. dict(
  19. assigner=dict(type='ATSSAssigner', topk=9),
  20. allowed_border=-1,
  21. pos_weight=-1,
  22. debug=False))
  23. atss_head = ATSSHead(
  24. num_classes=4,
  25. in_channels=1,
  26. stacked_convs=1,
  27. feat_channels=1,
  28. norm_cfg=None,
  29. train_cfg=cfg,
  30. anchor_generator=dict(
  31. type='AnchorGenerator',
  32. ratios=[1.0],
  33. octave_base_scale=8,
  34. scales_per_octave=1,
  35. strides=[8, 16, 32, 64, 128]),
  36. loss_cls=dict(
  37. type='FocalLoss',
  38. use_sigmoid=True,
  39. gamma=2.0,
  40. alpha=0.25,
  41. loss_weight=1.0),
  42. loss_bbox=dict(type='GIoULoss', loss_weight=2.0))
  43. feat = [
  44. torch.rand(1, 1, s // feat_size, s // feat_size)
  45. for feat_size in [8, 16, 32, 64, 128]
  46. ]
  47. cls_scores, bbox_preds, centernesses = atss_head.forward(feat)
  48. # Test that empty ground truth encourages the network to predict
  49. # background
  50. gt_instances = InstanceData()
  51. gt_instances.bboxes = torch.empty((0, 4))
  52. gt_instances.labels = torch.LongTensor([])
  53. empty_gt_losses = atss_head.loss_by_feat(cls_scores, bbox_preds,
  54. centernesses, [gt_instances],
  55. img_metas)
  56. # When there is no truth, the cls loss should be nonzero but there
  57. # should be no box loss.
  58. empty_cls_loss = sum(empty_gt_losses['loss_cls'])
  59. empty_box_loss = sum(empty_gt_losses['loss_bbox'])
  60. empty_centerness_loss = sum(empty_gt_losses['loss_centerness'])
  61. self.assertGreater(empty_cls_loss.item(), 0,
  62. 'cls loss should be non-zero')
  63. self.assertEqual(
  64. empty_box_loss.item(), 0,
  65. 'there should be no box loss when there are no true boxes')
  66. self.assertEqual(
  67. empty_centerness_loss.item(), 0,
  68. 'there should be no centerness loss when there are no true boxes')
  69. # When truth is non-empty then both cls and box loss should be nonzero
  70. # for random inputs
  71. gt_instances = InstanceData()
  72. gt_instances.bboxes = torch.Tensor(
  73. [[23.6667, 23.8757, 238.6326, 151.8874]])
  74. gt_instances.labels = torch.LongTensor([2])
  75. one_gt_losses = atss_head.loss_by_feat(cls_scores, bbox_preds,
  76. centernesses, [gt_instances],
  77. img_metas)
  78. onegt_cls_loss = sum(one_gt_losses['loss_cls'])
  79. onegt_box_loss = sum(one_gt_losses['loss_bbox'])
  80. onegt_centerness_loss = sum(one_gt_losses['loss_centerness'])
  81. self.assertGreater(onegt_cls_loss.item(), 0,
  82. 'cls loss should be non-zero')
  83. self.assertGreater(onegt_box_loss.item(), 0,
  84. 'box loss should be non-zero')
  85. self.assertGreater(onegt_centerness_loss.item(), 0,
  86. 'centerness loss should be non-zero')