test_keypoint_partition_metric.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import os.path as osp
  4. import tempfile
  5. from collections import defaultdict
  6. from unittest import TestCase
  7. import numpy as np
  8. from mmengine.fileio import load
  9. from mmengine.structures import InstanceData
  10. from xtcocotools.coco import COCO
  11. from mmpose.datasets.datasets.utils import parse_pose_metainfo
  12. from mmpose.evaluation.metrics import KeypointPartitionMetric
  13. class TestKeypointPartitionMetricWrappingCocoMetric(TestCase):
  14. def setUp(self):
  15. """Setup some variables which are used in every test method.
  16. TestCase calls functions in this order: setUp() -> testMethod() ->
  17. tearDown() -> cleanUp()
  18. """
  19. self.tmp_dir = tempfile.TemporaryDirectory()
  20. self.ann_file_coco = \
  21. 'tests/data/coco/test_keypoint_partition_metric.json'
  22. meta_info_coco = dict(
  23. from_file='configs/_base_/datasets/coco_wholebody.py')
  24. self.dataset_meta_coco = parse_pose_metainfo(meta_info_coco)
  25. self.coco = COCO(self.ann_file_coco)
  26. self.dataset_meta_coco['CLASSES'] = self.coco.loadCats(
  27. self.coco.getCatIds())
  28. self.topdown_data_coco = self._convert_ann_to_topdown_batch_data(
  29. self.ann_file_coco)
  30. assert len(self.topdown_data_coco) == 14
  31. self.bottomup_data_coco = self._convert_ann_to_bottomup_batch_data(
  32. self.ann_file_coco)
  33. assert len(self.bottomup_data_coco) == 4
  34. """
  35. The target results were obtained from CocoWholebodyMetric with
  36. score_mode='bbox' and nms_mode='none'. We cannot compare other
  37. combinations of score_mode and nms_mode because CocoWholebodyMetric
  38. calculates scores and nms using all keypoints while
  39. KeypointPartitionMetric calculates scores and nms part by part.
  40. As long as this case is tested correct, the other cases should be
  41. correct.
  42. """
  43. self.target_bbox_none = {
  44. 'body/coco/AP': 0.749,
  45. 'body/coco/AR': 0.800,
  46. 'foot/coco/AP': 0.840,
  47. 'foot/coco/AR': 0.850,
  48. 'face/coco/AP': 0.051,
  49. 'face/coco/AR': 0.050,
  50. 'left_hand/coco/AP': 0.283,
  51. 'left_hand/coco/AR': 0.300,
  52. 'right_hand/coco/AP': 0.383,
  53. 'right_hand/coco/AR': 0.380,
  54. 'all/coco/AP': 0.284,
  55. 'all/coco/AR': 0.450,
  56. }
  57. def _convert_ann_to_topdown_batch_data(self, ann_file):
  58. """Convert annotations to topdown-style batch data."""
  59. topdown_data = []
  60. db = load(ann_file)
  61. imgid2info = dict()
  62. for img in db['images']:
  63. imgid2info[img['id']] = img
  64. for ann in db['annotations']:
  65. w, h = ann['bbox'][2], ann['bbox'][3]
  66. bboxes = np.array(ann['bbox'], dtype=np.float32).reshape(-1, 4)
  67. bbox_scales = np.array([w * 1.25, h * 1.25]).reshape(-1, 2)
  68. _keypoints = np.array(ann['keypoints']).reshape((1, -1, 3))
  69. gt_instances = {
  70. 'bbox_scales': bbox_scales,
  71. 'bbox_scores': np.ones((1, ), dtype=np.float32),
  72. 'bboxes': bboxes,
  73. 'keypoints': _keypoints[..., :2],
  74. 'keypoints_visible': _keypoints[..., 2:3]
  75. }
  76. # fake predictions
  77. keypoints = np.zeros_like(_keypoints)
  78. keypoints[..., 0] = _keypoints[..., 0] * 0.99
  79. keypoints[..., 1] = _keypoints[..., 1] * 1.02
  80. keypoints[..., 2] = _keypoints[..., 2] * 0.8
  81. pred_instances = {
  82. 'keypoints': keypoints[..., :2],
  83. 'keypoint_scores': keypoints[..., -1],
  84. }
  85. data = {'inputs': None}
  86. data_sample = {
  87. 'id': ann['id'],
  88. 'img_id': ann['image_id'],
  89. 'category_id': ann.get('category_id', 1),
  90. 'gt_instances': gt_instances,
  91. 'pred_instances': pred_instances,
  92. # dummy image_shape for testing
  93. 'ori_shape': [640, 480],
  94. # store the raw annotation info to test without ann_file
  95. 'raw_ann_info': copy.deepcopy(ann),
  96. }
  97. # add crowd_index to data_sample if it is present in the image_info
  98. if 'crowdIndex' in imgid2info[ann['image_id']]:
  99. data_sample['crowd_index'] = imgid2info[
  100. ann['image_id']]['crowdIndex']
  101. # batch size = 1
  102. data_batch = [data]
  103. data_samples = [data_sample]
  104. topdown_data.append((data_batch, data_samples))
  105. return topdown_data
  106. def _convert_ann_to_bottomup_batch_data(self, ann_file):
  107. """Convert annotations to bottomup-style batch data."""
  108. img2ann = defaultdict(list)
  109. db = load(ann_file)
  110. for ann in db['annotations']:
  111. img2ann[ann['image_id']].append(ann)
  112. bottomup_data = []
  113. for img_id, anns in img2ann.items():
  114. _keypoints = np.array([ann['keypoints'] for ann in anns]).reshape(
  115. (len(anns), -1, 3))
  116. gt_instances = {
  117. 'bbox_scores': np.ones((len(anns)), dtype=np.float32),
  118. 'keypoints': _keypoints[..., :2],
  119. 'keypoints_visible': _keypoints[..., 2:3]
  120. }
  121. # fake predictions
  122. keypoints = np.zeros_like(_keypoints)
  123. keypoints[..., 0] = _keypoints[..., 0] * 0.99
  124. keypoints[..., 1] = _keypoints[..., 1] * 1.02
  125. keypoints[..., 2] = _keypoints[..., 2] * 0.8
  126. pred_instances = {
  127. 'keypoints': keypoints[..., :2],
  128. 'keypoint_scores': keypoints[..., -1],
  129. }
  130. data = {'inputs': None}
  131. data_sample = {
  132. 'id': [ann['id'] for ann in anns],
  133. 'img_id': img_id,
  134. 'gt_instances': gt_instances,
  135. 'pred_instances': pred_instances,
  136. # dummy image_shape for testing
  137. 'ori_shape': [640, 480],
  138. 'raw_ann_info': copy.deepcopy(anns),
  139. }
  140. # batch size = 1
  141. data_batch = [data]
  142. data_samples = [data_sample]
  143. bottomup_data.append((data_batch, data_samples))
  144. return bottomup_data
  145. def _assert_outfiles(self, prefix):
  146. for part in ['body', 'foot', 'face', 'left_hand', 'right_hand', 'all']:
  147. self.assertTrue(
  148. osp.isfile(
  149. osp.join(self.tmp_dir.name,
  150. f'{prefix}.{part}.keypoints.json')))
  151. def tearDown(self):
  152. self.tmp_dir.cleanup()
  153. def test_init(self):
  154. """test metric init method."""
  155. # test wrong metric type
  156. with self.assertRaisesRegex(
  157. ValueError, 'Metrics supported by KeypointPartitionMetric'):
  158. _ = KeypointPartitionMetric(
  159. metric=dict(type='Metric'), partitions=dict(all=range(133)))
  160. # test ann_file arg warning
  161. with self.assertWarnsRegex(UserWarning,
  162. 'does not support the ann_file argument'):
  163. _ = KeypointPartitionMetric(
  164. metric=dict(type='CocoMetric', ann_file=''),
  165. partitions=dict(all=range(133)))
  166. # test score_mode arg warning
  167. with self.assertWarnsRegex(UserWarning, "if score_mode is not 'bbox'"):
  168. _ = KeypointPartitionMetric(
  169. metric=dict(type='CocoMetric'),
  170. partitions=dict(all=range(133)))
  171. # test nms arg warning
  172. with self.assertWarnsRegex(UserWarning, 'oks_nms and soft_oks_nms'):
  173. _ = KeypointPartitionMetric(
  174. metric=dict(type='CocoMetric'),
  175. partitions=dict(all=range(133)))
  176. # test partitions
  177. with self.assertRaisesRegex(AssertionError, 'at least one partition'):
  178. _ = KeypointPartitionMetric(
  179. metric=dict(type='CocoMetric'), partitions=dict())
  180. with self.assertRaisesRegex(AssertionError, 'should be a sequence'):
  181. _ = KeypointPartitionMetric(
  182. metric=dict(type='CocoMetric'), partitions=dict(all={}))
  183. with self.assertRaisesRegex(AssertionError, 'at least one element'):
  184. _ = KeypointPartitionMetric(
  185. metric=dict(type='CocoMetric'), partitions=dict(all=[]))
  186. def test_bottomup_evaluate(self):
  187. """test bottomup-style COCO metric evaluation."""
  188. # case1: score_mode='bbox', nms_mode='none'
  189. metric = KeypointPartitionMetric(
  190. metric=dict(
  191. type='CocoMetric',
  192. outfile_prefix=f'{self.tmp_dir.name}/test_bottomup',
  193. score_mode='bbox',
  194. nms_mode='none'),
  195. partitions=dict(
  196. body=range(17),
  197. foot=range(17, 23),
  198. face=range(23, 91),
  199. left_hand=range(91, 112),
  200. right_hand=range(112, 133),
  201. all=range(133)))
  202. metric.dataset_meta = self.dataset_meta_coco
  203. # process samples
  204. for data_batch, data_samples in self.bottomup_data_coco:
  205. metric.process(data_batch, data_samples)
  206. eval_results = metric.evaluate(size=len(self.bottomup_data_coco))
  207. for key in self.target_bbox_none.keys():
  208. self.assertAlmostEqual(
  209. eval_results[key], self.target_bbox_none[key], places=3)
  210. self._assert_outfiles('test_bottomup')
  211. def test_topdown_evaluate(self):
  212. """test topdown-style COCO metric evaluation."""
  213. # case 1: score_mode='bbox', nms_mode='none'
  214. metric = KeypointPartitionMetric(
  215. metric=dict(
  216. type='CocoMetric',
  217. outfile_prefix=f'{self.tmp_dir.name}/test_topdown1',
  218. score_mode='bbox',
  219. nms_mode='none'),
  220. partitions=dict(
  221. body=range(17),
  222. foot=range(17, 23),
  223. face=range(23, 91),
  224. left_hand=range(91, 112),
  225. right_hand=range(112, 133),
  226. all=range(133)))
  227. metric.dataset_meta = self.dataset_meta_coco
  228. # process samples
  229. for data_batch, data_samples in self.topdown_data_coco:
  230. metric.process(data_batch, data_samples)
  231. eval_results = metric.evaluate(size=len(self.topdown_data_coco))
  232. for key in self.target_bbox_none.keys():
  233. self.assertAlmostEqual(
  234. eval_results[key], self.target_bbox_none[key], places=3)
  235. self._assert_outfiles('test_topdown1')
  236. class TestKeypointPartitionMetricWrappingPCKAccuracy(TestCase):
  237. def setUp(self):
  238. """Setup some variables which are used in every test method.
  239. TestCase calls functions in this order: setUp() -> testMethod() ->
  240. tearDown() -> cleanUp()
  241. """
  242. self.batch_size = 8
  243. num_keypoints = 24
  244. self.data_batch = []
  245. self.data_samples = []
  246. for i in range(self.batch_size):
  247. gt_instances = InstanceData()
  248. keypoints = np.zeros((1, num_keypoints, 2))
  249. for j in range(num_keypoints):
  250. keypoints[0, j] = [0.5 * i * j, 0.5 * i * j]
  251. gt_instances.keypoints = keypoints
  252. gt_instances.keypoints_visible = np.ones(
  253. (1, num_keypoints, 1)).astype(bool)
  254. gt_instances.keypoints_visible[0, (2 * i) % 8, 0] = False
  255. gt_instances.bboxes = np.array([[0.1, 0.2, 0.3, 0.4]]) * 20 * i
  256. gt_instances.head_size = np.array([[0.1]]) * 10 * i
  257. pred_instances = InstanceData()
  258. # fake predictions
  259. _keypoints = np.zeros_like(keypoints)
  260. _keypoints[0, :, 0] = keypoints[0, :, 0] * 0.95
  261. _keypoints[0, :, 1] = keypoints[0, :, 1] * 1.05
  262. pred_instances.keypoints = _keypoints
  263. data = {'inputs': None}
  264. data_sample = {
  265. 'gt_instances': gt_instances.to_dict(),
  266. 'pred_instances': pred_instances.to_dict(),
  267. }
  268. self.data_batch.append(data)
  269. self.data_samples.append(data_sample)
  270. def test_init(self):
  271. # test norm_item arg warning
  272. with self.assertWarnsRegex(UserWarning,
  273. 'norm_item torso is used in JhmdbDataset'):
  274. _ = KeypointPartitionMetric(
  275. metric=dict(
  276. type='PCKAccuracy', thr=0.05, norm_item=['bbox', 'torso']),
  277. partitions=dict(all=range(133)))
  278. def test_evaluate(self):
  279. """test PCK accuracy evaluation metric."""
  280. # test normalized by 'bbox'
  281. pck_metric = KeypointPartitionMetric(
  282. metric=dict(type='PCKAccuracy', thr=0.5, norm_item='bbox'),
  283. partitions=dict(
  284. p1=range(10),
  285. p2=range(10, 24),
  286. all=range(24),
  287. ))
  288. pck_metric.process(self.data_batch, self.data_samples)
  289. pck = pck_metric.evaluate(self.batch_size)
  290. target = {'p1/PCK': 1.0, 'p2/PCK': 1.0, 'all/PCK': 1.0}
  291. self.assertDictEqual(pck, target)
  292. # test normalized by 'head_size'
  293. pckh_metric = KeypointPartitionMetric(
  294. metric=dict(type='PCKAccuracy', thr=0.3, norm_item='head'),
  295. partitions=dict(
  296. p1=range(10),
  297. p2=range(10, 24),
  298. all=range(24),
  299. ))
  300. pckh_metric.process(self.data_batch, self.data_samples)
  301. pckh = pckh_metric.evaluate(self.batch_size)
  302. target = {'p1/PCKh': 0.9, 'p2/PCKh': 0.0, 'all/PCKh': 0.375}
  303. self.assertDictEqual(pckh, target)
  304. # test normalized by 'torso_size'
  305. tpck_metric = KeypointPartitionMetric(
  306. metric=dict(
  307. type='PCKAccuracy', thr=0.05, norm_item=['bbox', 'torso']),
  308. partitions=dict(
  309. p1=range(10),
  310. p2=range(10, 24),
  311. all=range(24),
  312. ))
  313. tpck_metric.process(self.data_batch, self.data_samples)
  314. tpck = tpck_metric.evaluate(self.batch_size)
  315. self.assertIsInstance(tpck, dict)
  316. target = {
  317. 'p1/PCK': 0.6,
  318. 'p1/tPCK': 0.11428571428571428,
  319. 'p2/PCK': 0.0,
  320. 'p2/tPCK': 0.0,
  321. 'all/PCK': 0.25,
  322. 'all/tPCK': 0.047619047619047616
  323. }
  324. self.assertDictEqual(tpck, target)
  325. class TestKeypointPartitionMetricWrappingAUCandEPE(TestCase):
  326. def setUp(self):
  327. """Setup some variables which are used in every test method.
  328. TestCase calls functions in this order: setUp() -> testMethod() ->
  329. tearDown() -> cleanUp()
  330. """
  331. output = np.zeros((1, 5, 2))
  332. target = np.zeros((1, 5, 2))
  333. # first channel
  334. output[0, 0] = [10, 4]
  335. target[0, 0] = [10, 0]
  336. # second channel
  337. output[0, 1] = [10, 18]
  338. target[0, 1] = [10, 10]
  339. # third channel
  340. output[0, 2] = [0, 0]
  341. target[0, 2] = [0, -1]
  342. # fourth channel
  343. output[0, 3] = [40, 40]
  344. target[0, 3] = [30, 30]
  345. # fifth channel
  346. output[0, 4] = [20, 10]
  347. target[0, 4] = [0, 10]
  348. gt_instances = InstanceData()
  349. gt_instances.keypoints = target
  350. gt_instances.keypoints_visible = np.array(
  351. [[True, True, False, True, True]])
  352. pred_instances = InstanceData()
  353. pred_instances.keypoints = output
  354. data = {'inputs': None}
  355. data_sample = {
  356. 'gt_instances': gt_instances.to_dict(),
  357. 'pred_instances': pred_instances.to_dict()
  358. }
  359. self.data_batch = [data]
  360. self.data_samples = [data_sample]
  361. def test_auc_evaluate(self):
  362. """test AUC evaluation metric."""
  363. auc_metric = KeypointPartitionMetric(
  364. metric=dict(type='AUC', norm_factor=20, num_thrs=4),
  365. partitions=dict(
  366. p1=range(3),
  367. p2=range(3, 5),
  368. all=range(5),
  369. ))
  370. auc_metric.process(self.data_batch, self.data_samples)
  371. auc = auc_metric.evaluate(1)
  372. target = {'p1/AUC': 0.625, 'p2/AUC': 0.125, 'all/AUC': 0.375}
  373. self.assertDictEqual(auc, target)
  374. def test_epe_evaluate(self):
  375. """test EPE evaluation metric."""
  376. epe_metric = KeypointPartitionMetric(
  377. metric=dict(type='EPE', ),
  378. partitions=dict(
  379. p1=range(3),
  380. p2=range(3, 5),
  381. all=range(5),
  382. ))
  383. epe_metric.process(self.data_batch, self.data_samples)
  384. epe = epe_metric.evaluate(1)
  385. target = {
  386. 'p1/EPE': 6.0,
  387. 'p2/EPE': 17.071067810058594,
  388. 'all/EPE': 11.535533905029297
  389. }
  390. self.assertDictEqual(epe, target)
  391. class TestKeypointPartitionMetricWrappingNME(TestCase):
  392. def setUp(self):
  393. """Setup some variables which are used in every test method.
  394. TestCase calls functions in this order: setUp() -> testMethod() ->
  395. tearDown() -> cleanUp()
  396. """
  397. self.batch_size = 4
  398. num_keypoints = 19
  399. self.data_batch = []
  400. self.data_samples = []
  401. for i in range(self.batch_size):
  402. gt_instances = InstanceData()
  403. keypoints = np.zeros((1, num_keypoints, 2))
  404. for j in range(num_keypoints):
  405. keypoints[0, j] = [0.5 * i * j, 0.5 * i * j]
  406. gt_instances.keypoints = keypoints
  407. gt_instances.keypoints_visible = np.ones(
  408. (1, num_keypoints, 1)).astype(bool)
  409. gt_instances.keypoints_visible[0, (2 * i) % self.batch_size,
  410. 0] = False
  411. gt_instances['box_size'] = np.array([[0.1]]) * 10 * i
  412. pred_instances = InstanceData()
  413. # fake predictions
  414. _keypoints = np.zeros_like(keypoints)
  415. _keypoints[0, :, 0] = keypoints[0, :, 0] * 0.95
  416. _keypoints[0, :, 1] = keypoints[0, :, 1] * 1.05
  417. pred_instances.keypoints = _keypoints
  418. data = {'inputs': None}
  419. data_sample = {
  420. 'gt_instances': gt_instances.to_dict(),
  421. 'pred_instances': pred_instances.to_dict(),
  422. }
  423. self.data_batch.append(data)
  424. self.data_samples.append(data_sample)
  425. def test_init(self):
  426. # test norm_mode arg missing
  427. with self.assertRaisesRegex(AssertionError, 'Missing norm_mode'):
  428. _ = KeypointPartitionMetric(
  429. metric=dict(type='NME', ), partitions=dict(all=range(133)))
  430. # test norm_mode = keypoint_distance
  431. with self.assertRaisesRegex(ValueError,
  432. "NME norm_mode 'keypoint_distance'"):
  433. _ = KeypointPartitionMetric(
  434. metric=dict(type='NME', norm_mode='keypoint_distance'),
  435. partitions=dict(all=range(133)))
  436. def test_nme_evaluate(self):
  437. """test NME evaluation metric."""
  438. # test when norm_mode = 'use_norm_item'
  439. # test norm_item = 'box_size' like in `AFLWDataset`
  440. nme_metric = KeypointPartitionMetric(
  441. metric=dict(
  442. type='NME', norm_mode='use_norm_item', norm_item='box_size'),
  443. partitions=dict(
  444. p1=range(10),
  445. p2=range(10, 19),
  446. all=range(19),
  447. ))
  448. nme_metric.process(self.data_batch, self.data_samples)
  449. nme = nme_metric.evaluate(4)
  450. target = {
  451. 'p1/NME': 0.1715388651247378,
  452. 'p2/NME': 0.4949747721354167,
  453. 'all/NME': 0.333256827460395
  454. }
  455. self.assertDictEqual(nme, target)