keypoints2coco_without_mmdet.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import json
  3. import os
  4. from argparse import ArgumentParser
  5. from mmcv import track_iter_progress
  6. from PIL import Image
  7. from xtcocotools.coco import COCO
  8. from mmpose.apis import inference_top_down_pose_model, init_pose_model
  9. def main():
  10. """Visualize the demo images.
  11. pose_keypoints require the json_file containing boxes.
  12. """
  13. parser = ArgumentParser()
  14. parser.add_argument('pose_config', help='Config file for detection')
  15. parser.add_argument('pose_checkpoint', help='Checkpoint file')
  16. parser.add_argument('--img-root', type=str, default='', help='Image root')
  17. parser.add_argument(
  18. '--json-file',
  19. type=str,
  20. default='',
  21. help='Json file containing image person bboxes in COCO format.')
  22. parser.add_argument(
  23. '--out-json-file',
  24. type=str,
  25. default='',
  26. help='Output json contains pseudolabeled annotation')
  27. parser.add_argument(
  28. '--show',
  29. action='store_true',
  30. default=False,
  31. help='whether to show img')
  32. parser.add_argument(
  33. '--device', default='cuda:0', help='Device used for inference')
  34. parser.add_argument(
  35. '--kpt-thr', type=float, default=0.3, help='Keypoint score threshold')
  36. args = parser.parse_args()
  37. coco = COCO(args.json_file)
  38. # build the pose model from a config file and a checkpoint file
  39. pose_model = init_pose_model(
  40. args.pose_config, args.pose_checkpoint, device=args.device.lower())
  41. dataset = pose_model.cfg.data['test']['type']
  42. img_keys = list(coco.imgs.keys())
  43. # optional
  44. return_heatmap = False
  45. # e.g. use ('backbone', ) to return backbone feature
  46. output_layer_names = None
  47. categories = [{'id': 1, 'name': 'person'}]
  48. img_anno_dict = {'images': [], 'annotations': [], 'categories': categories}
  49. # process each image
  50. ann_uniq_id = int(0)
  51. for i in track_iter_progress(range(len(img_keys))):
  52. # get bounding box annotations
  53. image_id = img_keys[i]
  54. image = coco.loadImgs(image_id)[0]
  55. image_name = os.path.join(args.img_root, image['file_name'])
  56. width, height = Image.open(image_name).size
  57. ann_ids = coco.getAnnIds(image_id)
  58. # make person bounding boxes
  59. person_results = []
  60. for ann_id in ann_ids:
  61. person = {}
  62. ann = coco.anns[ann_id]
  63. # bbox format is 'xywh'
  64. person['bbox'] = ann['bbox']
  65. person_results.append(person)
  66. pose_results, returned_outputs = inference_top_down_pose_model(
  67. pose_model,
  68. image_name,
  69. person_results,
  70. bbox_thr=None,
  71. format='xywh',
  72. dataset=dataset,
  73. return_heatmap=return_heatmap,
  74. outputs=output_layer_names)
  75. # add output of model and bboxes to dict
  76. for indx, i in enumerate(pose_results):
  77. pose_results[indx]['keypoints'][
  78. pose_results[indx]['keypoints'][:, 2] < args.kpt_thr, :3] = 0
  79. pose_results[indx]['keypoints'][
  80. pose_results[indx]['keypoints'][:, 2] >= args.kpt_thr, 2] = 2
  81. x = int(pose_results[indx]['bbox'][0])
  82. y = int(pose_results[indx]['bbox'][1])
  83. w = int(pose_results[indx]['bbox'][2] -
  84. pose_results[indx]['bbox'][0])
  85. h = int(pose_results[indx]['bbox'][3] -
  86. pose_results[indx]['bbox'][1])
  87. bbox = [x, y, w, h]
  88. area = round((w * h), 0)
  89. images = {
  90. 'file_name': image_name.split('/')[-1],
  91. 'height': height,
  92. 'width': width,
  93. 'id': int(image_id)
  94. }
  95. annotations = {
  96. 'keypoints': [
  97. int(i) for i in pose_results[indx]['keypoints'].reshape(
  98. -1).tolist()
  99. ],
  100. 'num_keypoints':
  101. len(pose_results[indx]['keypoints']),
  102. 'area':
  103. area,
  104. 'iscrowd':
  105. 0,
  106. 'image_id':
  107. int(image_id),
  108. 'bbox':
  109. bbox,
  110. 'category_id':
  111. 1,
  112. 'id':
  113. ann_uniq_id,
  114. }
  115. img_anno_dict['annotations'].append(annotations)
  116. ann_uniq_id += 1
  117. img_anno_dict['images'].append(images)
  118. # create json
  119. with open(args.out_json_file, 'w') as outfile:
  120. json.dump(img_anno_dict, outfile, indent=2)
  121. if __name__ == '__main__':
  122. main()