test_gfl_head.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine import Config
  5. from mmengine.structures import InstanceData
  6. from mmdet import * # noqa
  7. from mmdet.models.dense_heads import GFLHead
  8. class TestGFLHead(TestCase):
  9. def test_gfl_head_loss(self):
  10. """Tests gfl head loss when truth is empty and non-empty."""
  11. s = 256
  12. img_metas = [{
  13. 'img_shape': (s, s, 3),
  14. 'pad_shape': (s, s, 3),
  15. 'scale_factor': 1
  16. }]
  17. train_cfg = Config(
  18. dict(
  19. assigner=dict(type='ATSSAssigner', topk=9),
  20. allowed_border=-1,
  21. pos_weight=-1,
  22. debug=False))
  23. gfl_head = GFLHead(
  24. num_classes=4,
  25. in_channels=1,
  26. stacked_convs=1,
  27. train_cfg=train_cfg,
  28. anchor_generator=dict(
  29. type='AnchorGenerator',
  30. ratios=[1.0],
  31. octave_base_scale=8,
  32. scales_per_octave=1,
  33. strides=[8, 16, 32, 64, 128]),
  34. loss_cls=dict(
  35. type='QualityFocalLoss',
  36. use_sigmoid=True,
  37. beta=2.0,
  38. loss_weight=1.0),
  39. loss_bbox=dict(type='GIoULoss', loss_weight=2.0))
  40. feat = [
  41. torch.rand(1, 1, s // feat_size, s // feat_size)
  42. for feat_size in [4, 8, 16, 32, 64]
  43. ]
  44. cls_scores, bbox_preds = gfl_head.forward(feat)
  45. # Test that empty ground truth encourages the network to predict
  46. # background
  47. gt_instances = InstanceData()
  48. gt_instances.bboxes = torch.empty((0, 4))
  49. gt_instances.labels = torch.LongTensor([])
  50. empty_gt_losses = gfl_head.loss_by_feat(cls_scores, bbox_preds,
  51. [gt_instances], img_metas)
  52. # When there is no truth, the cls loss should be nonzero but there
  53. # should be no box loss.
  54. empty_cls_loss = sum(empty_gt_losses['loss_cls'])
  55. empty_box_loss = sum(empty_gt_losses['loss_bbox'])
  56. empty_dfl_loss = sum(empty_gt_losses['loss_dfl'])
  57. self.assertGreater(empty_cls_loss.item(), 0,
  58. 'cls loss should be non-zero')
  59. self.assertEqual(
  60. empty_box_loss.item(), 0,
  61. 'there should be no box loss when there are no true boxes')
  62. self.assertEqual(
  63. empty_dfl_loss.item(), 0,
  64. 'there should be no dfl loss when there are no true boxes')
  65. # When truth is non-empty then both cls and box loss should be nonzero
  66. # for random inputs
  67. gt_instances = InstanceData()
  68. gt_instances.bboxes = torch.Tensor(
  69. [[23.6667, 23.8757, 238.6326, 151.8874]])
  70. gt_instances.labels = torch.LongTensor([2])
  71. one_gt_losses = gfl_head.loss_by_feat(cls_scores, bbox_preds,
  72. [gt_instances], img_metas)
  73. onegt_cls_loss = sum(one_gt_losses['loss_cls'])
  74. onegt_box_loss = sum(one_gt_losses['loss_bbox'])
  75. onegt_dfl_loss = sum(one_gt_losses['loss_dfl'])
  76. self.assertGreater(onegt_cls_loss.item(), 0,
  77. 'cls loss should be non-zero')
  78. self.assertGreater(onegt_box_loss.item(), 0,
  79. 'box loss should be non-zero')
  80. self.assertGreater(onegt_dfl_loss.item(), 0,
  81. 'dfl loss should be non-zero')