test_centripetal_head.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.models.dense_heads import CentripetalHead
  6. class TestCentripetalHead(TestCase):
  7. def test_centripetal_head_loss(self):
  8. """Tests corner head loss when truth is empty and non-empty."""
  9. s = 256
  10. img_metas = [{
  11. 'img_shape': (s, s, 3),
  12. 'scale_factor': 1,
  13. 'batch_input_shape': (s, s, 3)
  14. }]
  15. centripetal_head = CentripetalHead(
  16. num_classes=4, in_channels=1, corner_emb_channels=0)
  17. # Corner head expects a multiple levels of features per image
  18. feat = [
  19. torch.rand(1, 1, s // 4, s // 4)
  20. for _ in range(centripetal_head.num_feat_levels)
  21. ]
  22. forward_outputs = centripetal_head.forward(feat)
  23. # Test that empty ground truth encourages the network
  24. # to predict background
  25. gt_instances = InstanceData()
  26. gt_instances.bboxes = torch.empty((0, 4))
  27. gt_instances.labels = torch.LongTensor([])
  28. gt_bboxes_ignore = None
  29. empty_gt_losses = centripetal_head.loss_by_feat(
  30. *forward_outputs, [gt_instances], img_metas, gt_bboxes_ignore)
  31. empty_det_loss = sum(empty_gt_losses['det_loss'])
  32. empty_guiding_loss = sum(empty_gt_losses['guiding_loss'])
  33. empty_centripetal_loss = sum(empty_gt_losses['centripetal_loss'])
  34. empty_off_loss = sum(empty_gt_losses['off_loss'])
  35. self.assertTrue(empty_det_loss.item() > 0,
  36. 'det loss should be non-zero')
  37. self.assertTrue(
  38. empty_guiding_loss.item() == 0,
  39. 'there should be no guiding loss when there are no true boxes')
  40. self.assertTrue(
  41. empty_centripetal_loss.item() == 0,
  42. 'there should be no centripetal loss when there are no true boxes')
  43. self.assertTrue(
  44. empty_off_loss.item() == 0,
  45. 'there should be no box loss when there are no true boxes')
  46. gt_instances = InstanceData()
  47. gt_instances.bboxes = torch.Tensor(
  48. [[23.6667, 23.8757, 238.6326, 151.8874],
  49. [123.6667, 123.8757, 138.6326, 251.8874]])
  50. gt_instances.labels = torch.LongTensor([2, 3])
  51. two_gt_losses = centripetal_head.loss_by_feat(*forward_outputs,
  52. [gt_instances],
  53. img_metas,
  54. gt_bboxes_ignore)
  55. twogt_det_loss = sum(two_gt_losses['det_loss'])
  56. twogt_guiding_loss = sum(two_gt_losses['guiding_loss'])
  57. twogt_centripetal_loss = sum(two_gt_losses['centripetal_loss'])
  58. twogt_off_loss = sum(two_gt_losses['off_loss'])
  59. assert twogt_det_loss.item() > 0, 'det loss should be non-zero'
  60. assert twogt_guiding_loss.item() > 0, 'push loss should be non-zero'
  61. assert twogt_centripetal_loss.item(
  62. ) > 0, 'pull loss should be non-zero'
  63. assert twogt_off_loss.item() > 0, 'off loss should be non-zero'