123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import os
- import os.path as osp
- from typing import Optional, Sequence
- from mmengine.dist import is_main_process
- from mmengine.evaluator import BaseMetric
- from mmengine.fileio import dump
- from mmengine.logging import MMLogger
- from mmengine.structures import InstanceData
- from mmdet.registry import METRICS
- @METRICS.register_module()
- class DumpProposals(BaseMetric):
- """Dump proposals pseudo metric.
- Args:
- output_dir (str): The root directory for ``proposals_file``.
- Defaults to ''.
- proposals_file (str): Proposals file path. Defaults to 'proposals.pkl'.
- num_max_proposals (int, optional): Maximum number of proposals to dump.
- If not specified, all proposals will be dumped.
- file_client_args (dict, optional): Arguments to instantiate the
- corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
- backend_args (dict, optional): Arguments to instantiate the
- corresponding backend. Defaults to None.
- collect_device (str): Device name used for collecting results from
- different ranks during distributed training. Must be 'cpu' or
- 'gpu'. Defaults to 'cpu'.
- prefix (str, optional): The prefix that will be added in the metric
- names to disambiguate homonymous metrics of different evaluators.
- If prefix is not provided in the argument, self.default_prefix
- will be used instead. Defaults to None.
- """
- default_prefix: Optional[str] = 'dump_proposals'
- def __init__(self,
- output_dir: str = '',
- proposals_file: str = 'proposals.pkl',
- num_max_proposals: Optional[int] = None,
- file_client_args: dict = None,
- backend_args: dict = None,
- collect_device: str = 'cpu',
- prefix: Optional[str] = None) -> None:
- super().__init__(collect_device=collect_device, prefix=prefix)
- self.num_max_proposals = num_max_proposals
- # TODO: update after mmengine finish refactor fileio.
- self.backend_args = backend_args
- if file_client_args is not None:
- raise RuntimeError(
- 'The `file_client_args` is deprecated, '
- 'please use `backend_args` instead, please refer to'
- 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
- )
- self.output_dir = output_dir
- assert proposals_file.endswith(('.pkl', '.pickle')), \
- 'The output file must be a pkl file.'
- self.proposals_file = os.path.join(self.output_dir, proposals_file)
- if is_main_process():
- os.makedirs(self.output_dir, exist_ok=True)
- def process(self, data_batch: Sequence[dict],
- data_samples: Sequence[dict]) -> None:
- """Process one batch of data samples and predictions. The processed
- results should be stored in ``self.results``, which will be used to
- compute the metrics when all batches have been processed.
- Args:
- data_batch (dict): A batch of data from the dataloader.
- data_samples (Sequence[dict]): A batch of data samples that
- contain annotations and predictions.
- """
- for data_sample in data_samples:
- pred = data_sample['pred_instances']
- # `bboxes` is sorted by `scores`
- ranked_scores, rank_inds = pred['scores'].sort(descending=True)
- ranked_bboxes = pred['bboxes'][rank_inds, :]
- ranked_bboxes = ranked_bboxes.cpu().numpy()
- ranked_scores = ranked_scores.cpu().numpy()
- pred_instance = InstanceData()
- pred_instance.bboxes = ranked_bboxes
- pred_instance.scores = ranked_scores
- if self.num_max_proposals is not None:
- pred_instance = pred_instance[:self.num_max_proposals]
- img_path = data_sample['img_path']
- # `file_name` is the key to obtain the proposals from the
- # `proposals_list`.
- file_name = osp.join(
- osp.split(osp.split(img_path)[0])[-1],
- osp.split(img_path)[-1])
- result = {file_name: pred_instance}
- self.results.append(result)
- def compute_metrics(self, results: list) -> dict:
- """Dump the processed results.
- Args:
- results (list): The processed results of each batch.
- Returns:
- dict: An empty dict.
- """
- logger: MMLogger = MMLogger.get_current_instance()
- dump_results = {}
- for result in results:
- dump_results.update(result)
- dump(
- dump_results,
- file=self.proposals_file,
- backend_args=self.backend_args)
- logger.info(f'Results are saved at {self.proposals_file}')
- return {}
|