wider_face.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import xml.etree.ElementTree as ET
  4. from mmengine.dist import is_main_process
  5. from mmengine.fileio import get_local_path, list_from_file
  6. from mmengine.utils import ProgressBar
  7. from mmdet.registry import DATASETS
  8. from mmdet.utils.typing_utils import List, Union
  9. from .xml_style import XMLDataset
  10. @DATASETS.register_module()
  11. class WIDERFaceDataset(XMLDataset):
  12. """Reader for the WIDER Face dataset in PASCAL VOC format.
  13. Conversion scripts can be found in
  14. https://github.com/sovrasov/wider-face-pascal-voc-annotations
  15. """
  16. METAINFO = {'classes': ('face', ), 'palette': [(0, 255, 0)]}
  17. def load_data_list(self) -> List[dict]:
  18. """Load annotation from XML style ann_file.
  19. Returns:
  20. list[dict]: Annotation info from XML file.
  21. """
  22. assert self._metainfo.get('classes', None) is not None, \
  23. 'classes in `XMLDataset` can not be None.'
  24. self.cat2label = {
  25. cat: i
  26. for i, cat in enumerate(self._metainfo['classes'])
  27. }
  28. data_list = []
  29. img_ids = list_from_file(self.ann_file, backend_args=self.backend_args)
  30. # loading process takes around 10 mins
  31. if is_main_process():
  32. prog_bar = ProgressBar(len(img_ids))
  33. for img_id in img_ids:
  34. raw_img_info = {}
  35. raw_img_info['img_id'] = img_id
  36. raw_img_info['file_name'] = f'{img_id}.jpg'
  37. parsed_data_info = self.parse_data_info(raw_img_info)
  38. data_list.append(parsed_data_info)
  39. if is_main_process():
  40. prog_bar.update()
  41. return data_list
  42. def parse_data_info(self, img_info: dict) -> Union[dict, List[dict]]:
  43. """Parse raw annotation to target format.
  44. Args:
  45. img_info (dict): Raw image information, usually it includes
  46. `img_id`, `file_name`, and `xml_path`.
  47. Returns:
  48. Union[dict, List[dict]]: Parsed annotation.
  49. """
  50. data_info = {}
  51. img_id = img_info['img_id']
  52. xml_path = osp.join(self.data_prefix['img'], 'Annotations',
  53. f'{img_id}.xml')
  54. data_info['img_id'] = img_id
  55. data_info['xml_path'] = xml_path
  56. # deal with xml file
  57. with get_local_path(
  58. xml_path, backend_args=self.backend_args) as local_path:
  59. raw_ann_info = ET.parse(local_path)
  60. root = raw_ann_info.getroot()
  61. size = root.find('size')
  62. width = int(size.find('width').text)
  63. height = int(size.find('height').text)
  64. folder = root.find('folder').text
  65. img_path = osp.join(self.data_prefix['img'], folder,
  66. img_info['file_name'])
  67. data_info['img_path'] = img_path
  68. data_info['height'] = height
  69. data_info['width'] = width
  70. # Coordinates are in range [0, width - 1 or height - 1]
  71. data_info['instances'] = self._parse_instance_info(
  72. raw_ann_info, minus_one=False)
  73. return data_info