test_dino.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import torch
  4. from mmengine.structures import InstanceData
  5. from mmdet.registry import MODELS
  6. from mmdet.structures import DetDataSample
  7. from mmdet.testing import get_detector_cfg
  8. from mmdet.utils import register_all_modules
  9. class TestDINO(TestCase):
  10. def setUp(self):
  11. register_all_modules()
  12. def test_dino_head_loss(self):
  13. """Tests transformer head loss when truth is empty and non-empty."""
  14. s = 256
  15. metainfo = {
  16. 'img_shape': (s, s),
  17. 'scale_factor': (1, 1),
  18. 'pad_shape': (s, s),
  19. 'batch_input_shape': (s, s)
  20. }
  21. data_sample = DetDataSample()
  22. data_sample.set_metainfo(metainfo)
  23. configs = [get_detector_cfg('dino/dino-4scale_r50_8xb2-12e_coco.py')]
  24. for config in configs:
  25. model = MODELS.build(config)
  26. model.init_weights()
  27. random_image = torch.rand(1, 3, s, s)
  28. # Test that empty ground truth encourages the network to
  29. # predict background
  30. gt_instances = InstanceData()
  31. gt_instances.bboxes = torch.empty((0, 4))
  32. gt_instances.labels = torch.LongTensor([])
  33. data_sample.gt_instances = gt_instances
  34. batch_data_samples_1 = [data_sample]
  35. empty_gt_losses = model.loss(
  36. random_image, batch_data_samples=batch_data_samples_1)
  37. # When there is no truth, the cls loss should be nonzero but there
  38. # should be no box loss.
  39. for key, loss in empty_gt_losses.items():
  40. _loss = loss.item()
  41. if 'bbox' in key or 'iou' in key or 'dn' in key:
  42. self.assertEqual(
  43. _loss, 0, f'there should be no {key}({_loss}) '
  44. f'when no ground true boxes')
  45. elif 'cls' in key:
  46. self.assertGreater(_loss, 0,
  47. f'{key}({_loss}) should be non-zero')
  48. # When truth is non-empty then both cls and box loss should
  49. # be nonzero for random inputs
  50. gt_instances = InstanceData()
  51. gt_instances.bboxes = torch.Tensor(
  52. [[23.6667, 23.8757, 238.6326, 151.8874]])
  53. gt_instances.labels = torch.LongTensor([2])
  54. data_sample.gt_instances = gt_instances
  55. batch_data_samples_2 = [data_sample]
  56. one_gt_losses = model.loss(
  57. random_image, batch_data_samples=batch_data_samples_2)
  58. for loss in one_gt_losses.values():
  59. self.assertGreater(
  60. loss.item(), 0,
  61. 'cls loss, or box loss, or iou loss should be non-zero')
  62. model.eval()
  63. # test _forward
  64. model._forward(
  65. random_image, batch_data_samples=batch_data_samples_2)
  66. # test only predict
  67. model.predict(
  68. random_image,
  69. batch_data_samples=batch_data_samples_2,
  70. rescale=True)