test_autoassign_head.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.models.dense_heads import AutoAssignHead
  6. class TestAutoAssignHead(TestCase):
  7. def test_autoassign_head_loss(self):
  8. """Tests autoassign head loss when truth is empty and non-empty."""
  9. s = 300
  10. img_metas = [{
  11. 'img_shape': (s, s, 3),
  12. 'pad_shape': (s, s, 3),
  13. 'scale_factor': 1,
  14. }]
  15. autoassign_head = AutoAssignHead(
  16. num_classes=4,
  17. in_channels=1,
  18. stacked_convs=1,
  19. feat_channels=1,
  20. strides=[8, 16, 32, 64, 128],
  21. loss_bbox=dict(type='GIoULoss', loss_weight=5.0),
  22. norm_cfg=None)
  23. # Fcos head expects a multiple levels of features per image
  24. feats = (
  25. torch.rand(1, 1, s // stride[1], s // stride[0])
  26. for stride in autoassign_head.prior_generator.strides)
  27. cls_scores, bbox_preds, centernesses = autoassign_head.forward(feats)
  28. # Test that empty ground truth encourages the network to
  29. # predict background
  30. gt_instances = InstanceData()
  31. gt_instances.bboxes = torch.empty((0, 4))
  32. gt_instances.labels = torch.LongTensor([])
  33. empty_gt_losses = autoassign_head.loss_by_feat(cls_scores, bbox_preds,
  34. centernesses,
  35. [gt_instances],
  36. img_metas)
  37. # When there is no truth, the neg loss should be nonzero but
  38. # pos loss and center loss should be zero
  39. empty_pos_loss = empty_gt_losses['loss_pos'].item()
  40. empty_neg_loss = empty_gt_losses['loss_neg'].item()
  41. empty_ctr_loss = empty_gt_losses['loss_center'].item()
  42. self.assertGreater(empty_neg_loss, 0, 'neg loss should be non-zero')
  43. self.assertEqual(
  44. empty_pos_loss, 0,
  45. 'there should be no pos loss when there are no true boxes')
  46. self.assertEqual(
  47. empty_ctr_loss, 0,
  48. 'there should be no centerness loss when there are no true boxes')
  49. # When truth is non-empty then all pos, neg loss and center loss
  50. # should be nonzero for random inputs
  51. gt_instances = InstanceData()
  52. gt_instances.bboxes = torch.Tensor(
  53. [[23.6667, 23.8757, 238.6326, 151.8874]])
  54. gt_instances.labels = torch.LongTensor([2])
  55. one_gt_losses = autoassign_head.loss_by_feat(cls_scores, bbox_preds,
  56. centernesses,
  57. [gt_instances], img_metas)
  58. onegt_pos_loss = one_gt_losses['loss_pos'].item()
  59. onegt_neg_loss = one_gt_losses['loss_neg'].item()
  60. onegt_ctr_loss = one_gt_losses['loss_center'].item()
  61. self.assertGreater(onegt_pos_loss, 0, 'pos loss should be non-zero')
  62. self.assertGreater(onegt_neg_loss, 0, 'neg loss should be non-zero')
  63. self.assertGreater(onegt_ctr_loss, 0, 'center loss should be non-zero')