test_keypoint_2d_metrics.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from unittest import TestCase
  3. import numpy as np
  4. from mmengine.structures import InstanceData
  5. from mmpose.datasets.datasets.utils import parse_pose_metainfo
  6. from mmpose.evaluation.metrics import (AUC, EPE, NME, JhmdbPCKAccuracy,
  7. MpiiPCKAccuracy, PCKAccuracy)
  8. class TestPCKAccuracy(TestCase):
  9. def setUp(self):
  10. """Setup some variables which are used in every test method.
  11. TestCase calls functions in this order: setUp() -> testMethod() ->
  12. tearDown() -> cleanUp()
  13. """
  14. self.batch_size = 8
  15. num_keypoints = 15
  16. self.data_batch = []
  17. self.data_samples = []
  18. for i in range(self.batch_size):
  19. gt_instances = InstanceData()
  20. keypoints = np.zeros((1, num_keypoints, 2))
  21. keypoints[0, i] = [0.5 * i, 0.5 * i]
  22. gt_instances.keypoints = keypoints
  23. gt_instances.keypoints_visible = np.ones(
  24. (1, num_keypoints, 1)).astype(bool)
  25. gt_instances.keypoints_visible[0, (2 * i) % 8, 0] = False
  26. gt_instances.bboxes = np.random.random((1, 4)) * 20 * i
  27. gt_instances.head_size = np.random.random((1, 1)) * 10 * i
  28. pred_instances = InstanceData()
  29. pred_instances.keypoints = keypoints
  30. data = {'inputs': None}
  31. data_sample = {
  32. 'gt_instances': gt_instances.to_dict(),
  33. 'pred_instances': pred_instances.to_dict(),
  34. }
  35. self.data_batch.append(data)
  36. self.data_samples.append(data_sample)
  37. def test_init(self):
  38. """test metric init method."""
  39. # test invalid normalized_items
  40. with self.assertRaisesRegex(
  41. KeyError, "Should be one of 'bbox', 'head', 'torso'"):
  42. PCKAccuracy(norm_item='invalid')
  43. def test_evaluate(self):
  44. """test PCK accuracy evaluation metric."""
  45. # test normalized by 'bbox'
  46. pck_metric = PCKAccuracy(thr=0.5, norm_item='bbox')
  47. pck_metric.process(self.data_batch, self.data_samples)
  48. pck = pck_metric.evaluate(self.batch_size)
  49. target = {'PCK': 1.0}
  50. self.assertDictEqual(pck, target)
  51. # test normalized by 'head_size'
  52. pckh_metric = PCKAccuracy(thr=0.3, norm_item='head')
  53. pckh_metric.process(self.data_batch, self.data_samples)
  54. pckh = pckh_metric.evaluate(self.batch_size)
  55. target = {'PCKh': 1.0}
  56. self.assertDictEqual(pckh, target)
  57. # test normalized by 'torso_size'
  58. tpck_metric = PCKAccuracy(thr=0.05, norm_item=['bbox', 'torso'])
  59. tpck_metric.process(self.data_batch, self.data_samples)
  60. tpck = tpck_metric.evaluate(self.batch_size)
  61. self.assertIsInstance(tpck, dict)
  62. target = {
  63. 'PCK': 1.0,
  64. 'tPCK': 1.0,
  65. }
  66. self.assertDictEqual(tpck, target)
  67. class TestMpiiPCKAccuracy(TestCase):
  68. def setUp(self):
  69. """Setup some variables which are used in every test method.
  70. TestCase calls functions in this order: setUp() -> testMethod() ->
  71. tearDown() -> cleanUp()
  72. """
  73. self.batch_size = 8
  74. num_keypoints = 16
  75. self.data_batch = []
  76. self.data_samples = []
  77. for i in range(self.batch_size):
  78. gt_instances = InstanceData()
  79. keypoints = np.zeros((1, num_keypoints, 2))
  80. keypoints[0, i] = [0.5 * i, 0.5 * i]
  81. gt_instances.keypoints = keypoints + 1.0
  82. gt_instances.keypoints_visible = np.ones(
  83. (1, num_keypoints, 1)).astype(bool)
  84. gt_instances.keypoints_visible[0, (2 * i) % 8, 0] = False
  85. gt_instances.bboxes = np.random.random((1, 4)) * 20 * i
  86. gt_instances.head_size = np.random.random((1, 1)) * 10 * i
  87. pred_instances = InstanceData()
  88. pred_instances.keypoints = keypoints
  89. data = {'inputs': None}
  90. data_sample = {
  91. 'gt_instances': gt_instances.to_dict(),
  92. 'pred_instances': pred_instances.to_dict(),
  93. }
  94. self.data_batch.append(data)
  95. self.data_samples.append(data_sample)
  96. def test_init(self):
  97. """test metric init method."""
  98. # test invalid normalized_items
  99. with self.assertRaisesRegex(
  100. KeyError, "Should be one of 'bbox', 'head', 'torso'"):
  101. MpiiPCKAccuracy(norm_item='invalid')
  102. def test_evaluate(self):
  103. """test PCK accuracy evaluation metric."""
  104. # test normalized by 'head_size'
  105. mpii_pck_metric = MpiiPCKAccuracy(thr=0.3, norm_item='head')
  106. mpii_pck_metric.process(self.data_batch, self.data_samples)
  107. pck_results = mpii_pck_metric.evaluate(self.batch_size)
  108. target = {
  109. 'Head PCK': 100.0,
  110. 'Shoulder PCK': 100.0,
  111. 'Elbow PCK': 100.0,
  112. 'Wrist PCK': 100.0,
  113. 'Hip PCK': 100.0,
  114. 'Knee PCK': 100.0,
  115. 'Ankle PCK': 100.0,
  116. 'PCK': 100.0,
  117. 'PCK@0.1': 100.0,
  118. }
  119. self.assertDictEqual(pck_results, target)
  120. class TestJhmdbPCKAccuracy(TestCase):
  121. def setUp(self):
  122. """Setup some variables which are used in every test method.
  123. TestCase calls functions in this order: setUp() -> testMethod() ->
  124. tearDown() -> cleanUp()
  125. """
  126. self.batch_size = 8
  127. num_keypoints = 15
  128. self.data_batch = []
  129. self.data_samples = []
  130. for i in range(self.batch_size):
  131. gt_instances = InstanceData()
  132. keypoints = np.zeros((1, num_keypoints, 2))
  133. keypoints[0, i] = [0.5 * i, 0.5 * i]
  134. gt_instances.keypoints = keypoints
  135. gt_instances.keypoints_visible = np.ones(
  136. (1, num_keypoints, 1)).astype(bool)
  137. gt_instances.keypoints_visible[0, (2 * i) % 8, 0] = False
  138. gt_instances.bboxes = np.random.random((1, 4)) * 20 * i
  139. gt_instances.head_size = np.random.random((1, 1)) * 10 * i
  140. pred_instances = InstanceData()
  141. pred_instances.keypoints = keypoints
  142. data = {'inputs': None}
  143. data_sample = {
  144. 'gt_instances': gt_instances.to_dict(),
  145. 'pred_instances': pred_instances.to_dict(),
  146. }
  147. self.data_batch.append(data)
  148. self.data_samples.append(data_sample)
  149. def test_init(self):
  150. """test metric init method."""
  151. # test invalid normalized_items
  152. with self.assertRaisesRegex(
  153. KeyError, "Should be one of 'bbox', 'head', 'torso'"):
  154. JhmdbPCKAccuracy(norm_item='invalid')
  155. def test_evaluate(self):
  156. """test PCK accuracy evaluation metric."""
  157. # test normalized by 'bbox_size'
  158. jhmdb_pck_metric = JhmdbPCKAccuracy(thr=0.5, norm_item='bbox')
  159. jhmdb_pck_metric.process(self.data_batch, self.data_samples)
  160. pck_results = jhmdb_pck_metric.evaluate(self.batch_size)
  161. target = {
  162. 'Head PCK': 1.0,
  163. 'Sho PCK': 1.0,
  164. 'Elb PCK': 1.0,
  165. 'Wri PCK': 1.0,
  166. 'Hip PCK': 1.0,
  167. 'Knee PCK': 1.0,
  168. 'Ank PCK': 1.0,
  169. 'PCK': 1.0,
  170. }
  171. self.assertDictEqual(pck_results, target)
  172. # test normalized by 'torso_size'
  173. jhmdb_tpck_metric = JhmdbPCKAccuracy(thr=0.2, norm_item='torso')
  174. jhmdb_tpck_metric.process(self.data_batch, self.data_samples)
  175. tpck_results = jhmdb_tpck_metric.evaluate(self.batch_size)
  176. target = {
  177. 'Head tPCK': 1.0,
  178. 'Sho tPCK': 1.0,
  179. 'Elb tPCK': 1.0,
  180. 'Wri tPCK': 1.0,
  181. 'Hip tPCK': 1.0,
  182. 'Knee tPCK': 1.0,
  183. 'Ank tPCK': 1.0,
  184. 'tPCK': 1.0,
  185. }
  186. self.assertDictEqual(tpck_results, target)
  187. class TestAUCandEPE(TestCase):
  188. def setUp(self):
  189. """Setup some variables which are used in every test method.
  190. TestCase calls functions in this order: setUp() -> testMethod() ->
  191. tearDown() -> cleanUp()
  192. """
  193. output = np.zeros((1, 5, 2))
  194. target = np.zeros((1, 5, 2))
  195. # first channel
  196. output[0, 0] = [10, 4]
  197. target[0, 0] = [10, 0]
  198. # second channel
  199. output[0, 1] = [10, 18]
  200. target[0, 1] = [10, 10]
  201. # third channel
  202. output[0, 2] = [0, 0]
  203. target[0, 2] = [0, -1]
  204. # fourth channel
  205. output[0, 3] = [40, 40]
  206. target[0, 3] = [30, 30]
  207. # fifth channel
  208. output[0, 4] = [20, 10]
  209. target[0, 4] = [0, 10]
  210. gt_instances = InstanceData()
  211. gt_instances.keypoints = target
  212. gt_instances.keypoints_visible = np.array(
  213. [[True, True, False, True, True]])
  214. pred_instances = InstanceData()
  215. pred_instances.keypoints = output
  216. data = {'inputs': None}
  217. data_sample = {
  218. 'gt_instances': gt_instances.to_dict(),
  219. 'pred_instances': pred_instances.to_dict()
  220. }
  221. self.data_batch = [data]
  222. self.data_samples = [data_sample]
  223. def test_auc_evaluate(self):
  224. """test AUC evaluation metric."""
  225. auc_metric = AUC(norm_factor=20, num_thrs=4)
  226. auc_metric.process(self.data_batch, self.data_samples)
  227. auc = auc_metric.evaluate(1)
  228. target = {'AUC': 0.375}
  229. self.assertDictEqual(auc, target)
  230. def test_epe_evaluate(self):
  231. """test EPE evaluation metric."""
  232. epe_metric = EPE()
  233. epe_metric.process(self.data_batch, self.data_samples)
  234. epe = epe_metric.evaluate(1)
  235. self.assertAlmostEqual(epe['EPE'], 11.5355339)
  236. class TestNME(TestCase):
  237. def _generate_data(self,
  238. batch_size: int = 1,
  239. num_keypoints: int = 5,
  240. norm_item: str = 'box_size') -> tuple:
  241. """Generate data_batch and data_samples according to different
  242. settings."""
  243. data_batch = []
  244. data_samples = []
  245. for i in range(batch_size):
  246. gt_instances = InstanceData()
  247. keypoints = np.zeros((1, num_keypoints, 2))
  248. keypoints[0, i] = [0.5 * i, 0.5 * i]
  249. gt_instances.keypoints = keypoints
  250. gt_instances.keypoints_visible = np.ones(
  251. (1, num_keypoints, 1)).astype(bool)
  252. gt_instances.keypoints_visible[0, (2 * i) % batch_size, 0] = False
  253. gt_instances[norm_item] = np.random.random((1, 1)) * 20 * i
  254. pred_instances = InstanceData()
  255. pred_instances.keypoints = keypoints
  256. data = {'inputs': None}
  257. data_sample = {
  258. 'gt_instances': gt_instances.to_dict(),
  259. 'pred_instances': pred_instances.to_dict(),
  260. }
  261. data_batch.append(data)
  262. data_samples.append(data_sample)
  263. return data_batch, data_samples
  264. def test_nme_evaluate(self):
  265. """test NME evaluation metric."""
  266. # test when norm_mode = 'use_norm_item'
  267. # test norm_item = 'box_size' like in `AFLWDataset`
  268. norm_item = 'box_size'
  269. nme_metric = NME(norm_mode='use_norm_item', norm_item=norm_item)
  270. aflw_meta_info = dict(from_file='configs/_base_/datasets/aflw.py')
  271. aflw_dataset_meta = parse_pose_metainfo(aflw_meta_info)
  272. nme_metric.dataset_meta = aflw_dataset_meta
  273. data_batch, data_samples = self._generate_data(
  274. batch_size=4, num_keypoints=19, norm_item=norm_item)
  275. nme_metric.process(data_batch, data_samples)
  276. nme = nme_metric.evaluate(4)
  277. target = {'NME': 0.0}
  278. self.assertDictEqual(nme, target)
  279. # test when norm_mode = 'keypoint_distance'
  280. # when `keypoint_indices = None`,
  281. # use default `keypoint_indices` like in `Horse10Dataset`
  282. nme_metric = NME(norm_mode='keypoint_distance')
  283. horse10_meta_info = dict(
  284. from_file='configs/_base_/datasets/horse10.py')
  285. horse10_dataset_meta = parse_pose_metainfo(horse10_meta_info)
  286. nme_metric.dataset_meta = horse10_dataset_meta
  287. data_batch, data_samples = self._generate_data(
  288. batch_size=4, num_keypoints=22)
  289. nme_metric.process(data_batch, data_samples)
  290. nme = nme_metric.evaluate(4)
  291. target = {'NME': 0.0}
  292. self.assertDictEqual(nme, target)
  293. # test when norm_mode = 'keypoint_distance'
  294. # specify custom `keypoint_indices`
  295. keypoint_indices = [2, 4]
  296. nme_metric = NME(
  297. norm_mode='keypoint_distance', keypoint_indices=keypoint_indices)
  298. coco_meta_info = dict(from_file='configs/_base_/datasets/coco.py')
  299. coco_dataset_meta = parse_pose_metainfo(coco_meta_info)
  300. nme_metric.dataset_meta = coco_dataset_meta
  301. data_batch, data_samples = self._generate_data(
  302. batch_size=2, num_keypoints=17)
  303. nme_metric.process(data_batch, data_samples)
  304. nme = nme_metric.evaluate(2)
  305. target = {'NME': 0.0}
  306. self.assertDictEqual(nme, target)
  307. def test_exceptions_and_warnings(self):
  308. """test exceptions and warnings."""
  309. # test invalid norm_mode
  310. with self.assertRaisesRegex(
  311. KeyError,
  312. "`norm_mode` should be 'use_norm_item' or 'keypoint_distance'"
  313. ):
  314. nme_metric = NME(norm_mode='invalid')
  315. # test when norm_mode = 'use_norm_item' but do not specify norm_item
  316. with self.assertRaisesRegex(
  317. KeyError, '`norm_mode` is set to `"use_norm_item"`, '
  318. 'please specify the `norm_item`'):
  319. nme_metric = NME(norm_mode='use_norm_item', norm_item=None)
  320. # test when norm_mode = 'use_norm_item'
  321. # but the `norm_item` do not in data_info
  322. with self.assertRaisesRegex(
  323. AssertionError,
  324. 'The ground truth data info do not have the expected '
  325. 'normalized factor'):
  326. nme_metric = NME(norm_mode='use_norm_item', norm_item='norm_item1')
  327. coco_meta_info = dict(from_file='configs/_base_/datasets/coco.py')
  328. coco_dataset_meta = parse_pose_metainfo(coco_meta_info)
  329. nme_metric.dataset_meta = coco_dataset_meta
  330. data_batch, data_samples = self._generate_data(
  331. norm_item='norm_item2')
  332. # raise AssertionError here
  333. nme_metric.process(data_batch, data_samples)
  334. # test when norm_mode = 'keypoint_distance', `keypoint_indices` = None
  335. # but the dataset_name not in `DEFAULT_KEYPOINT_INDICES`
  336. with self.assertRaisesRegex(
  337. KeyError, 'can not find the keypoint_indices in '
  338. '`DEFAULT_KEYPOINT_INDICES`'):
  339. nme_metric = NME(
  340. norm_mode='keypoint_distance', keypoint_indices=None)
  341. coco_meta_info = dict(from_file='configs/_base_/datasets/coco.py')
  342. coco_dataset_meta = parse_pose_metainfo(coco_meta_info)
  343. nme_metric.dataset_meta = coco_dataset_meta
  344. data_batch, data_samples = self._generate_data()
  345. nme_metric.process(data_batch, data_samples)
  346. # raise KeyError here
  347. _ = nme_metric.evaluate(1)
  348. # test when len(keypoint_indices) is not 2
  349. with self.assertRaisesRegex(
  350. AssertionError,
  351. 'The keypoint indices used for normalization should be a pair.'
  352. ):
  353. nme_metric = NME(
  354. norm_mode='keypoint_distance', keypoint_indices=[0, 1, 2])
  355. coco_meta_info = dict(from_file='configs/_base_/datasets/coco.py')
  356. coco_dataset_meta = parse_pose_metainfo(coco_meta_info)
  357. nme_metric.dataset_meta = coco_dataset_meta
  358. data_batch, data_samples = self._generate_data()
  359. nme_metric.process(data_batch, data_samples)
  360. # raise AssertionError here
  361. _ = nme_metric.evaluate(1)
  362. # test when dataset does not contain the required keypoint
  363. with self.assertRaisesRegex(AssertionError,
  364. 'dataset does not contain the required'):
  365. nme_metric = NME(
  366. norm_mode='keypoint_distance', keypoint_indices=[17, 18])
  367. coco_meta_info = dict(from_file='configs/_base_/datasets/coco.py')
  368. coco_dataset_meta = parse_pose_metainfo(coco_meta_info)
  369. nme_metric.dataset_meta = coco_dataset_meta
  370. data_batch, predidata_samplesctions = self._generate_data()
  371. nme_metric.process(data_batch, data_samples)
  372. # raise AssertionError here
  373. _ = nme_metric.evaluate(1)