123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import os.path as osp
- from typing import List, Optional
- from mmengine.dataset import BaseDataset
- from mmengine.fileio import load
- from mmengine.utils import is_abs
- from ..registry import DATASETS
- @DATASETS.register_module()
- class BaseDetDataset(BaseDataset):
- """Base dataset for detection.
- Args:
- proposal_file (str, optional): Proposals file path. Defaults to None.
- file_client_args (dict): Arguments to instantiate the
- corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
- backend_args (dict, optional): Arguments to instantiate the
- corresponding backend. Defaults to None.
- """
- def __init__(self,
- *args,
- seg_map_suffix: str = '.png',
- proposal_file: Optional[str] = None,
- file_client_args: dict = None,
- backend_args: dict = None,
- **kwargs) -> None:
- self.seg_map_suffix = seg_map_suffix
- self.proposal_file = proposal_file
- self.backend_args = backend_args
- if file_client_args is not None:
- raise RuntimeError(
- 'The `file_client_args` is deprecated, '
- 'please use `backend_args` instead, please refer to'
- 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
- )
- super().__init__(*args, **kwargs)
- def full_init(self) -> None:
- """Load annotation file and set ``BaseDataset._fully_initialized`` to
- True.
- If ``lazy_init=False``, ``full_init`` will be called during the
- instantiation and ``self._fully_initialized`` will be set to True. If
- ``obj._fully_initialized=False``, the class method decorated by
- ``force_full_init`` will call ``full_init`` automatically.
- Several steps to initialize annotation:
- - load_data_list: Load annotations from annotation file.
- - load_proposals: Load proposals from proposal file, if
- `self.proposal_file` is not None.
- - filter data information: Filter annotations according to
- filter_cfg.
- - slice_data: Slice dataset according to ``self._indices``
- - serialize_data: Serialize ``self.data_list`` if
- ``self.serialize_data`` is True.
- """
- if self._fully_initialized:
- return
- # load data information
- self.data_list = self.load_data_list()
- # get proposals from file
- if self.proposal_file is not None:
- self.load_proposals()
- # filter illegal data, such as data that has no annotations.
- self.data_list = self.filter_data()
- # Get subset data according to indices.
- if self._indices is not None:
- self.data_list = self._get_unserialized_subset(self._indices)
- # serialize data_list
- if self.serialize_data:
- self.data_bytes, self.data_address = self._serialize_data()
- self._fully_initialized = True
- def load_proposals(self) -> None:
- """Load proposals from proposals file.
- The `proposals_list` should be a dict[img_path: proposals]
- with the same length as `data_list`. And the `proposals` should be
- a `dict` or :obj:`InstanceData` usually contains following keys.
- - bboxes (np.ndarry): Has a shape (num_instances, 4),
- the last dimension 4 arrange as (x1, y1, x2, y2).
- - scores (np.ndarry): Classification scores, has a shape
- (num_instance, ).
- """
- # TODO: Add Unit Test after fully support Dump-Proposal Metric
- if not is_abs(self.proposal_file):
- self.proposal_file = osp.join(self.data_root, self.proposal_file)
- proposals_list = load(
- self.proposal_file, backend_args=self.backend_args)
- assert len(self.data_list) == len(proposals_list)
- for data_info in self.data_list:
- img_path = data_info['img_path']
- # `file_name` is the key to obtain the proposals from the
- # `proposals_list`.
- file_name = osp.join(
- osp.split(osp.split(img_path)[0])[-1],
- osp.split(img_path)[-1])
- proposals = proposals_list[file_name]
- data_info['proposals'] = proposals
- def get_cat_ids(self, idx: int) -> List[int]:
- """Get COCO category ids by index.
- Args:
- idx (int): Index of data.
- Returns:
- List[int]: All categories in the image of specified index.
- """
- instances = self.get_data_info(idx)['instances']
- return [instance['bbox_label'] for instance in instances]
|