dump_proposals_metric.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os
  3. import os.path as osp
  4. from typing import Optional, Sequence
  5. from mmengine.dist import is_main_process
  6. from mmengine.evaluator import BaseMetric
  7. from mmengine.fileio import dump
  8. from mmengine.logging import MMLogger
  9. from mmengine.structures import InstanceData
  10. from mmdet.registry import METRICS
  11. @METRICS.register_module()
  12. class DumpProposals(BaseMetric):
  13. """Dump proposals pseudo metric.
  14. Args:
  15. output_dir (str): The root directory for ``proposals_file``.
  16. Defaults to ''.
  17. proposals_file (str): Proposals file path. Defaults to 'proposals.pkl'.
  18. num_max_proposals (int, optional): Maximum number of proposals to dump.
  19. If not specified, all proposals will be dumped.
  20. file_client_args (dict, optional): Arguments to instantiate the
  21. corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
  22. backend_args (dict, optional): Arguments to instantiate the
  23. corresponding backend. Defaults to None.
  24. collect_device (str): Device name used for collecting results from
  25. different ranks during distributed training. Must be 'cpu' or
  26. 'gpu'. Defaults to 'cpu'.
  27. prefix (str, optional): The prefix that will be added in the metric
  28. names to disambiguate homonymous metrics of different evaluators.
  29. If prefix is not provided in the argument, self.default_prefix
  30. will be used instead. Defaults to None.
  31. """
  32. default_prefix: Optional[str] = 'dump_proposals'
  33. def __init__(self,
  34. output_dir: str = '',
  35. proposals_file: str = 'proposals.pkl',
  36. num_max_proposals: Optional[int] = None,
  37. file_client_args: dict = None,
  38. backend_args: dict = None,
  39. collect_device: str = 'cpu',
  40. prefix: Optional[str] = None) -> None:
  41. super().__init__(collect_device=collect_device, prefix=prefix)
  42. self.num_max_proposals = num_max_proposals
  43. # TODO: update after mmengine finish refactor fileio.
  44. self.backend_args = backend_args
  45. if file_client_args is not None:
  46. raise RuntimeError(
  47. 'The `file_client_args` is deprecated, '
  48. 'please use `backend_args` instead, please refer to'
  49. 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
  50. )
  51. self.output_dir = output_dir
  52. assert proposals_file.endswith(('.pkl', '.pickle')), \
  53. 'The output file must be a pkl file.'
  54. self.proposals_file = os.path.join(self.output_dir, proposals_file)
  55. if is_main_process():
  56. os.makedirs(self.output_dir, exist_ok=True)
  57. def process(self, data_batch: Sequence[dict],
  58. data_samples: Sequence[dict]) -> None:
  59. """Process one batch of data samples and predictions. The processed
  60. results should be stored in ``self.results``, which will be used to
  61. compute the metrics when all batches have been processed.
  62. Args:
  63. data_batch (dict): A batch of data from the dataloader.
  64. data_samples (Sequence[dict]): A batch of data samples that
  65. contain annotations and predictions.
  66. """
  67. for data_sample in data_samples:
  68. pred = data_sample['pred_instances']
  69. # `bboxes` is sorted by `scores`
  70. ranked_scores, rank_inds = pred['scores'].sort(descending=True)
  71. ranked_bboxes = pred['bboxes'][rank_inds, :]
  72. ranked_bboxes = ranked_bboxes.cpu().numpy()
  73. ranked_scores = ranked_scores.cpu().numpy()
  74. pred_instance = InstanceData()
  75. pred_instance.bboxes = ranked_bboxes
  76. pred_instance.scores = ranked_scores
  77. if self.num_max_proposals is not None:
  78. pred_instance = pred_instance[:self.num_max_proposals]
  79. img_path = data_sample['img_path']
  80. # `file_name` is the key to obtain the proposals from the
  81. # `proposals_list`.
  82. file_name = osp.join(
  83. osp.split(osp.split(img_path)[0])[-1],
  84. osp.split(img_path)[-1])
  85. result = {file_name: pred_instance}
  86. self.results.append(result)
  87. def compute_metrics(self, results: list) -> dict:
  88. """Dump the processed results.
  89. Args:
  90. results (list): The processed results of each batch.
  91. Returns:
  92. dict: An empty dict.
  93. """
  94. logger: MMLogger = MMLogger.get_current_instance()
  95. dump_results = {}
  96. for result in results:
  97. dump_results.update(result)
  98. dump(
  99. dump_results,
  100. file=self.proposals_file,
  101. backend_args=self.backend_args)
  102. logger.info(f'Results are saved at {self.proposals_file}')
  103. return {}