test_dab_detr.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.registry import MODELS
  6. from mmdet.structures import DetDataSample
  7. from mmdet.testing import get_detector_cfg
  8. from mmdet.utils import register_all_modules
  9. class TestDABDETR(TestCase):
  10. def setUp(self) -> None:
  11. register_all_modules()
  12. def test_dab_detr_head_loss(self):
  13. """Tests transformer head loss when truth is empty and non-empty."""
  14. s = 256
  15. metainfo = {
  16. 'img_shape': (s, s),
  17. 'scale_factor': (1, 1),
  18. 'pad_shape': (s, s),
  19. 'batch_input_shape': (s, s)
  20. }
  21. img_metas = DetDataSample()
  22. img_metas.set_metainfo(metainfo)
  23. batch_data_samples = []
  24. batch_data_samples.append(img_metas)
  25. config = get_detector_cfg('dab_detr/dab-detr_r50_8xb2-50e_coco.py')
  26. model = MODELS.build(config)
  27. model.init_weights()
  28. random_image = torch.rand(1, 3, s, s)
  29. # Test that empty ground truth encourages the network to
  30. # predict background
  31. gt_instances = InstanceData()
  32. gt_instances.bboxes = torch.empty((0, 4))
  33. gt_instances.labels = torch.LongTensor([])
  34. img_metas.gt_instances = gt_instances
  35. batch_data_samples1 = []
  36. batch_data_samples1.append(img_metas)
  37. empty_gt_losses = model.loss(
  38. random_image, batch_data_samples=batch_data_samples1)
  39. # When there is no truth, the cls loss should be nonzero but there
  40. # should be no box loss.
  41. for key, loss in empty_gt_losses.items():
  42. if 'cls' in key:
  43. self.assertGreater(loss.item(), 0,
  44. 'cls loss should be non-zero')
  45. elif 'bbox' in key:
  46. self.assertEqual(
  47. loss.item(), 0,
  48. 'there should be no box loss when no ground true boxes')
  49. elif 'iou' in key:
  50. self.assertEqual(
  51. loss.item(), 0,
  52. 'there should be no iou loss when there are no true boxes')
  53. # When truth is non-empty then both cls and box loss should be nonzero
  54. # for random inputs
  55. gt_instances = InstanceData()
  56. gt_instances.bboxes = torch.Tensor(
  57. [[23.6667, 23.8757, 238.6326, 151.8874]])
  58. gt_instances.labels = torch.LongTensor([2])
  59. img_metas.gt_instances = gt_instances
  60. batch_data_samples2 = []
  61. batch_data_samples2.append(img_metas)
  62. one_gt_losses = model.loss(
  63. random_image, batch_data_samples=batch_data_samples2)
  64. for loss in one_gt_losses.values():
  65. self.assertGreater(
  66. loss.item(), 0,
  67. 'cls loss, or box loss, or iou loss should be non-zero')
  68. model.eval()
  69. # test _forward
  70. model._forward(random_image, batch_data_samples=batch_data_samples2)
  71. # test only predict
  72. model.predict(
  73. random_image, batch_data_samples=batch_data_samples2, rescale=True)