base_det_dataset.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. from typing import List, Optional
  4. from mmengine.dataset import BaseDataset
  5. from mmengine.fileio import load
  6. from mmengine.utils import is_abs
  7. from ..registry import DATASETS
  8. @DATASETS.register_module()
  9. class BaseDetDataset(BaseDataset):
  10. """Base dataset for detection.
  11. Args:
  12. proposal_file (str, optional): Proposals file path. Defaults to None.
  13. file_client_args (dict): Arguments to instantiate the
  14. corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
  15. backend_args (dict, optional): Arguments to instantiate the
  16. corresponding backend. Defaults to None.
  17. """
  18. def __init__(self,
  19. *args,
  20. seg_map_suffix: str = '.png',
  21. proposal_file: Optional[str] = None,
  22. file_client_args: dict = None,
  23. backend_args: dict = None,
  24. **kwargs) -> None:
  25. self.seg_map_suffix = seg_map_suffix
  26. self.proposal_file = proposal_file
  27. self.backend_args = backend_args
  28. if file_client_args is not None:
  29. raise RuntimeError(
  30. 'The `file_client_args` is deprecated, '
  31. 'please use `backend_args` instead, please refer to'
  32. 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
  33. )
  34. super().__init__(*args, **kwargs)
  35. def full_init(self) -> None:
  36. """Load annotation file and set ``BaseDataset._fully_initialized`` to
  37. True.
  38. If ``lazy_init=False``, ``full_init`` will be called during the
  39. instantiation and ``self._fully_initialized`` will be set to True. If
  40. ``obj._fully_initialized=False``, the class method decorated by
  41. ``force_full_init`` will call ``full_init`` automatically.
  42. Several steps to initialize annotation:
  43. - load_data_list: Load annotations from annotation file.
  44. - load_proposals: Load proposals from proposal file, if
  45. `self.proposal_file` is not None.
  46. - filter data information: Filter annotations according to
  47. filter_cfg.
  48. - slice_data: Slice dataset according to ``self._indices``
  49. - serialize_data: Serialize ``self.data_list`` if
  50. ``self.serialize_data`` is True.
  51. """
  52. if self._fully_initialized:
  53. return
  54. # load data information
  55. self.data_list = self.load_data_list()
  56. # get proposals from file
  57. if self.proposal_file is not None:
  58. self.load_proposals()
  59. # filter illegal data, such as data that has no annotations.
  60. self.data_list = self.filter_data()
  61. # Get subset data according to indices.
  62. if self._indices is not None:
  63. self.data_list = self._get_unserialized_subset(self._indices)
  64. # serialize data_list
  65. if self.serialize_data:
  66. self.data_bytes, self.data_address = self._serialize_data()
  67. self._fully_initialized = True
  68. def load_proposals(self) -> None:
  69. """Load proposals from proposals file.
  70. The `proposals_list` should be a dict[img_path: proposals]
  71. with the same length as `data_list`. And the `proposals` should be
  72. a `dict` or :obj:`InstanceData` usually contains following keys.
  73. - bboxes (np.ndarry): Has a shape (num_instances, 4),
  74. the last dimension 4 arrange as (x1, y1, x2, y2).
  75. - scores (np.ndarry): Classification scores, has a shape
  76. (num_instance, ).
  77. """
  78. # TODO: Add Unit Test after fully support Dump-Proposal Metric
  79. if not is_abs(self.proposal_file):
  80. self.proposal_file = osp.join(self.data_root, self.proposal_file)
  81. proposals_list = load(
  82. self.proposal_file, backend_args=self.backend_args)
  83. assert len(self.data_list) == len(proposals_list)
  84. for data_info in self.data_list:
  85. img_path = data_info['img_path']
  86. # `file_name` is the key to obtain the proposals from the
  87. # `proposals_list`.
  88. file_name = osp.join(
  89. osp.split(osp.split(img_path)[0])[-1],
  90. osp.split(img_path)[-1])
  91. proposals = proposals_list[file_name]
  92. data_info['proposals'] = proposals
  93. def get_cat_ids(self, idx: int) -> List[int]:
  94. """Get COCO category ids by index.
  95. Args:
  96. idx (int): Index of data.
  97. Returns:
  98. List[int]: All categories in the image of specified index.
  99. """
  100. instances = self.get_data_info(idx)['instances']
  101. return [instance['bbox_label'] for instance in instances]