test_loading.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import os
  4. import os.path as osp
  5. import sys
  6. import unittest
  7. from unittest.mock import MagicMock, Mock, patch
  8. import mmcv
  9. import numpy as np
  10. from mmdet.datasets.transforms import (FilterAnnotations, LoadAnnotations,
  11. LoadEmptyAnnotations,
  12. LoadImageFromNDArray,
  13. LoadMultiChannelImageFromFiles,
  14. LoadProposals)
  15. from mmdet.evaluation import INSTANCE_OFFSET
  16. from mmdet.structures.mask import BitmapMasks, PolygonMasks
  17. try:
  18. import panopticapi
  19. except ImportError:
  20. panopticapi = None
  21. class TestLoadAnnotations(unittest.TestCase):
  22. def setUp(self):
  23. """Setup the model and optimizer which are used in every test method.
  24. TestCase calls functions in this order: setUp() -> testMethod() ->
  25. tearDown() -> cleanUp()
  26. """
  27. data_prefix = osp.join(osp.dirname(__file__), '../../data')
  28. seg_map = osp.join(data_prefix, 'gray.jpg')
  29. self.results = {
  30. 'ori_shape': (300, 400),
  31. 'seg_map_path':
  32. seg_map,
  33. 'instances': [{
  34. 'bbox': [0, 0, 10, 20],
  35. 'bbox_label': 1,
  36. 'mask': [[0, 0, 0, 20, 10, 20, 10, 0]],
  37. 'ignore_flag': 0
  38. }, {
  39. 'bbox': [10, 10, 110, 120],
  40. 'bbox_label': 2,
  41. 'mask': [[10, 10, 110, 10, 110, 120, 110, 10]],
  42. 'ignore_flag': 0
  43. }, {
  44. 'bbox': [50, 50, 60, 80],
  45. 'bbox_label': 2,
  46. 'mask': [[50, 50, 60, 50, 60, 80, 50, 80]],
  47. 'ignore_flag': 1
  48. }]
  49. }
  50. def test_load_bboxes(self):
  51. transform = LoadAnnotations(
  52. with_bbox=True,
  53. with_label=False,
  54. with_seg=False,
  55. with_mask=False,
  56. box_type=None)
  57. results = transform(copy.deepcopy(self.results))
  58. self.assertIn('gt_bboxes', results)
  59. self.assertTrue((results['gt_bboxes'] == np.array([[0, 0, 10, 20],
  60. [10, 10, 110, 120],
  61. [50, 50, 60,
  62. 80]])).all())
  63. self.assertEqual(results['gt_bboxes'].dtype, np.float32)
  64. self.assertTrue((results['gt_ignore_flags'] == np.array([0, 0,
  65. 1])).all())
  66. self.assertEqual(results['gt_ignore_flags'].dtype, bool)
  67. def test_load_labels(self):
  68. transform = LoadAnnotations(
  69. with_bbox=False,
  70. with_label=True,
  71. with_seg=False,
  72. with_mask=False,
  73. )
  74. results = transform(copy.deepcopy(self.results))
  75. self.assertIn('gt_bboxes_labels', results)
  76. self.assertTrue((results['gt_bboxes_labels'] == np.array([1, 2,
  77. 2])).all())
  78. self.assertEqual(results['gt_bboxes_labels'].dtype, np.int64)
  79. def test_load_mask(self):
  80. transform = LoadAnnotations(
  81. with_bbox=False,
  82. with_label=False,
  83. with_seg=False,
  84. with_mask=True,
  85. poly2mask=False)
  86. results = transform(copy.deepcopy(self.results))
  87. self.assertIn('gt_masks', results)
  88. self.assertEqual(len(results['gt_masks']), 3)
  89. self.assertIsInstance(results['gt_masks'], PolygonMasks)
  90. def test_load_mask_poly2mask(self):
  91. transform = LoadAnnotations(
  92. with_bbox=False,
  93. with_label=False,
  94. with_seg=False,
  95. with_mask=True,
  96. poly2mask=True)
  97. results = transform(copy.deepcopy(self.results))
  98. self.assertIn('gt_masks', results)
  99. self.assertEqual(len(results['gt_masks']), 3)
  100. self.assertIsInstance(results['gt_masks'], BitmapMasks)
  101. def test_repr(self):
  102. transform = LoadAnnotations(
  103. with_bbox=True,
  104. with_label=False,
  105. with_seg=False,
  106. with_mask=False,
  107. )
  108. self.assertEqual(
  109. repr(transform), ('LoadAnnotations(with_bbox=True, '
  110. 'with_label=False, with_mask=False, '
  111. 'with_seg=False, poly2mask=True, '
  112. "imdecode_backend='cv2', "
  113. 'backend_args=None)'))
  114. class TestFilterAnnotations(unittest.TestCase):
  115. def setUp(self):
  116. """Setup the model and optimizer which are used in every test method.
  117. TestCase calls functions in this order: setUp() -> testMethod() ->
  118. tearDown() -> cleanUp()
  119. """
  120. rng = np.random.RandomState(0)
  121. self.results = {
  122. 'img':
  123. np.random.random((224, 224, 3)),
  124. 'img_shape': (224, 224),
  125. 'gt_bboxes_labels':
  126. np.array([1, 2, 3], dtype=np.int64),
  127. 'gt_bboxes':
  128. np.array([[10, 10, 20, 20], [20, 20, 40, 40], [40, 40, 80, 80]]),
  129. 'gt_ignore_flags':
  130. np.array([0, 0, 1], dtype=np.bool8),
  131. 'gt_masks':
  132. BitmapMasks(rng.rand(3, 224, 224), height=224, width=224),
  133. }
  134. def test_transform(self):
  135. # test keep_empty = True
  136. transform = FilterAnnotations(
  137. min_gt_bbox_wh=(50, 50),
  138. keep_empty=True,
  139. )
  140. results = transform(copy.deepcopy(self.results))
  141. self.assertIsNone(results)
  142. # test keep_empty = False
  143. transform = FilterAnnotations(
  144. min_gt_bbox_wh=(50, 50),
  145. keep_empty=False,
  146. )
  147. results = transform(copy.deepcopy(self.results))
  148. self.assertTrue(isinstance(results, dict))
  149. # test filter annotations
  150. transform = FilterAnnotations(min_gt_bbox_wh=(15, 15), )
  151. results = transform(copy.deepcopy(self.results))
  152. self.assertIsInstance(results, dict)
  153. self.assertTrue((results['gt_bboxes_labels'] == np.array([2,
  154. 3])).all())
  155. self.assertTrue((results['gt_bboxes'] == np.array([[20, 20, 40, 40],
  156. [40, 40, 80,
  157. 80]])).all())
  158. self.assertEqual(len(results['gt_masks']), 2)
  159. self.assertEqual(len(results['gt_ignore_flags']), 2)
  160. def test_repr(self):
  161. transform = FilterAnnotations(
  162. min_gt_bbox_wh=(1, 1),
  163. keep_empty=False,
  164. )
  165. self.assertEqual(
  166. repr(transform), ('FilterAnnotations(min_gt_bbox_wh=(1, 1), '
  167. 'keep_empty=False)'))
  168. class TestLoadPanopticAnnotations(unittest.TestCase):
  169. def setUp(self):
  170. seg_map = np.zeros((10, 10), dtype=np.int32)
  171. seg_map[:5, :10] = 1 + 10 * INSTANCE_OFFSET
  172. seg_map[5:10, :5] = 4 + 11 * INSTANCE_OFFSET
  173. seg_map[5:10, 5:10] = 6 + 0 * INSTANCE_OFFSET
  174. rgb_seg_map = np.zeros((10, 10, 3), dtype=np.uint8)
  175. rgb_seg_map[:, :, 0] = seg_map / (256 * 256)
  176. rgb_seg_map[:, :, 1] = seg_map % (256 * 256) / 256
  177. rgb_seg_map[:, :, 2] = seg_map % 256
  178. self.seg_map_path = './1.png'
  179. mmcv.imwrite(rgb_seg_map, self.seg_map_path)
  180. self.seg_map = seg_map
  181. self.rgb_seg_map = rgb_seg_map
  182. self.results = {
  183. 'ori_shape': (10, 10),
  184. 'instances': [{
  185. 'bbox': [0, 0, 10, 5],
  186. 'bbox_label': 0,
  187. 'ignore_flag': 0,
  188. }, {
  189. 'bbox': [0, 5, 5, 10],
  190. 'bbox_label': 1,
  191. 'ignore_flag': 1,
  192. }],
  193. 'segments_info': [
  194. {
  195. 'id': 1 + 10 * INSTANCE_OFFSET,
  196. 'category': 0,
  197. 'is_thing': True,
  198. },
  199. {
  200. 'id': 4 + 11 * INSTANCE_OFFSET,
  201. 'category': 1,
  202. 'is_thing': True,
  203. },
  204. {
  205. 'id': 6 + 0 * INSTANCE_OFFSET,
  206. 'category': 2,
  207. 'is_thing': False,
  208. },
  209. ],
  210. 'seg_map_path':
  211. self.seg_map_path
  212. }
  213. self.gt_mask = BitmapMasks([
  214. (seg_map == 1 + 10 * INSTANCE_OFFSET).astype(np.uint8),
  215. (seg_map == 4 + 11 * INSTANCE_OFFSET).astype(np.uint8),
  216. ], 10, 10)
  217. self.gt_bboxes = np.array([[0, 0, 10, 5], [0, 5, 5, 10]],
  218. dtype=np.float32)
  219. self.gt_bboxes_labels = np.array([0, 1], dtype=np.int64)
  220. self.gt_ignore_flags = np.array([0, 1], dtype=bool)
  221. self.gt_seg_map = np.zeros((10, 10), dtype=np.int32)
  222. self.gt_seg_map[:5, :10] = 0
  223. self.gt_seg_map[5:10, :5] = 1
  224. self.gt_seg_map[5:10, 5:10] = 2
  225. def tearDown(self):
  226. os.remove(self.seg_map_path)
  227. @unittest.skipIf(panopticapi is not None, 'panopticapi is installed')
  228. def test_init_without_panopticapi(self):
  229. # test if panopticapi is not installed
  230. from mmdet.datasets.transforms import LoadPanopticAnnotations
  231. with self.assertRaisesRegex(
  232. ImportError,
  233. 'panopticapi is not installed, please install it by'):
  234. LoadPanopticAnnotations()
  235. def test_transform(self):
  236. sys.modules['panopticapi'] = MagicMock()
  237. sys.modules['panopticapi.utils'] = MagicMock()
  238. from mmdet.datasets.transforms import LoadPanopticAnnotations
  239. mock_rgb2id = Mock(return_value=self.seg_map)
  240. with patch('panopticapi.utils.rgb2id', mock_rgb2id):
  241. # test with all False
  242. transform = LoadPanopticAnnotations(
  243. with_bbox=False,
  244. with_label=False,
  245. with_mask=False,
  246. with_seg=False)
  247. results = transform(copy.deepcopy(self.results))
  248. self.assertDictEqual(results, self.results)
  249. # test with with_mask=True
  250. transform = LoadPanopticAnnotations(
  251. with_bbox=False,
  252. with_label=False,
  253. with_mask=True,
  254. with_seg=False)
  255. results = transform(copy.deepcopy(self.results))
  256. self.assertTrue(
  257. (results['gt_masks'].masks == self.gt_mask.masks).all())
  258. # test with with_seg=True
  259. transform = LoadPanopticAnnotations(
  260. with_bbox=False,
  261. with_label=False,
  262. with_mask=False,
  263. with_seg=True)
  264. results = transform(copy.deepcopy(self.results))
  265. self.assertNotIn('gt_masks', results)
  266. self.assertTrue((results['gt_seg_map'] == self.gt_seg_map).all())
  267. # test with all True
  268. transform = LoadPanopticAnnotations(
  269. with_bbox=True,
  270. with_label=True,
  271. with_mask=True,
  272. with_seg=True,
  273. box_type=None)
  274. results = transform(copy.deepcopy(self.results))
  275. self.assertTrue(
  276. (results['gt_masks'].masks == self.gt_mask.masks).all())
  277. self.assertTrue((results['gt_bboxes'] == self.gt_bboxes).all())
  278. self.assertTrue(
  279. (results['gt_bboxes_labels'] == self.gt_bboxes_labels).all())
  280. self.assertTrue(
  281. (results['gt_ignore_flags'] == self.gt_ignore_flags).all())
  282. self.assertTrue((results['gt_seg_map'] == self.gt_seg_map).all())
  283. class TestLoadImageFromNDArray(unittest.TestCase):
  284. def setUp(self):
  285. """Setup the model and optimizer which are used in every test method.
  286. TestCase calls functions in this order: setUp() -> testMethod() ->
  287. tearDown() -> cleanUp()
  288. """
  289. self.results = {'img': np.zeros((256, 256, 3), dtype=np.uint8)}
  290. def test_transform(self):
  291. transform = LoadImageFromNDArray()
  292. results = transform(copy.deepcopy(self.results))
  293. self.assertEqual(results['img'].shape, (256, 256, 3))
  294. self.assertEqual(results['img'].dtype, np.uint8)
  295. self.assertEqual(results['img_shape'], (256, 256))
  296. self.assertEqual(results['ori_shape'], (256, 256))
  297. # to_float32
  298. transform = LoadImageFromNDArray(to_float32=True)
  299. results = transform(copy.deepcopy(results))
  300. self.assertEqual(results['img'].dtype, np.float32)
  301. def test_repr(self):
  302. transform = LoadImageFromNDArray()
  303. self.assertEqual(
  304. repr(transform), ('LoadImageFromNDArray('
  305. 'ignore_empty=False, '
  306. 'to_float32=False, '
  307. "color_type='color', "
  308. "imdecode_backend='cv2', "
  309. 'backend_args=None)'))
  310. class TestLoadMultiChannelImageFromFiles(unittest.TestCase):
  311. def setUp(self):
  312. """Setup the model and optimizer which are used in every test method.
  313. TestCase calls functions in this order: setUp() -> testMethod() ->
  314. tearDown() -> cleanUp()
  315. """
  316. self.img_path = []
  317. for i in range(4):
  318. img_channel_path = f'./part_{i}.jpg'
  319. img_channel = np.zeros((10, 10), dtype=np.uint8)
  320. mmcv.imwrite(img_channel, img_channel_path)
  321. self.img_path.append(img_channel_path)
  322. self.results = {'img_path': self.img_path}
  323. def tearDown(self):
  324. for filename in self.img_path:
  325. os.remove(filename)
  326. def test_transform(self):
  327. transform = LoadMultiChannelImageFromFiles()
  328. results = transform(copy.deepcopy(self.results))
  329. self.assertEqual(results['img'].shape, (10, 10, 4))
  330. self.assertEqual(results['img'].dtype, np.uint8)
  331. self.assertEqual(results['img_shape'], (10, 10))
  332. self.assertEqual(results['ori_shape'], (10, 10))
  333. # to_float32
  334. transform = LoadMultiChannelImageFromFiles(to_float32=True)
  335. results = transform(copy.deepcopy(results))
  336. self.assertEqual(results['img'].dtype, np.float32)
  337. def test_rper(self):
  338. transform = LoadMultiChannelImageFromFiles()
  339. self.assertEqual(
  340. repr(transform), ('LoadMultiChannelImageFromFiles('
  341. 'to_float32=False, '
  342. "color_type='unchanged', "
  343. "imdecode_backend='cv2', "
  344. 'backend_args=None)'))
  345. class TestLoadProposals(unittest.TestCase):
  346. def test_transform(self):
  347. transform = LoadProposals()
  348. results = {
  349. 'proposals':
  350. dict(
  351. bboxes=np.zeros((5, 4), dtype=np.int64),
  352. scores=np.zeros((5, ), dtype=np.int64))
  353. }
  354. results = transform(results)
  355. self.assertEqual(results['proposals'].dtype, np.float32)
  356. self.assertEqual(results['proposals'].shape[-1], 4)
  357. self.assertEqual(results['proposals_scores'].dtype, np.float32)
  358. # bboxes.shape[1] should be 4
  359. results = {'proposals': dict(bboxes=np.zeros((5, 5), dtype=np.int64))}
  360. with self.assertRaises(AssertionError):
  361. transform(results)
  362. # bboxes.shape[0] should equal to scores.shape[0]
  363. results = {
  364. 'proposals':
  365. dict(
  366. bboxes=np.zeros((5, 4), dtype=np.int64),
  367. scores=np.zeros((3, ), dtype=np.int64))
  368. }
  369. with self.assertRaises(AssertionError):
  370. transform(results)
  371. # empty bboxes
  372. results = {
  373. 'proposals': dict(bboxes=np.zeros((0, 4), dtype=np.float32))
  374. }
  375. results = transform(results)
  376. excepted_proposals = np.zeros((0, 4), dtype=np.float32)
  377. excepted_proposals_scores = np.zeros(0, dtype=np.float32)
  378. self.assertTrue((results['proposals'] == excepted_proposals).all())
  379. self.assertTrue(
  380. (results['proposals_scores'] == excepted_proposals_scores).all())
  381. transform = LoadProposals(num_max_proposals=2)
  382. results = {
  383. 'proposals':
  384. dict(
  385. bboxes=np.zeros((5, 4), dtype=np.int64),
  386. scores=np.zeros((5, ), dtype=np.int64))
  387. }
  388. results = transform(results)
  389. self.assertEqual(results['proposals'].shape[0], 2)
  390. def test_repr(self):
  391. transform = LoadProposals()
  392. self.assertEqual(
  393. repr(transform), 'LoadProposals(num_max_proposals=None)')
  394. class TestLoadEmptyAnnotations(unittest.TestCase):
  395. def test_transform(self):
  396. transform = LoadEmptyAnnotations(
  397. with_bbox=True, with_label=True, with_mask=True, with_seg=True)
  398. results = {'img_shape': (224, 224)}
  399. results = transform(results)
  400. self.assertEqual(results['gt_bboxes'].dtype, np.float32)
  401. self.assertEqual(results['gt_bboxes'].shape[-1], 4)
  402. self.assertEqual(results['gt_ignore_flags'].dtype, bool)
  403. self.assertEqual(results['gt_bboxes_labels'].dtype, np.int64)
  404. self.assertEqual(results['gt_masks'].masks.dtype, np.uint8)
  405. self.assertEqual(results['gt_masks'].masks.shape[-2:],
  406. results['img_shape'])
  407. self.assertEqual(results['gt_seg_map'].dtype, np.uint8)
  408. self.assertEqual(results['gt_seg_map'].shape, results['img_shape'])
  409. def test_repr(self):
  410. transform = LoadEmptyAnnotations()
  411. self.assertEqual(
  412. repr(transform), 'LoadEmptyAnnotations(with_bbox=True, '
  413. 'with_label=True, '
  414. 'with_mask=False, '
  415. 'with_seg=False, '
  416. 'seg_ignore_label=255)')