test_formatting.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import os.path as osp
  4. import unittest
  5. import numpy as np
  6. import torch
  7. from mmengine.structures import InstanceData, PixelData
  8. from mmdet.datasets.transforms import PackDetInputs
  9. from mmdet.structures import DetDataSample
  10. from mmdet.structures.mask import BitmapMasks
  11. class TestPackDetInputs(unittest.TestCase):
  12. def setUp(self):
  13. """Setup the model and optimizer which are used in every test method.
  14. TestCase calls functions in this order: setUp() -> testMethod() ->
  15. tearDown() -> cleanUp()
  16. """
  17. data_prefix = osp.join(osp.dirname(__file__), '../../data')
  18. img_path = osp.join(data_prefix, 'color.jpg')
  19. rng = np.random.RandomState(0)
  20. self.results1 = {
  21. 'img_id': 1,
  22. 'img_path': img_path,
  23. 'ori_shape': (300, 400),
  24. 'img_shape': (600, 800),
  25. 'scale_factor': 2.0,
  26. 'flip': False,
  27. 'img': rng.rand(300, 400),
  28. 'gt_seg_map': rng.rand(300, 400),
  29. 'gt_masks':
  30. BitmapMasks(rng.rand(3, 300, 400), height=300, width=400),
  31. 'gt_bboxes_labels': rng.rand(3, ),
  32. 'gt_ignore_flags': np.array([0, 0, 1], dtype=bool),
  33. 'proposals': rng.rand(2, 4),
  34. 'proposals_scores': rng.rand(2, )
  35. }
  36. self.results2 = {
  37. 'img_id': 1,
  38. 'img_path': img_path,
  39. 'ori_shape': (300, 400),
  40. 'img_shape': (600, 800),
  41. 'scale_factor': 2.0,
  42. 'flip': False,
  43. 'img': rng.rand(300, 400),
  44. 'gt_seg_map': rng.rand(300, 400),
  45. 'gt_masks':
  46. BitmapMasks(rng.rand(3, 300, 400), height=300, width=400),
  47. 'gt_bboxes_labels': rng.rand(3, ),
  48. 'proposals': rng.rand(2, 4),
  49. 'proposals_scores': rng.rand(2, )
  50. }
  51. self.meta_keys = ('img_id', 'img_path', 'ori_shape', 'scale_factor',
  52. 'flip')
  53. def test_transform(self):
  54. transform = PackDetInputs(meta_keys=self.meta_keys)
  55. results = transform(copy.deepcopy(self.results1))
  56. self.assertIn('data_samples', results)
  57. self.assertIsInstance(results['data_samples'], DetDataSample)
  58. self.assertIsInstance(results['data_samples'].gt_instances,
  59. InstanceData)
  60. self.assertIsInstance(results['data_samples'].ignored_instances,
  61. InstanceData)
  62. self.assertEqual(len(results['data_samples'].gt_instances), 2)
  63. self.assertEqual(len(results['data_samples'].ignored_instances), 1)
  64. self.assertIsInstance(results['data_samples'].gt_sem_seg, PixelData)
  65. self.assertIsInstance(results['data_samples'].proposals, InstanceData)
  66. self.assertEqual(len(results['data_samples'].proposals), 2)
  67. self.assertIsInstance(results['data_samples'].proposals.bboxes,
  68. torch.Tensor)
  69. self.assertIsInstance(results['data_samples'].proposals.scores,
  70. torch.Tensor)
  71. def test_transform_without_ignore(self):
  72. transform = PackDetInputs(meta_keys=self.meta_keys)
  73. results = transform(copy.deepcopy(self.results2))
  74. self.assertIn('data_samples', results)
  75. self.assertIsInstance(results['data_samples'], DetDataSample)
  76. self.assertIsInstance(results['data_samples'].gt_instances,
  77. InstanceData)
  78. self.assertIsInstance(results['data_samples'].ignored_instances,
  79. InstanceData)
  80. self.assertEqual(len(results['data_samples'].gt_instances), 3)
  81. self.assertEqual(len(results['data_samples'].ignored_instances), 0)
  82. self.assertIsInstance(results['data_samples'].gt_sem_seg, PixelData)
  83. self.assertIsInstance(results['data_samples'].proposals, InstanceData)
  84. self.assertEqual(len(results['data_samples'].proposals), 2)
  85. self.assertIsInstance(results['data_samples'].proposals.bboxes,
  86. torch.Tensor)
  87. self.assertIsInstance(results['data_samples'].proposals.scores,
  88. torch.Tensor)
  89. def test_repr(self):
  90. transform = PackDetInputs(meta_keys=self.meta_keys)
  91. self.assertEqual(
  92. repr(transform), f'PackDetInputs(meta_keys={self.meta_keys})')