test_posetrack18_metric.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import tempfile
  4. from collections import defaultdict
  5. from unittest import TestCase
  6. import numpy as np
  7. from mmengine.fileio import dump, load
  8. from mmpose.datasets.datasets.utils import parse_pose_metainfo
  9. from mmpose.evaluation.metrics import PoseTrack18Metric
  10. class TestPoseTrack18Metric(TestCase):
  11. def setUp(self):
  12. """Setup some variables which are used in every test method.
  13. TestCase calls functions in this order: setUp() -> testMethod() ->
  14. tearDown() -> cleanUp()
  15. """
  16. self.tmp_dir = tempfile.TemporaryDirectory()
  17. self.ann_file = 'tests/data/posetrack18/annotations/'\
  18. 'test_posetrack18_val.json'
  19. posetrack18_meta_info = dict(
  20. from_file='configs/_base_/datasets/posetrack18.py')
  21. self.posetrack18_dataset_meta = parse_pose_metainfo(
  22. posetrack18_meta_info)
  23. self.db = load(self.ann_file)
  24. self.topdown_data = self._convert_ann_to_topdown_batch_data()
  25. assert len(self.topdown_data) == 14
  26. self.bottomup_data = self._convert_ann_to_bottomup_batch_data()
  27. assert len(self.bottomup_data) == 3
  28. self.target = {
  29. 'posetrack18/Head AP': 100.0,
  30. 'posetrack18/Shou AP': 100.0,
  31. 'posetrack18/Elb AP': 100.0,
  32. 'posetrack18/Wri AP': 100.0,
  33. 'posetrack18/Hip AP': 100.0,
  34. 'posetrack18/Knee AP': 100.0,
  35. 'posetrack18/Ankl AP': 100.0,
  36. 'posetrack18/AP': 100.0,
  37. }
  38. def _convert_ann_to_topdown_batch_data(self):
  39. """Convert annotations to topdown-style batch data."""
  40. topdown_data = []
  41. for ann in self.db['annotations']:
  42. w, h = ann['bbox'][2], ann['bbox'][3]
  43. bboxes = np.array(ann['bbox'], dtype=np.float32).reshape(-1, 4)
  44. bbox_scales = np.array([w * 1.25, h * 1.25]).reshape(-1, 2)
  45. keypoints = np.array(ann['keypoints']).reshape((1, -1, 3))
  46. gt_instances = {
  47. 'bbox_scales': bbox_scales,
  48. 'bboxes': bboxes,
  49. 'bbox_scores': np.ones((1, ), dtype=np.float32),
  50. }
  51. pred_instances = {
  52. 'keypoints': keypoints[..., :2],
  53. 'keypoint_scores': keypoints[..., -1],
  54. }
  55. data = {'inputs': None}
  56. data_sample = {
  57. 'id': ann['id'],
  58. 'img_id': ann['image_id'],
  59. 'gt_instances': gt_instances,
  60. 'pred_instances': pred_instances
  61. }
  62. # batch size = 1
  63. data_batch = [data]
  64. data_samples = [data_sample]
  65. topdown_data.append((data_batch, data_samples))
  66. return topdown_data
  67. def _convert_ann_to_bottomup_batch_data(self):
  68. """Convert annotations to bottomup-style batch data."""
  69. img2ann = defaultdict(list)
  70. for ann in self.db['annotations']:
  71. img2ann[ann['image_id']].append(ann)
  72. bottomup_data = []
  73. for img_id, anns in img2ann.items():
  74. keypoints = np.array([ann['keypoints'] for ann in anns]).reshape(
  75. (len(anns), -1, 3))
  76. gt_instances = {
  77. 'bbox_scores': np.ones((len(anns)), dtype=np.float32)
  78. }
  79. pred_instances = {
  80. 'keypoints': keypoints[..., :2],
  81. 'keypoint_scores': keypoints[..., -1],
  82. }
  83. data = {'inputs': None}
  84. data_sample = {
  85. 'id': [ann['id'] for ann in anns],
  86. 'img_id': img_id,
  87. 'gt_instances': gt_instances,
  88. 'pred_instances': pred_instances
  89. }
  90. # batch size = 1
  91. data_batch = [data]
  92. data_samples = [data_sample]
  93. bottomup_data.append((data_batch, data_samples))
  94. return bottomup_data
  95. def tearDown(self):
  96. self.tmp_dir.cleanup()
  97. def test_init(self):
  98. """test metric init method."""
  99. # test score_mode option
  100. with self.assertRaisesRegex(ValueError,
  101. '`score_mode` should be one of'):
  102. _ = PoseTrack18Metric(ann_file=self.ann_file, score_mode='invalid')
  103. # test nms_mode option
  104. with self.assertRaisesRegex(ValueError, '`nms_mode` should be one of'):
  105. _ = PoseTrack18Metric(ann_file=self.ann_file, nms_mode='invalid')
  106. # test `format_only` option
  107. with self.assertRaisesRegex(
  108. AssertionError,
  109. '`outfile_prefix` can not be None when `format_only` is True'):
  110. _ = PoseTrack18Metric(
  111. ann_file=self.ann_file, format_only=True, outfile_prefix=None)
  112. def test_topdown_evaluate(self):
  113. """test topdown-style posetrack18 metric evaluation."""
  114. # case 1: score_mode='bbox', nms_mode='none'
  115. posetrack18_metric = PoseTrack18Metric(
  116. ann_file=self.ann_file,
  117. outfile_prefix=f'{self.tmp_dir.name}/test',
  118. score_mode='bbox',
  119. nms_mode='none')
  120. posetrack18_metric.dataset_meta = self.posetrack18_dataset_meta
  121. # process samples
  122. for data_batch, data_samples in self.topdown_data:
  123. posetrack18_metric.process(data_batch, data_samples)
  124. eval_results = posetrack18_metric.evaluate(size=len(self.topdown_data))
  125. self.assertDictEqual(eval_results, self.target)
  126. self.assertTrue(
  127. osp.isfile(osp.join(self.tmp_dir.name, '003418_mpii_test.json')))
  128. # case 2: score_mode='bbox_keypoint', nms_mode='oks_nms'
  129. posetrack18_metric = PoseTrack18Metric(
  130. ann_file=self.ann_file,
  131. outfile_prefix=f'{self.tmp_dir.name}/test',
  132. score_mode='bbox_keypoint',
  133. nms_mode='oks_nms')
  134. posetrack18_metric.dataset_meta = self.posetrack18_dataset_meta
  135. # process samples
  136. for data_batch, data_samples in self.topdown_data:
  137. posetrack18_metric.process(data_batch, data_samples)
  138. eval_results = posetrack18_metric.evaluate(size=len(self.topdown_data))
  139. self.assertDictEqual(eval_results, self.target)
  140. self.assertTrue(
  141. osp.isfile(osp.join(self.tmp_dir.name, '009473_mpii_test.json')))
  142. # case 3: score_mode='bbox_keypoint', nms_mode='soft_oks_nms'
  143. posetrack18_metric = PoseTrack18Metric(
  144. ann_file=self.ann_file,
  145. outfile_prefix=f'{self.tmp_dir.name}/test',
  146. score_mode='bbox_keypoint',
  147. nms_mode='soft_oks_nms')
  148. posetrack18_metric.dataset_meta = self.posetrack18_dataset_meta
  149. # process samples
  150. for data_batch, data_samples in self.topdown_data:
  151. posetrack18_metric.process(data_batch, data_samples)
  152. eval_results = posetrack18_metric.evaluate(size=len(self.topdown_data))
  153. self.assertDictEqual(eval_results, self.target)
  154. self.assertTrue(
  155. osp.isfile(osp.join(self.tmp_dir.name, '012834_mpii_test.json')))
  156. def test_bottomup_evaluate(self):
  157. """test bottomup-style posetrack18 metric evaluation."""
  158. # case 1: score_mode='bbox', nms_mode='none'
  159. posetrack18_metric = PoseTrack18Metric(
  160. ann_file=self.ann_file, outfile_prefix=f'{self.tmp_dir.name}/test')
  161. posetrack18_metric.dataset_meta = self.posetrack18_dataset_meta
  162. # process samples
  163. for data_batch, data_samples in self.bottomup_data:
  164. posetrack18_metric.process(data_batch, data_samples)
  165. eval_results = posetrack18_metric.evaluate(
  166. size=len(self.bottomup_data))
  167. self.assertDictEqual(eval_results, self.target)
  168. self.assertTrue(
  169. osp.isfile(osp.join(self.tmp_dir.name, '009473_mpii_test.json')))
  170. def test_other_methods(self):
  171. """test other useful methods."""
  172. # test `_sort_and_unique_bboxes` method
  173. posetrack18_metric = PoseTrack18Metric(ann_file=self.ann_file)
  174. posetrack18_metric.dataset_meta = self.posetrack18_dataset_meta
  175. # process samples
  176. for data_batch, data_samples in self.topdown_data:
  177. posetrack18_metric.process(data_batch, data_samples)
  178. # process one extra sample
  179. data_batch, data_samples = self.topdown_data[0]
  180. posetrack18_metric.process(data_batch, data_samples)
  181. # an extra sample
  182. eval_results = posetrack18_metric.evaluate(
  183. size=len(self.topdown_data) + 1)
  184. self.assertDictEqual(eval_results, self.target)
  185. def test_format_only(self):
  186. """test `format_only` option."""
  187. posetrack18_metric = PoseTrack18Metric(
  188. ann_file=self.ann_file,
  189. format_only=True,
  190. outfile_prefix=f'{self.tmp_dir.name}/test')
  191. posetrack18_metric.dataset_meta = self.posetrack18_dataset_meta
  192. # process samples
  193. for data_batch, data_samples in self.topdown_data:
  194. posetrack18_metric.process(data_batch, data_samples)
  195. eval_results = posetrack18_metric.evaluate(size=len(self.topdown_data))
  196. self.assertDictEqual(eval_results, {})
  197. self.assertTrue(
  198. osp.isfile(osp.join(self.tmp_dir.name, '012834_mpii_test.json')))
  199. # test when gt annotations are absent
  200. db_ = load(self.ann_file)
  201. del db_['annotations']
  202. tmp_ann_file = osp.join(self.tmp_dir.name, 'temp_ann.json')
  203. dump(db_, tmp_ann_file, sort_keys=True, indent=4)
  204. with self.assertRaisesRegex(
  205. AssertionError,
  206. 'Ground truth annotations are required for evaluation'):
  207. _ = PoseTrack18Metric(ann_file=tmp_ann_file, format_only=False)
  208. def test_topdown_alignment(self):
  209. """Test whether the output of PoseTrack18Metric and the original
  210. TopDownPoseTrack18Dataset are the same."""
  211. self.skipTest('Skip test.')
  212. topdown_data = []
  213. for ann in self.db['annotations']:
  214. w, h = ann['bbox'][2], ann['bbox'][3]
  215. bboxes = np.array(ann['bbox'], dtype=np.float32).reshape(-1, 4)
  216. bbox_scales = np.array([w * 1.25, h * 1.25]).reshape(-1, 2)
  217. keypoints = np.array(
  218. ann['keypoints'], dtype=np.float32).reshape(1, 17, 3)
  219. keypoints[..., 0] = keypoints[..., 0] * 0.98
  220. keypoints[..., 1] = keypoints[..., 1] * 1.02
  221. keypoints[..., 2] = keypoints[..., 2] * 0.8
  222. gt_instances = {
  223. 'bbox_scales': bbox_scales,
  224. 'bbox_scores': np.ones((1, ), dtype=np.float32) * 0.98,
  225. 'bboxes': bboxes,
  226. }
  227. pred_instances = {
  228. 'keypoints': keypoints[..., :2],
  229. 'keypoint_scores': keypoints[..., -1],
  230. }
  231. data = {'inputs': None}
  232. data_sample = {
  233. 'id': ann['id'],
  234. 'img_id': ann['image_id'],
  235. 'gt_instances': gt_instances,
  236. 'pred_instances': pred_instances
  237. }
  238. # batch size = 1
  239. data_batch = [data]
  240. data_samples = [data_sample]
  241. topdown_data.append((data_batch, data_samples))
  242. # case 1:
  243. # typical setting: score_mode='bbox_keypoint', nms_mode='oks_nms'
  244. posetrack18_metric = PoseTrack18Metric(
  245. ann_file=self.ann_file,
  246. outfile_prefix=f'{self.tmp_dir.name}/test',
  247. score_mode='bbox_keypoint',
  248. nms_mode='oks_nms')
  249. posetrack18_metric.dataset_meta = self.posetrack18_dataset_meta
  250. # process samples
  251. for data_batch, data_samples in topdown_data:
  252. posetrack18_metric.process(data_batch, data_samples)
  253. eval_results = posetrack18_metric.evaluate(size=len(topdown_data))
  254. target = {
  255. 'posetrack18/Head AP': 84.6677132391418,
  256. 'posetrack18/Shou AP': 80.86734693877551,
  257. 'posetrack18/Elb AP': 83.0204081632653,
  258. 'posetrack18/Wri AP': 85.12396694214877,
  259. 'posetrack18/Hip AP': 75.14792899408285,
  260. 'posetrack18/Knee AP': 66.76515151515152,
  261. 'posetrack18/Ankl AP': 71.78571428571428,
  262. 'posetrack18/Total AP': 78.62827822638012,
  263. }
  264. for key in eval_results.keys():
  265. self.assertAlmostEqual(eval_results[key], target[key])
  266. self.assertTrue(
  267. osp.isfile(osp.join(self.tmp_dir.name, '012834_mpii_test.json')))
  268. topdown_data = []
  269. anns = self.db['annotations']
  270. for i, ann in enumerate(anns):
  271. w, h = ann['bbox'][2], ann['bbox'][3]
  272. bboxes = np.array(ann['bbox'], dtype=np.float32).reshape(-1, 4)
  273. bbox_scales = np.array([w * 1.25, h * 1.25]).reshape(-1, 2)
  274. keypoints = np.array(
  275. ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
  276. keypoints[..., 0] = keypoints[..., 0] * (1 - i / 100)
  277. keypoints[..., 1] = keypoints[..., 1] * (1 + i / 100)
  278. keypoints[..., 2] = keypoints[..., 2] * (1 - i / 100)
  279. gt_instances0 = {
  280. 'bbox_scales': bbox_scales,
  281. 'bbox_scores': np.ones((1, ), dtype=np.float32),
  282. 'bboxes': bboxes,
  283. }
  284. pred_instances0 = {
  285. 'keypoints': keypoints[..., :2],
  286. 'keypoint_scores': keypoints[..., -1],
  287. }
  288. data0 = {'inputs': None}
  289. data_sample0 = {
  290. 'id': ann['id'],
  291. 'img_id': ann['image_id'],
  292. 'gt_instances': gt_instances0,
  293. 'pred_instances': pred_instances0
  294. }
  295. keypoints = np.array(
  296. ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
  297. keypoints[..., 0] = keypoints[..., 0] * (1 + i / 100)
  298. keypoints[..., 1] = keypoints[..., 1] * (1 - i / 100)
  299. keypoints[..., 2] = keypoints[..., 2] * (1 - 2 * i / 100)
  300. gt_instances1 = {
  301. 'bbox_scales': bbox_scales,
  302. 'bboxes': bboxes,
  303. 'bbox_scores': np.ones(
  304. (1, ), dtype=np.float32) * (1 - 2 * i / 100)
  305. }
  306. pred_instances1 = {
  307. 'keypoints': keypoints[..., :2],
  308. 'keypoint_scores': keypoints[..., -1],
  309. }
  310. data1 = {'inputs': None}
  311. data_sample1 = {
  312. 'id': ann['id'] + 1,
  313. 'img_id': ann['image_id'],
  314. 'gt_instances': gt_instances1,
  315. 'pred_instances': pred_instances1
  316. }
  317. # batch size = 2
  318. data_batch = [data0, data1]
  319. data_samples = [data_sample0, data_sample1]
  320. topdown_data.append((data_batch, data_samples))
  321. # case 3: score_mode='bbox_keypoint', nms_mode='soft_oks_nms'
  322. posetrack18_metric = PoseTrack18Metric(
  323. ann_file=self.ann_file,
  324. outfile_prefix=f'{self.tmp_dir.name}/test',
  325. score_mode='bbox_keypoint',
  326. keypoint_score_thr=0.2,
  327. nms_thr=0.9,
  328. nms_mode='soft_oks_nms')
  329. posetrack18_metric.dataset_meta = self.posetrack18_dataset_meta
  330. # process samples
  331. for data_batch, data_samples in topdown_data:
  332. posetrack18_metric.process(data_batch, data_samples)
  333. eval_results = posetrack18_metric.evaluate(size=len(topdown_data) * 2)
  334. target = {
  335. 'posetrack18/Head AP': 27.1062271062271068,
  336. 'posetrack18/Shou AP': 25.918367346938776,
  337. 'posetrack18/Elb AP': 22.67857142857143,
  338. 'posetrack18/Wri AP': 29.090909090909093,
  339. 'posetrack18/Hip AP': 18.40659340659341,
  340. 'posetrack18/Knee AP': 32.0,
  341. 'posetrack18/Ankl AP': 20.0,
  342. 'posetrack18/Total AP': 25.167170924313783,
  343. }
  344. for key in eval_results.keys():
  345. self.assertAlmostEqual(eval_results[key], target[key])
  346. self.assertTrue(
  347. osp.isfile(osp.join(self.tmp_dir.name, '009473_mpii_test.json')))