test_fovea_head.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.models.dense_heads import FoveaHead
  6. class TestFOVEAHead(TestCase):
  7. def test_fovea_head_loss(self):
  8. """Tests anchor head loss when truth is empty and non-empty."""
  9. s = 256
  10. img_metas = [{
  11. 'img_shape': (s, s, 3),
  12. 'pad_shape': (s, s, 3),
  13. 'scale_factor': 1,
  14. }]
  15. fovea_head = FoveaHead(num_classes=4, in_channels=1)
  16. # Anchor head expects a multiple levels of features per image
  17. feats = (
  18. torch.rand(1, 1, s // (2**(i + 2)), s // (2**(i + 2)))
  19. for i in range(len(fovea_head.prior_generator.strides)))
  20. cls_scores, bbox_preds = fovea_head.forward(feats)
  21. # Test that empty ground truth encourages the network to
  22. # predict background
  23. gt_instances = InstanceData()
  24. gt_instances.bboxes = torch.empty((0, 4))
  25. gt_instances.labels = torch.LongTensor([])
  26. empty_gt_losses = fovea_head.loss_by_feat(cls_scores, bbox_preds,
  27. [gt_instances], img_metas)
  28. # When there is no truth, the cls loss should be nonzero but
  29. # there should be no box loss.
  30. empty_cls_loss = empty_gt_losses['loss_cls']
  31. empty_box_loss = empty_gt_losses['loss_bbox']
  32. self.assertGreater(empty_cls_loss.item(), 0,
  33. 'cls loss should be non-zero')
  34. self.assertEqual(
  35. empty_box_loss.item(), 0,
  36. 'there should be no box loss when there are no true boxes')
  37. # When truth is non-empty then both cls and box loss
  38. # should be nonzero for random inputs
  39. gt_instances = InstanceData()
  40. gt_instances.bboxes = torch.Tensor(
  41. [[23.6667, 23.8757, 238.6326, 151.8874]])
  42. gt_instances.labels = torch.LongTensor([2])
  43. one_gt_losses = fovea_head.loss_by_feat(cls_scores, bbox_preds,
  44. [gt_instances], img_metas)
  45. onegt_cls_loss = one_gt_losses['loss_cls']
  46. onegt_box_loss = one_gt_losses['loss_bbox']
  47. self.assertGreater(onegt_cls_loss.item(), 0,
  48. 'cls loss should be non-zero')
  49. self.assertGreater(onegt_box_loss.item(), 0,
  50. 'box loss should be non-zero')