test_ddod_head.py 3.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine import Config
  5. from mmengine.structures import InstanceData
  6. from mmdet import * # noqa
  7. from mmdet.models.dense_heads import DDODHead
  8. class TestDDODHead(TestCase):
  9. def test_ddod_head_loss(self):
  10. """Tests ddod head loss when truth is empty and non-empty."""
  11. s = 256
  12. img_metas = [{
  13. 'img_shape': (s, s, 3),
  14. 'pad_shape': (s, s, 3),
  15. 'scale_factor': 1
  16. }]
  17. cfg = Config(
  18. dict(
  19. assigner=dict(type='ATSSAssigner', topk=9, alpha=0.8),
  20. reg_assigner=dict(type='ATSSAssigner', topk=9, alpha=0.5),
  21. allowed_border=-1,
  22. pos_weight=-1,
  23. debug=False))
  24. atss_head = DDODHead(
  25. num_classes=4,
  26. in_channels=1,
  27. stacked_convs=1,
  28. feat_channels=1,
  29. use_dcn=False,
  30. norm_cfg=None,
  31. train_cfg=cfg,
  32. anchor_generator=dict(
  33. type='AnchorGenerator',
  34. ratios=[1.0],
  35. octave_base_scale=8,
  36. scales_per_octave=1,
  37. strides=[8, 16, 32, 64, 128]),
  38. loss_cls=dict(
  39. type='FocalLoss',
  40. use_sigmoid=True,
  41. gamma=2.0,
  42. alpha=0.25,
  43. loss_weight=1.0),
  44. loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
  45. loss_iou=dict(
  46. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0))
  47. feat = [
  48. torch.rand(1, 1, s // feat_size, s // feat_size)
  49. for feat_size in [8, 16, 32, 64, 128]
  50. ]
  51. cls_scores, bbox_preds, centernesses = atss_head.forward(feat)
  52. # Test that empty ground truth encourages the network to predict
  53. # background
  54. gt_instances = InstanceData()
  55. gt_instances.bboxes = torch.empty((0, 4))
  56. gt_instances.labels = torch.LongTensor([])
  57. empty_gt_losses = atss_head.loss_by_feat(cls_scores, bbox_preds,
  58. centernesses, [gt_instances],
  59. img_metas)
  60. # When there is no truth, the cls loss should be nonzero but there
  61. # should be no box loss.
  62. empty_cls_loss = sum(empty_gt_losses['loss_cls'])
  63. empty_box_loss = sum(empty_gt_losses['loss_bbox'])
  64. empty_centerness_loss = sum(empty_gt_losses['loss_iou'])
  65. self.assertGreater(empty_cls_loss.item(), 0,
  66. 'cls loss should be non-zero')
  67. self.assertEqual(
  68. empty_box_loss.item(), 0,
  69. 'there should be no box loss when there are no true boxes')
  70. self.assertEqual(
  71. empty_centerness_loss.item(), 0,
  72. 'there should be no centerness loss when there are no true boxes')
  73. # When truth is non-empty then both cls and box loss should be nonzero
  74. # for random inputs
  75. gt_instances = InstanceData()
  76. gt_instances.bboxes = torch.Tensor(
  77. [[23.6667, 23.8757, 238.6326, 151.8874]])
  78. gt_instances.labels = torch.LongTensor([2])
  79. one_gt_losses = atss_head.loss_by_feat(cls_scores, bbox_preds,
  80. centernesses, [gt_instances],
  81. img_metas)
  82. onegt_cls_loss = sum(one_gt_losses['loss_cls'])
  83. onegt_box_loss = sum(one_gt_losses['loss_bbox'])
  84. onegt_centerness_loss = sum(one_gt_losses['loss_iou'])
  85. self.assertGreater(onegt_cls_loss.item(), 0,
  86. 'cls loss should be non-zero')
  87. self.assertGreater(onegt_box_loss.item(), 0,
  88. 'box loss should be non-zero')
  89. self.assertGreater(onegt_centerness_loss.item(), 0,
  90. 'centerness loss should be non-zero')