crowdhuman.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import json
  3. import logging
  4. import os.path as osp
  5. import warnings
  6. from typing import List, Union
  7. import mmcv
  8. from mmengine.dist import get_rank
  9. from mmengine.fileio import dump, get, get_text, load
  10. from mmengine.logging import print_log
  11. from mmengine.utils import ProgressBar
  12. from mmdet.registry import DATASETS
  13. from .base_det_dataset import BaseDetDataset
  14. @DATASETS.register_module()
  15. class CrowdHumanDataset(BaseDetDataset):
  16. r"""Dataset for CrowdHuman.
  17. Args:
  18. data_root (str): The root directory for
  19. ``data_prefix`` and ``ann_file``.
  20. ann_file (str): Annotation file path.
  21. extra_ann_file (str | optional):The path of extra image metas
  22. for CrowdHuman. It can be created by CrowdHumanDataset
  23. automatically or by tools/misc/get_crowdhuman_id_hw.py
  24. manually. Defaults to None.
  25. """
  26. METAINFO = {
  27. 'classes': ('person', ),
  28. # palette is a list of color tuples, which is used for visualization.
  29. 'palette': [(220, 20, 60)]
  30. }
  31. def __init__(self, data_root, ann_file, extra_ann_file=None, **kwargs):
  32. # extra_ann_file record the size of each image. This file is
  33. # automatically created when you first load the CrowdHuman
  34. # dataset by mmdet.
  35. if extra_ann_file is not None:
  36. self.extra_ann_exist = True
  37. self.extra_anns = load(extra_ann_file)
  38. else:
  39. ann_file_name = osp.basename(ann_file)
  40. if 'train' in ann_file_name:
  41. self.extra_ann_file = osp.join(data_root, 'id_hw_train.json')
  42. elif 'val' in ann_file_name:
  43. self.extra_ann_file = osp.join(data_root, 'id_hw_val.json')
  44. self.extra_ann_exist = False
  45. if not osp.isfile(self.extra_ann_file):
  46. print_log(
  47. 'extra_ann_file does not exist, prepare to collect '
  48. 'image height and width...',
  49. level=logging.INFO)
  50. self.extra_anns = {}
  51. else:
  52. self.extra_ann_exist = True
  53. self.extra_anns = load(self.extra_ann_file)
  54. super().__init__(data_root=data_root, ann_file=ann_file, **kwargs)
  55. def load_data_list(self) -> List[dict]:
  56. """Load annotations from an annotation file named as ``self.ann_file``
  57. Returns:
  58. List[dict]: A list of annotation.
  59. """ # noqa: E501
  60. anno_strs = get_text(
  61. self.ann_file, backend_args=self.backend_args).strip().split('\n')
  62. print_log('loading CrowdHuman annotation...', level=logging.INFO)
  63. data_list = []
  64. prog_bar = ProgressBar(len(anno_strs))
  65. for i, anno_str in enumerate(anno_strs):
  66. anno_dict = json.loads(anno_str)
  67. parsed_data_info = self.parse_data_info(anno_dict)
  68. data_list.append(parsed_data_info)
  69. prog_bar.update()
  70. if not self.extra_ann_exist and get_rank() == 0:
  71. # TODO: support file client
  72. try:
  73. dump(self.extra_anns, self.extra_ann_file, file_format='json')
  74. except: # noqa
  75. warnings.warn(
  76. 'Cache files can not be saved automatically! To speed up'
  77. 'loading the dataset, please manually generate the cache'
  78. ' file by file tools/misc/get_crowdhuman_id_hw.py')
  79. print_log(
  80. f'\nsave extra_ann_file in {self.data_root}',
  81. level=logging.INFO)
  82. del self.extra_anns
  83. print_log('\nDone', level=logging.INFO)
  84. return data_list
  85. def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
  86. """Parse raw annotation to target format.
  87. Args:
  88. raw_data_info (dict): Raw data information load from ``ann_file``
  89. Returns:
  90. Union[dict, List[dict]]: Parsed annotation.
  91. """
  92. data_info = {}
  93. img_path = osp.join(self.data_prefix['img'],
  94. f"{raw_data_info['ID']}.jpg")
  95. data_info['img_path'] = img_path
  96. data_info['img_id'] = raw_data_info['ID']
  97. if not self.extra_ann_exist:
  98. img_bytes = get(img_path, backend_args=self.backend_args)
  99. img = mmcv.imfrombytes(img_bytes, backend='cv2')
  100. data_info['height'], data_info['width'] = img.shape[:2]
  101. self.extra_anns[raw_data_info['ID']] = img.shape[:2]
  102. del img, img_bytes
  103. else:
  104. data_info['height'], data_info['width'] = self.extra_anns[
  105. raw_data_info['ID']]
  106. instances = []
  107. for i, ann in enumerate(raw_data_info['gtboxes']):
  108. instance = {}
  109. if ann['tag'] not in self.metainfo['classes']:
  110. instance['bbox_label'] = -1
  111. instance['ignore_flag'] = 1
  112. else:
  113. instance['bbox_label'] = self.metainfo['classes'].index(
  114. ann['tag'])
  115. instance['ignore_flag'] = 0
  116. if 'extra' in ann:
  117. if 'ignore' in ann['extra']:
  118. if ann['extra']['ignore'] != 0:
  119. instance['bbox_label'] = -1
  120. instance['ignore_flag'] = 1
  121. x1, y1, w, h = ann['fbox']
  122. bbox = [x1, y1, x1 + w, y1 + h]
  123. instance['bbox'] = bbox
  124. # Record the full bbox(fbox), head bbox(hbox) and visible
  125. # bbox(vbox) as additional information. If you need to use
  126. # this information, you just need to design the pipeline
  127. # instead of overriding the CrowdHumanDataset.
  128. instance['fbox'] = bbox
  129. hbox = ann['hbox']
  130. instance['hbox'] = [
  131. hbox[0], hbox[1], hbox[0] + hbox[2], hbox[1] + hbox[3]
  132. ]
  133. vbox = ann['vbox']
  134. instance['vbox'] = [
  135. vbox[0], vbox[1], vbox[0] + vbox[2], vbox[1] + vbox[3]
  136. ]
  137. instances.append(instance)
  138. data_info['instances'] = instances
  139. return data_info