xml_style.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import xml.etree.ElementTree as ET
  4. from typing import List, Optional, Union
  5. import mmcv
  6. from mmengine.fileio import get, get_local_path, list_from_file
  7. from mmdet.registry import DATASETS
  8. from .base_det_dataset import BaseDetDataset
  9. @DATASETS.register_module()
  10. class XMLDataset(BaseDetDataset):
  11. """XML dataset for detection.
  12. Args:
  13. img_subdir (str): Subdir where images are stored. Default: JPEGImages.
  14. ann_subdir (str): Subdir where annotations are. Default: Annotations.
  15. backend_args (dict, optional): Arguments to instantiate the
  16. corresponding backend. Defaults to None.
  17. """
  18. def __init__(self,
  19. img_subdir: str = 'JPEGImages',
  20. ann_subdir: str = 'Annotations',
  21. **kwargs) -> None:
  22. self.img_subdir = img_subdir
  23. self.ann_subdir = ann_subdir
  24. super().__init__(**kwargs)
  25. @property
  26. def sub_data_root(self) -> str:
  27. """Return the sub data root."""
  28. return self.data_prefix.get('sub_data_root', '')
  29. def load_data_list(self) -> List[dict]:
  30. """Load annotation from XML style ann_file.
  31. Returns:
  32. list[dict]: Annotation info from XML file.
  33. """
  34. assert self._metainfo.get('classes', None) is not None, \
  35. '`classes` in `XMLDataset` can not be None.'
  36. self.cat2label = {
  37. cat: i
  38. for i, cat in enumerate(self._metainfo['classes'])
  39. }
  40. data_list = []
  41. img_ids = list_from_file(self.ann_file, backend_args=self.backend_args)
  42. for img_id in img_ids:
  43. file_name = osp.join(self.img_subdir, f'{img_id}.jpg')
  44. xml_path = osp.join(self.sub_data_root, self.ann_subdir,
  45. f'{img_id}.xml')
  46. raw_img_info = {}
  47. raw_img_info['img_id'] = img_id
  48. raw_img_info['file_name'] = file_name
  49. raw_img_info['xml_path'] = xml_path
  50. parsed_data_info = self.parse_data_info(raw_img_info)
  51. data_list.append(parsed_data_info)
  52. return data_list
  53. @property
  54. def bbox_min_size(self) -> Optional[str]:
  55. """Return the minimum size of bounding boxes in the images."""
  56. if self.filter_cfg is not None:
  57. return self.filter_cfg.get('bbox_min_size', None)
  58. else:
  59. return None
  60. def parse_data_info(self, img_info: dict) -> Union[dict, List[dict]]:
  61. """Parse raw annotation to target format.
  62. Args:
  63. img_info (dict): Raw image information, usually it includes
  64. `img_id`, `file_name`, and `xml_path`.
  65. Returns:
  66. Union[dict, List[dict]]: Parsed annotation.
  67. """
  68. data_info = {}
  69. img_path = osp.join(self.sub_data_root, img_info['file_name'])
  70. data_info['img_path'] = img_path
  71. data_info['img_id'] = img_info['img_id']
  72. data_info['xml_path'] = img_info['xml_path']
  73. # deal with xml file
  74. with get_local_path(
  75. img_info['xml_path'],
  76. backend_args=self.backend_args) as local_path:
  77. raw_ann_info = ET.parse(local_path)
  78. root = raw_ann_info.getroot()
  79. size = root.find('size')
  80. if size is not None:
  81. width = int(size.find('width').text)
  82. height = int(size.find('height').text)
  83. else:
  84. img_bytes = get(img_path, backend_args=self.backend_args)
  85. img = mmcv.imfrombytes(img_bytes, backend='cv2')
  86. height, width = img.shape[:2]
  87. del img, img_bytes
  88. data_info['height'] = height
  89. data_info['width'] = width
  90. data_info['instances'] = self._parse_instance_info(
  91. raw_ann_info, minus_one=True)
  92. return data_info
  93. def _parse_instance_info(self,
  94. raw_ann_info: ET,
  95. minus_one: bool = True) -> List[dict]:
  96. """parse instance information.
  97. Args:
  98. raw_ann_info (ElementTree): ElementTree object.
  99. minus_one (bool): Whether to subtract 1 from the coordinates.
  100. Defaults to True.
  101. Returns:
  102. List[dict]: List of instances.
  103. """
  104. instances = []
  105. for obj in raw_ann_info.findall('object'):
  106. instance = {}
  107. name = obj.find('name').text
  108. if name not in self._metainfo['classes']:
  109. continue
  110. difficult = obj.find('difficult')
  111. difficult = 0 if difficult is None else int(difficult.text)
  112. bnd_box = obj.find('bndbox')
  113. bbox = [
  114. int(float(bnd_box.find('xmin').text)),
  115. int(float(bnd_box.find('ymin').text)),
  116. int(float(bnd_box.find('xmax').text)),
  117. int(float(bnd_box.find('ymax').text))
  118. ]
  119. # VOC needs to subtract 1 from the coordinates
  120. if minus_one:
  121. bbox = [x - 1 for x in bbox]
  122. ignore = False
  123. if self.bbox_min_size is not None:
  124. assert not self.test_mode
  125. w = bbox[2] - bbox[0]
  126. h = bbox[3] - bbox[1]
  127. if w < self.bbox_min_size or h < self.bbox_min_size:
  128. ignore = True
  129. if difficult or ignore:
  130. instance['ignore_flag'] = 1
  131. else:
  132. instance['ignore_flag'] = 0
  133. instance['bbox'] = bbox
  134. instance['bbox_label'] = self.cat2label[name]
  135. instances.append(instance)
  136. return instances
  137. def filter_data(self) -> List[dict]:
  138. """Filter annotations according to filter_cfg.
  139. Returns:
  140. List[dict]: Filtered results.
  141. """
  142. if self.test_mode:
  143. return self.data_list
  144. filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \
  145. if self.filter_cfg is not None else False
  146. min_size = self.filter_cfg.get('min_size', 0) \
  147. if self.filter_cfg is not None else 0
  148. valid_data_infos = []
  149. for i, data_info in enumerate(self.data_list):
  150. width = data_info['width']
  151. height = data_info['height']
  152. if filter_empty_gt and len(data_info['instances']) == 0:
  153. continue
  154. if min(width, height) >= min_size:
  155. valid_data_infos.append(data_info)
  156. return valid_data_infos