test_conditional_detr.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.registry import MODELS
  6. from mmdet.structures import DetDataSample
  7. from mmdet.testing import get_detector_cfg
  8. from mmdet.utils import register_all_modules
  9. class TestConditionalDETR(TestCase):
  10. def setUp(self) -> None:
  11. register_all_modules()
  12. def test_conditional_detr_head_loss(self):
  13. """Tests transformer head loss when truth is empty and non-empty."""
  14. s = 256
  15. metainfo = {
  16. 'img_shape': (s, s),
  17. 'scale_factor': (1, 1),
  18. 'pad_shape': (s, s),
  19. 'batch_input_shape': (s, s)
  20. }
  21. img_metas = DetDataSample()
  22. img_metas.set_metainfo(metainfo)
  23. batch_data_samples = []
  24. batch_data_samples.append(img_metas)
  25. config = get_detector_cfg(
  26. 'conditional_detr/conditional-detr_r50_8xb2-50e_coco.py')
  27. model = MODELS.build(config)
  28. model.init_weights()
  29. random_image = torch.rand(1, 3, s, s)
  30. # Test that empty ground truth encourages the network to
  31. # predict background
  32. gt_instances = InstanceData()
  33. gt_instances.bboxes = torch.empty((0, 4))
  34. gt_instances.labels = torch.LongTensor([])
  35. img_metas.gt_instances = gt_instances
  36. batch_data_samples1 = []
  37. batch_data_samples1.append(img_metas)
  38. empty_gt_losses = model.loss(
  39. random_image, batch_data_samples=batch_data_samples1)
  40. # When there is no truth, the cls loss should be nonzero but there
  41. # should be no box loss.
  42. for key, loss in empty_gt_losses.items():
  43. if 'cls' in key:
  44. self.assertGreater(loss.item(), 0,
  45. 'cls loss should be non-zero')
  46. elif 'bbox' in key:
  47. self.assertEqual(
  48. loss.item(), 0,
  49. 'there should be no box loss when no ground true boxes')
  50. elif 'iou' in key:
  51. self.assertEqual(
  52. loss.item(), 0,
  53. 'there should be no iou loss when there are no true boxes')
  54. # When truth is non-empty then both cls and box loss should be nonzero
  55. # for random inputs
  56. gt_instances = InstanceData()
  57. gt_instances.bboxes = torch.Tensor(
  58. [[23.6667, 23.8757, 238.6326, 151.8874]])
  59. gt_instances.labels = torch.LongTensor([2])
  60. img_metas.gt_instances = gt_instances
  61. batch_data_samples2 = []
  62. batch_data_samples2.append(img_metas)
  63. one_gt_losses = model.loss(
  64. random_image, batch_data_samples=batch_data_samples2)
  65. for loss in one_gt_losses.values():
  66. self.assertGreater(
  67. loss.item(), 0,
  68. 'cls loss, or box loss, or iou loss should be non-zero')
  69. model.eval()
  70. # test _forward
  71. model._forward(random_image, batch_data_samples=batch_data_samples2)
  72. # test only predict
  73. model.predict(
  74. random_image, batch_data_samples=batch_data_samples2, rescale=True)